# -*- coding: utf-8 -*-
"""reaction_wheels.py
@brief  Tests for the reaction wheel handler
@author J. Meier
@date   20.06.2021
"""
import enum
import struct
from typing import List

from eive_tmtc.pus_tm.defs import PrintWrapper
from eive_tmtc.config.object_ids import RW1_ID, RW2_ID, RW3_ID, RW4_ID
from tmtccmd.config import TmtcDefinitionWrapper, OpCodeEntry
from tmtccmd.config.tmtc import tmtc_definitions_provider
from tmtccmd.tmtc import DefaultPusQueueHelper
from tmtccmd.pus.tc.s3_fsfw_hk import (
    generate_one_hk_command,
    generate_one_diag_command,
    make_sid,
    enable_periodic_hk_command_with_interval,
    disable_periodic_hk_command,
)
from tmtccmd.pus.s8_fsfw_action import create_action_cmd
from spacepackets.ecss.tc import PusTelecommand
from tmtccmd.pus.s200_fsfw_mode import pack_mode_data, Mode, Subservice
from eive_tmtc.config.definitions import CustomServiceList
from tmtccmd.util import ObjectIdU32
from tmtccmd.fsfw.tmtc_printer import FsfwTmTcPrinter


class OpCodesDev:
    SPEED = "speed"
    ON = "on"
    NML = "nml"
    OFF = "off"
    GET_STATUS = "status"
    GET_TM = "get_tm_set"
    REQ_TM = "req_tm_set"
    ENABLE_STATUS_HK = "enable_status_hk"
    DISABLE_STATUS_HK = "disable_status_hk"


class InfoDev:
    SPEED = "Set speed"
    ON = "Set On"
    NML = "Set Normal"
    OFF = "Set Off"
    GET_STATUS = "Get Status HK"
    GET_TM = "Get TM HK"
    REQ_TM = "Request TM HK"
    ENABLE_STATUS_HK = "Enable Status HK"
    DISABLE_STATUS_HK = "Disable Status HK"


class OpCodesAss:
    ON = ["0", "on"]
    NML = ["1", "nml"]
    OFF = ["2", "off"]
    ALL_SPEED_UP = ["3", "speed_up"]
    ALL_SPEED_OFF = ["4", "speed_off"]


class ActionId:
    REQUEST_TM = 9


class InfoAss:
    ON = "Mode On: 3/4 RWs min. on"
    NML = "Mode Normal: 3/4 RWs min. normal"
    OFF = "Mode Off: All RWs off"
    ALL_SPEED_UP = "Speed up consecutively"
    ALL_SPEED_OFF = "Speed down to 0"


class RwSetId(enum.IntEnum):
    STATUS_SET_ID = 4
    TEMPERATURE_SET_ID = 8
    LAST_RESET = 2
    TM_SET = 9


class RwCommandId:
    RESET_MCU = bytearray([0x0, 0x0, 0x0, 0x01])
    # Reads status information from reaction wheel into dataset with id 4
    GET_RW_STATUS = bytearray([0x0, 0x0, 0x0, 0x04])
    INIT_RW_CONTROLLER = bytearray([0x0, 0x0, 0x0, 0x05])
    SET_SPEED = bytearray([0x0, 0x0, 0x0, 0x06])
    # Reads temperature from reaction wheel into dataset with id 8
    GET_TEMPERATURE = bytearray([0x0, 0x0, 0x0, 0x08])
    GET_TM = bytearray([0x0, 0x0, 0x0, 0x09])


class SpeedDefinitions:
    RPM_100 = 1000
    RPM_5000 = 5000


class RampTime:
    MS_1000 = 1000


