import enum
import struct
import time
from typing import Optional

from config.definitions import CustomServiceList
from tmtccmd.config import TmTcDefWrapper

from tmtccmd.config.tmtc import OpCodeEntry
from tmtccmd.tc import QueueHelper
from tmtccmd.tc.pus_3_fsfw_hk import (
    make_sid,
    generate_one_diag_command,
)
from tmtccmd.tc.pus_11_tc_sched import (
    generate_enable_tc_sched_cmd,
    generate_time_tagged_cmd,
)
from tmtccmd.tc.pus_200_fsfw_modes import pack_mode_data, Modes, Subservices
from tmtccmd.tc.pus_20_params import (
    pack_scalar_double_param_app_data,
    pack_fsfw_load_param_cmd,
    pack_boolean_parameter_app_data,
)
from tmtccmd.logging import get_console_logger
from spacepackets.ecss.tc import PusTelecommand
from config.object_ids import PL_PCDU_ID

LOGGER = get_console_logger()


class OpCodes:
    SWITCH_HPA_ON_PROC = ["0", "proc-hpa"]
    SWITCH_ON = ["2", "on"]
    SWITCH_OFF = ["3", "off"]
    NORMAL_SSR = ["4", "nml-ssr"]
    NORMAL_DRO = ["5", "nml-dro"]
    NORMAL_X8 = ["6", "nml-x8"]
    NORMAL_TX = ["7", "nml-tx"]
    NORMAL_MPA = ["8", "nml-mpa"]
    NORMAL_HPA = ["9", "nml-hpa"]

    REQ_OS_HK = ["8", "hk-os"]

    INJECT_SSR_TO_DRO_FAILURE = ["10", "inject-ssr-dro-fault"]
    INJECT_DRO_TO_X8_FAILURE = ["11", "inject-dro-x8-fault"]
    INJECT_X8_TO_TX_FAILURE = ["12", "inject-x8-tx-fault"]
    INJECT_TX_TO_MPA_FAILURE = ["13", "inject-tx-mpa-fault"]
    INJECT_MPA_TO_HPA_FAILURE = ["14", "inject-mpa-hpa-fault"]
    INJECT_ALL_ON_FAILURE = ["15", "inject-all-on-fault"]


class Info:
    NORMAL = "PL PCDU ADC modules normal"
    SWITCH_ON = "Switching PL PCDU on"
    SWITCH_OFF = "Switching PL PCDU off"
    NORMAL_SSR = f"{NORMAL}, SSR on"
    NORMAL_DRO = f"{NORMAL},DRO on"
    NORMAL_X8 = f"{NORMAL}, X8 on"
    NORMAL_TX = f"{NORMAL}, TX on"
    NORMAL_MPA = f"{NORMAL}, MPA on"
    NORMAL_HPA = f"{NORMAL}, HPA on"
    REQ_OS_HK = "Request One Shot HK"


class SetIds(enum.IntEnum):
    ADC = 0


class NormalSubmodesMask(enum.IntEnum):
    SOLID_STATE_RELAYS_ADC_ON = 0
    DRO_ON = 1
    X8_ON = 2
    TX_ON = 3
    MPA_ON = 4
    HPA_ON = 5


class ParamIds(enum.IntEnum):
    NEG_V_LOWER_BOUND = 0
    NEG_V_UPPER_BOUND = 1

    DRO_U_LOWER_BOUND = 2
    DRO_U_UPPER_BOUND = 3
    DRO_I_UPPER_BOUND = 4

    X8_U_LOWER_BOUND = 5
    X8_U_UPPER_BOUND = 6
    X8_I_UPPER_BOUND = 7

    TX_U_LOWER_BOUND = 8
    TX_U_UPPER_BOUND = 9
    TX_I_UPPER_BOUND = 10

    MPA_U_LOWER_BOUND = 11
    MPA_U_UPPER_BOUND = 12
    MPA_I_UPPER_BOUND = 13

    HPA_U_LOWER_BOUND = 14
    HPA_U_UPPER_BOUND = 15
    HPA_I_UPPER_BOUND = 16

    SSR_TO_DRO_WAIT_TIME = 17
    DRO_TO_X8_WAIT_TIME = 18
    X8_TO_TX_WAIT_TIME = 19
    TX_TO_MPA_WAIT_TIME = 20
    MPA_TO_HPA_WAIT_TIME = 21

    INJECT_SSR_TO_DRO_FAILURE = 30
    INJECT_DRO_TO_X8_FAILURE = 31
    INJECT_X8_TO_TX_FAILURE = 32
    INJECT_TX_TO_MPA_FAILURE = 33
    INJECT_MPA_TO_HPA_FAILURE = 34
    INJECT_ALL_ON_FAILURE = 35


def add_pl_pcdu_cmds(defs: TmTcDefWrapper):
    oce = OpCodeEntry()
    oce.add(keys=OpCodes.SWITCH_ON, info=Info.SWITCH_ON)
    oce.add(keys=OpCodes.SWITCH_OFF, info=Info.SWITCH_OFF)
    oce.add(keys=OpCodes.NORMAL_SSR, info=Info.NORMAL_SSR)
    oce.add(keys=OpCodes.NORMAL_DRO, info=Info.NORMAL_DRO)
    oce.add(keys=OpCodes.NORMAL_X8, info=Info.NORMAL_X8)
    oce.add(keys=OpCodes.NORMAL_TX, info=Info.NORMAL_TX)
    oce.add(keys=OpCodes.NORMAL_MPA, info=Info.NORMAL_MPA)
    oce.add(keys=OpCodes.NORMAL_HPA, info=Info.NORMAL_HPA)
    oce.add(keys=OpCodes.REQ_OS_HK, info=Info.REQ_OS_HK)
    oce.add(
        keys=OpCodes.INJECT_SSR_TO_DRO_FAILURE,
        info="Inject failure SSR to DRO transition",
    )
    oce.add(
        keys=OpCodes.INJECT_DRO_TO_X8_FAILURE,
        info="Inject failure in DRO to X8 transition",
    )
    oce.add(
        keys=OpCodes.INJECT_X8_TO_TX_FAILURE,
        info="Inject failure in X8 to TX transition",
    )
    oce.add(
        keys=OpCodes.INJECT_TX_TO_MPA_FAILURE,
        info="Inject failure in TX to MPA transition",
    )
    oce.add(
        keys=OpCodes.INJECT_MPA_TO_HPA_FAILURE,
        info="Inject failure in MPA to HPA transition",
    )
    oce.add(keys=OpCodes.INJECT_ALL_ON_FAILURE, info="Inject failure in all on mode")
    defs.add_service(CustomServiceList.PL_PCDU.value, "PL PCDU", oce)


def pack_pl_pcdu_commands(q: QueueHelper, op_code: str):
    if op_code in OpCodes.SWITCH_ON:
        pack_pl_pcdu_mode_cmd(q=q, info=Info.SWITCH_ON, mode=Modes.ON, submode=0)
    if op_code in OpCodes.SWITCH_OFF:
        pack_pl_pcdu_mode_cmd(q=q, info=Info.SWITCH_OFF, mode=Modes.OFF, submode=0)
    if op_code in OpCodes.NORMAL_SSR:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_SSR,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(
                NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON
            ),
        )
    if op_code in OpCodes.NORMAL_DRO:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_DRO,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.DRO_ON),
        )
    if op_code in OpCodes.NORMAL_X8:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_X8,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.X8_ON),
        )
    if op_code in OpCodes.NORMAL_TX:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_TX,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.TX_ON),
        )
    if op_code in OpCodes.NORMAL_MPA:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_MPA,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.MPA_ON),
        )
    if op_code in OpCodes.NORMAL_HPA:
        pack_pl_pcdu_mode_cmd(
            q=q,
            info=Info.NORMAL_HPA,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.HPA_ON),
        )
    if op_code in OpCodes.REQ_OS_HK:
        q.add_log_cmd(f"PL PCDU: {Info.REQ_OS_HK}")
        q.add_pus_tc(
            generate_one_diag_command(
                sid=make_sid(object_id=PL_PCDU_ID, set_id=SetIds.ADC)
            )
        )
    if op_code in OpCodes.SWITCH_HPA_ON_PROC:
        hpa_on_procedure(q)
    if op_code in OpCodes.INJECT_ALL_ON_FAILURE:
        pack_failure_injection_cmd(
            q=q,
            param_id=ParamIds.INJECT_ALL_ON_FAILURE,
            print_str="All On",
        )