@tmtc_definitions_provider
def add_rw_cmds(defs: TmtcDefinitionWrapper):
    oce = OpCodeEntry()
    oce.add(info=InfoDev.SPEED, keys=OpCodesDev.SPEED)
    oce.add(info=InfoDev.ON, keys=OpCodesDev.ON)
    oce.add(info=InfoDev.OFF, keys=OpCodesDev.OFF)
    oce.add(info=InfoDev.NML, keys=OpCodesDev.NML)
    oce.add(info=InfoDev.REQ_TM, keys=OpCodesDev.REQ_TM)
    oce.add(info=InfoDev.GET_STATUS, keys=OpCodesDev.GET_STATUS)
    oce.add(info=InfoDev.GET_TM, keys=OpCodesDev.GET_TM)
    oce.add(info=InfoDev.ENABLE_STATUS_HK, keys=OpCodesDev.ENABLE_STATUS_HK)
    oce.add(info=InfoDev.DISABLE_STATUS_HK, keys=OpCodesDev.DISABLE_STATUS_HK)
    defs.add_service(
        name=CustomServiceList.REACTION_WHEEL_1.value,
        info="Reaction Wheel 1",
        op_code_entry=oce,
    )
    defs.add_service(
        name=CustomServiceList.REACTION_WHEEL_2.value,
        info="Reaction Wheel 2",
        op_code_entry=oce,
    )
    defs.add_service(
        name=CustomServiceList.REACTION_WHEEL_3.value,
        info="Reaction Wheel 3",
        op_code_entry=oce,
    )
    defs.add_service(
        name=CustomServiceList.REACTION_WHEEL_4.value,
        info="Reaction Wheel 4",
        op_code_entry=oce,
    )
    oce = OpCodeEntry()
    oce.add(info=InfoAss.ON, keys=OpCodesAss.ON)
    oce.add(info=InfoAss.NML, keys=OpCodesAss.NML)
    oce.add(info=InfoAss.OFF, keys=OpCodesAss.OFF)
    oce.add(info=InfoAss.ALL_SPEED_UP, keys=OpCodesAss.ALL_SPEED_UP)
    oce.add(info=InfoAss.ALL_SPEED_OFF, keys=OpCodesAss.ALL_SPEED_OFF)
    defs.add_service(
        name=CustomServiceList.RW_ASSEMBLY.value,
        info="Reaction Wheel Assembly",
        op_code_entry=oce,
    )


def create_single_rw_cmd(  # noqa C901: Complexity is okay here.
    object_id: bytes, rw_idx: int, q: DefaultPusQueueHelper, cmd_str: str
):
    if cmd_str == OpCodesDev.SPEED:
        speed, ramp_time = prompt_speed_ramp_time()
        q.add_log_cmd(
            f"RW {rw_idx}: {InfoDev.SPEED} with target "
            f"speed {speed / 10.0} RPM and {ramp_time} ms ramp time"
        )
        q.add_pus_tc(pack_set_speed_command(object_id, speed, ramp_time))

    if cmd_str == OpCodesDev.ON:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.ON}")
        mode_data = pack_mode_data(object_id, Mode.ON, 0)
        q.add_pus_tc(PusTelecommand(service=200, subservice=1, app_data=mode_data))

    if cmd_str == OpCodesDev.NML:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.NML}")
        mode_data = pack_mode_data(object_id, Mode.NORMAL, 0)
        q.add_pus_tc(PusTelecommand(service=200, subservice=1, app_data=mode_data))

    if cmd_str == OpCodesDev.OFF:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.OFF}")
        mode_data = pack_mode_data(object_id, Mode.OFF, 0)
        q.add_pus_tc(PusTelecommand(service=200, subservice=1, app_data=mode_data))

    if cmd_str == OpCodesDev.GET_TM:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.GET_TM}")
        q.add_pus_tc(
            generate_one_hk_command(
                sid=make_sid(object_id=object_id, set_id=RwSetId.TM_SET)
            )
        )
    if cmd_str == OpCodesDev.REQ_TM:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.REQ_TM}")
        q.add_pus_tc(
            create_action_cmd(object_id=object_id, action_id=ActionId.REQUEST_TM)
        )
    if cmd_str in OpCodesDev.GET_STATUS:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.GET_STATUS}")
        q.add_pus_tc(
            generate_one_diag_command(
                sid=make_sid(object_id=object_id, set_id=RwSetId.STATUS_SET_ID)
            )
        )
    if cmd_str in OpCodesDev.ENABLE_STATUS_HK:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.ENABLE_STATUS_HK}")
        interval = float(input("Please enter HK interval in floating point seconds: "))
        cmds = enable_periodic_hk_command_with_interval(
            True, make_sid(object_id, RwSetId.STATUS_SET_ID), interval
        )
        for cmd in cmds:
            q.add_pus_tc(cmd)
    if cmd_str in OpCodesDev.DISABLE_STATUS_HK:
        q.add_log_cmd(f"RW {rw_idx}: {InfoDev.DISABLE_STATUS_HK}")
        q.add_pus_tc(
            disable_periodic_hk_command(
                True, make_sid(object_id, RwSetId.STATUS_SET_ID)
            )
        )


def pack_rw_ass_cmds(q: DefaultPusQueueHelper, object_id: bytes, cmd_str: str):
    if cmd_str in OpCodesAss.OFF:
        data = pack_mode_data(object_id=object_id, mode=Mode.OFF, submode=0)
        q.add_pus_tc(
            PusTelecommand(
                service=200, subservice=Subservice.TC_MODE_COMMAND, app_data=data
            )
        )
    if cmd_str in OpCodesAss.ON:
        data = pack_mode_data(object_id=object_id, mode=Mode.ON, submode=0)
        q.add_pus_tc(
            PusTelecommand(
                service=200, subservice=Subservice.TC_MODE_COMMAND, app_data=data
            )
        )
    if cmd_str in OpCodesAss.NML:
        data = pack_mode_data(object_id=object_id, mode=Mode.NORMAL, submode=0)
        q.add_pus_tc(
            PusTelecommand(
                service=200, subservice=Subservice.TC_MODE_COMMAND, app_data=data
            )
        )
    if cmd_str in OpCodesAss.ALL_SPEED_UP:
        speed, ramp_time = prompt_speed_ramp_time()
        rw_speed_up_cmd_consec(q, [RW1_ID, RW2_ID, RW3_ID, RW4_ID], speed, ramp_time)
    if cmd_str in OpCodesAss.ALL_SPEED_OFF:
        rw_speed_down_cmd_consec(
            q, [RW1_ID, RW2_ID, RW3_ID, RW4_ID], prompt_ramp_time()
        )


def prompt_speed_ramp_time() -> (int, int):
    speed = int(
        input("Specify speed [0.1 RPM, 0 or range [-65000, -1000] and [1000, 65000]: ")
    )
    return speed, prompt_ramp_time()


def prompt_ramp_time() -> int:
    return int(input("Specify ramp time [ms, range [10, 20000]]: "))


def pack_set_speed_command(
    object_id: bytes, speed: int, ramp_time_ms: int
) -> PusTelecommand:
    """With this function a command is packed to set the speed of a reaction wheel
    :param object_id: The object id of the reaction wheel handler.
    :param speed: Valid speeds are 0, [-65000, -1000] and [1000, 65000]. Values are
        specified in 0.1 * RPM
    :param ramp_time_ms: The time after which the reaction wheel will reach the commanded speed.
        Valid times are 10 - 20000 ms
    """
    if speed > 0:
        if speed < 1000 or speed > 65000:
            raise ValueError(
                "Invalid RW speed specified. Allowed range is [1000, 65000] 0.1 * RPM"
            )
    elif speed < 0:
        if speed < -65000 or speed > -1000:
            raise ValueError(
                "Invalid RW speed specified. Allowed range is [-65000, -1000] 0.1 * RPM"
            )
    else:
        # Speed is 0
        pass

    if ramp_time_ms < 0 or (
        ramp_time_ms > 0 and (ramp_time_ms > 20000 or ramp_time_ms < 10)
    ):
        raise ValueError("Invalid Ramp Speed time. Allowed range is [10-20000] ms")
    command_id = RwCommandId.SET_SPEED
    command = bytearray()
    command += object_id + command_id
    command = command + struct.pack("!i", speed)
    command = command + ramp_time_ms.to_bytes(length=2, byteorder="big")
    command = PusTelecommand(service=8, subservice=128, app_data=command)
    return command