def hpa_on_procedure(q: QueueHelper):
    delay_dro_to_x8 = request_wait_time()
    if delay_dro_to_x8 is None:
        delay_dro_to_x8 = 900
    q.add_log_cmd(
        f"Starting procedure to switch on PL PCDU HPA with DRO to X8 "
        f"delay of {delay_dro_to_x8} seconds"
    )
    pl_pcdu_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(object_id=PL_PCDU_ID, mode=Modes.ON, submode=0),
    )
    ssr_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(
                NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON
            ),
        ),
    )
    dro_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.DRO_ON),
        ),
    )
    x8_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.X8_ON),
        ),
    )
    tx_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.TX_ON),
        ),
    )
    mpa_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.MPA_ON),
        ),
    )
    hpa_on = PusTelecommand(
        service=200,
        subservice=Subservices.TC_MODE_COMMAND,
        app_data=pack_mode_data(
            object_id=PL_PCDU_ID,
            mode=Modes.NORMAL,
            submode=submode_mask_to_submode(NormalSubmodesMask.HPA_ON),
        ),
    )
    current_time = time.time()

    enb_sched = generate_enable_tc_sched_cmd()

    sched_time = current_time + 10
    q.add_pus_tc(enb_sched)
    tagged_on_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time),
        tc_to_insert=pl_pcdu_on,
    )
    q.add_pus_tc(tagged_on_cmd)

    sched_time += 5
    tagged_ssr_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time),
        tc_to_insert=ssr_on,
    )
    q.add_pus_tc(tagged_ssr_cmd)

    sched_time += 5
    tagged_dro_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time), tc_to_insert=dro_on
    )
    q.add_pus_tc(tagged_dro_cmd)

    sched_time += delay_dro_to_x8
    tagged_x8_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time), tc_to_insert=x8_on
    )
    q.add_pus_tc(tagged_x8_cmd)

    sched_time += 5
    tagged_tx_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time), tc_to_insert=tx_on
    )
    q.add_pus_tc(tagged_tx_cmd)

    sched_time += 5
    tagged_mpa_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time), tc_to_insert=mpa_on
    )
    q.add_pus_tc(tagged_mpa_cmd)

    sched_time += 5
    tagged_hpa_cmd = generate_time_tagged_cmd(
        release_time=struct.pack("!I", sched_time), tc_to_insert=hpa_on
    )
    q.add_pus_tc(tagged_hpa_cmd)


def request_wait_time() -> Optional[float]:
    while True:
        wait_time = input("Please enter DRO to X8 wait time in seconds, x to cancel: ")
        if wait_time.lower() == "x":
            return None
        try:
            wait_time = float(wait_time)
        except ValueError:
            LOGGER.warning("Invalid input")
            continue
        if wait_time <= 0:
            LOGGER.warning("Invalid input")
        else:
            return wait_time


def submode_mask_to_submode(on_tgt: NormalSubmodesMask) -> int:
    if on_tgt == NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON:
        return 1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON
    if on_tgt == NormalSubmodesMask.DRO_ON:
        return 1 << NormalSubmodesMask.DRO_ON | (
            1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON
        )
    if on_tgt == NormalSubmodesMask.X8_ON:
        return (
            1 << NormalSubmodesMask.DRO_ON
            | (1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON)
            | (1 << NormalSubmodesMask.X8_ON)
        )
    if on_tgt == NormalSubmodesMask.TX_ON:
        return (
            1 << NormalSubmodesMask.DRO_ON
            | (1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON)
            | (1 << NormalSubmodesMask.X8_ON)
            | (1 << NormalSubmodesMask.TX_ON)
        )
    if on_tgt == NormalSubmodesMask.MPA_ON:
        return (
            1 << NormalSubmodesMask.DRO_ON
            | (1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON)
            | (1 << NormalSubmodesMask.X8_ON)
            | (1 << NormalSubmodesMask.TX_ON)
            | (1 << NormalSubmodesMask.MPA_ON)
        )
    if on_tgt == NormalSubmodesMask.HPA_ON:
        return (
            1 << NormalSubmodesMask.DRO_ON
            | (1 << NormalSubmodesMask.SOLID_STATE_RELAYS_ADC_ON)
            | (1 << NormalSubmodesMask.X8_ON)
            | (1 << NormalSubmodesMask.TX_ON)
            | (1 << NormalSubmodesMask.MPA_ON)
            | (1 << NormalSubmodesMask.HPA_ON)
        )


def pack_wait_time_cmd(q: QueueHelper, param_id: int, print_str: str):
    wait_time = request_wait_time()
    q.add_log_cmd(f"Updating {print_str} wait time to {wait_time}")
    if wait_time is None:
        return
    param_data = pack_scalar_double_param_app_data(
        object_id=PL_PCDU_ID,
        domain_id=0,
        unique_id=param_id,
        parameter=wait_time,
    )
    q.add_pus_tc(pack_fsfw_load_param_cmd(app_data=param_data))


def pack_failure_injection_cmd(q: QueueHelper, param_id: int, print_str: str):
    q.add_log_cmd(f"Inserting {print_str} error")
    param_data = pack_boolean_parameter_app_data(
        object_id=PL_PCDU_ID, domain_id=0, unique_id=param_id, parameter=True
    )
    q.add_pus_tc(pack_fsfw_load_param_cmd(app_data=param_data))


def pack_pl_pcdu_mode_cmd(q: QueueHelper, info: str, mode: Modes, submode: int):
    q.add_log_cmd(info)
    mode_data = pack_mode_data(object_id=PL_PCDU_ID, mode=mode, submode=submode)
    q.add_pus_tc(
        PusTelecommand(
            service=200, subservice=Subservices.TC_MODE_COMMAND, app_data=mode_data
        )
    )