def handle_rw_hk_data(
    pw: PrintWrapper, object_id: ObjectIdU32, set_id: int, hk_data: bytes
):
    current_idx = 0
    if set_id == RwSetId.STATUS_SET_ID:
        pw.dlog(
            f"Received Status HK (ID {set_id}) from Reaction Wheel {object_id.name}"
        )
        fmt_str = "!IiiBB"
        inc_len = struct.calcsize(fmt_str)
        (temp, speed, ref_speed, state, clc_mode) = struct.unpack(
            fmt_str, hk_data[current_idx : current_idx + inc_len]
        )
        current_idx += inc_len
        speed_rpm = speed / 10.0
        ref_speed_rpm = ref_speed / 10.0
        pw.dlog(
            f"Temperature {temp} C | Speed {speed_rpm} rpm | Reference Speed"
            f" {ref_speed_rpm} rpm"
        )
        pw.dlog(
            f"State {state}. 0: Error, 1: Idle, 2: Coasting, 3: Running, speed stable, "
            "4: Running, speed  changing"
        )
        pw.dlog(
            f"Current Limit Control mode {clc_mode}. 0: Low Current Mode (0.3 A), "
            "1: High Current Mode (0.6 A)"
        )
        pw.dlog(FsfwTmTcPrinter.get_validity_buffer(hk_data[current_idx:], 5))
    if set_id == RwSetId.LAST_RESET:
        pw.dlog(
            f"Received Last Reset HK (ID {set_id}) from Reaction Wheel {object_id.name}"
        )
        fmt_str = "!BB"
        inc_len = struct.calcsize(fmt_str)
        (last_not_cleared_reset_status, current_reset_status) = struct.unpack(
            fmt_str, hk_data[current_idx : current_idx + inc_len]
        )
        current_idx += inc_len
        pw.dlog(
            f"Last Non-Cleared (Cached) Reset Status {last_not_cleared_reset_status} | "
            f"Current Reset Status {current_reset_status}"
        )
    if set_id == RwSetId.TM_SET:
        pw.dlog(f"Received TM HK (ID {set_id}) from Reaction Wheel {object_id.name}")
        fmt_str = "!BiffBBiiIIIIIIIIIIIIIIII"
        inc_len = struct.calcsize(fmt_str)
        (
            last_reset_status,
            mcu_temp,
            pressure_sens_temp,
            pressure,
            state,
            clc_mode,
            current_speed,
            ref_speed,
            num_invalid_crc_packets,
            num_invalid_len_packets,
            num_invalid_cmd_packets,
            num_of_cmd_executed_requests,
            num_of_cmd_replies,
            uart_num_of_bytes_written,
            uart_num_of_bytes_read,
            uart_num_parity_errors,
            uart_num_noise_errors,
            uart_num_frame_errors,
            uart_num_reg_overrun_errors,
            uart_total_num_errors,
            spi_num_bytes_written,
            spi_num_bytes_read,
            spi_num_reg_overrun_errors,
            spi_total_num_errors,
        ) = struct.unpack(fmt_str, hk_data[current_idx : current_idx + inc_len])

        pw.dlog(
            f"MCU Temperature {mcu_temp} | Pressure Sensore Temperature"
            f" {pressure_sens_temp} C"
        )
        pw.dlog(f"Last Reset Status {last_reset_status}")
        pw.dlog(
            f"Current Limit Control mode {clc_mode}. 0: Low Current Mode (0.3 A), "
            "1: High Current Mode (0.6 A)"
        )
        pw.dlog(f"Speed {current_speed} rpm | Reference Speed {ref_speed} rpm")
        pw.dlog(
            f"State {state}. 0: Error, 1: Idle, 2: Coasting, 3: Running, speed stable, "
            "4: Running, speed  changing"
        )
        pw.dlog("Number Of Invalid Packets:")
        pw.dlog("CRC | Length | CMD")
        pw.dlog(
            f"{num_invalid_crc_packets} | {num_invalid_len_packets} |"
            f" {num_invalid_cmd_packets}"
        )
        pw.dlog(
            f"Num Of CMD Executed Requests {num_of_cmd_executed_requests} | "
            f"Num of CMD Replies {num_of_cmd_replies}"
        )
        pw.dlog("UART COM information:")
        pw.dlog(
            "NumBytesWritten | NumBytesRead | ParityErrs | NoiseErrs | FrameErrs | "
            "RegOverrunErrs | TotalErrs"
        )
        pw.dlog(
            f"{uart_num_of_bytes_written} | {uart_num_of_bytes_read} |"
            f" {uart_num_parity_errors} | {uart_num_noise_errors} |"
            f" {uart_num_frame_errors} | {uart_num_reg_overrun_errors} |"
            f" {uart_total_num_errors}"
        )
        pw.dlog("SPI COM Info:")
        pw.dlog("NumBytesWritten | NumBytesRead | RegOverrunErrs | TotalErrs")
        pw.dlog(
            f"{spi_num_bytes_written} | {spi_num_bytes_read} |"
            f" {spi_num_reg_overrun_errors} | {spi_total_num_errors}"
        )
        if current_idx > 0:
            pw.dlog(
                FsfwTmTcPrinter.get_validity_buffer(
                    validity_buffer=hk_data[current_idx:], num_vars=27
                )
            )


def rw_speed_up_cmd_consec(
    q: DefaultPusQueueHelper, obids: List[bytes], speed: int, ramp_time: int
):
    for oid in obids:
        q.add_pus_tc(
            pack_set_speed_command(object_id=oid, speed=speed, ramp_time_ms=ramp_time)
        )


def rw_speed_down_cmd_consec(
    q: DefaultPusQueueHelper, obids: List[bytes], ramp_time: int
):
    for oid in obids:
        q.add_pus_tc(
            pack_set_speed_command(object_id=oid, speed=0, ramp_time_ms=ramp_time)
        )