import datetime
import enum
import logging
import struct

from tmtccmd.config.tmtc import (
    CmdTreeNode,
)
from tmtccmd.fsfw.tmtc_printer import FsfwTmTcPrinter
from tmtccmd.pus.s20_fsfw_param import create_load_param_cmd
from tmtccmd.pus.s20_fsfw_param_defs import (
    create_scalar_double_parameter,
    create_scalar_float_parameter,
)
from tmtccmd.pus.s200_fsfw_mode import Mode, pack_mode_command
from tmtccmd.pus.tc.s3_fsfw_hk import (
    disable_periodic_hk_command,
    enable_periodic_hk_command_with_interval,
    generate_one_hk_command,
    make_sid,
)
from tmtccmd.tmtc.queue import DefaultPusQueueHelper

from eive_tmtc.config.object_ids import PWR_CONTROLLER
from eive_tmtc.pus_tm.defs import PrintWrapper

_LOGGER = logging.getLogger(__name__)


class SetId(enum.IntEnum):
    CORE_HK_SET = 0
    ENABLE_PL_SET = 1


# class ActionId(enum.IntEnum):


class ParamId(enum.IntEnum):
    BATTERY_INTERNAL_RESISTANCE = 0
    BATTERY_MAXIMUM_CAPACITY = 1
    COULOMB_COUNTER_VOLTAGE_UPPER_THRESHOLD = 2
    MAX_ALLOWED_TIME_DIFF = 3
    PAYLOAD_OP_LIMIT_ON = 4
    PAYLOAD_OP_LIMIT_LOW = 5
    HIGHER_MODES_LIMIT = 6


class OpCode:
    OFF = "off"
    ON = "on"
    NML = "normal"
    SET_PARAMETER = "set_parameter"
    REQUEST_CORE_HK = "core_hk"
    ENABLE_CORE_HK = "core_enable_hk"
    DISABLE_CORE_HK = "core_disable_hk"
    REQUEST_ENABLE_PL_HK = "enable_pl_hk"
    ENABLE_ENABLE_PL_HK = "enable_pl_enable_hk"
    DISABLE_ENABLE_PL_HK = "enable_pl_disable_hk"


class Info:
    OFF = "PWR Ctrl Mode to OFF"
    ON = "PWR Ctrl Mode to ON"
    NML = "PWR Ctrl Mode to NORMAL"
    SET_PARAMETER = "Set Parameter"
    REQUEST_CORE_HK = "Request Core HK once"
    ENABLE_CORE_HK = "Enable Core HK Data Generation"
    DISABLE_CORE_HK = "Disable Core HK Data Generation"
    REQUEST_ENABLE_PL_HK = "Request Enable PL HK once"
    ENABLE_ENABLE_PL_HK = "Enable Enable PL HK Data Generation"
    DISABLE_ENABLE_PL_HK = "Disable Enable PL HK Data Generation"


def create_pwr_ctrl_node() -> CmdTreeNode:
    op_code_strs = [
        getattr(OpCode, key) for key in dir(OpCode) if not key.startswith("__")
    ]
    info_strs = [getattr(Info, key) for key in dir(OpCode) if not key.startswith("__")]
    combined_dict = dict(zip(op_code_strs, info_strs))
    node = CmdTreeNode("pwr_ctrl", "Power Controller", hide_children_for_print=True)
    for op_code, info in combined_dict.items():
        node.add_child(CmdTreeNode(op_code, info))
    return node


def pack_power_ctrl_command(q: DefaultPusQueueHelper, cmd_str: str):
    if cmd_str == OpCode.OFF:
        q.add_log_cmd(f"{Info.OFF}")
        q.add_pus_tc(pack_mode_command(PWR_CONTROLLER, Mode.OFF, 0))
    elif cmd_str == OpCode.ON:
        q.add_log_cmd(f"{Info.ON}")
        q.add_pus_tc(pack_mode_command(PWR_CONTROLLER, Mode.ON, 0))
    elif cmd_str == OpCode.NML:
        q.add_log_cmd(f"{Info.NML}")
        q.add_pus_tc(pack_mode_command(PWR_CONTROLLER, Mode.NORMAL, 0))
    elif cmd_str in OpCode.SET_PARAMETER:
        q.add_log_cmd(f"{Info.SET_PARAMETER}")
        set_pwr_ctrl_param(q)
    elif cmd_str == OpCode.REQUEST_CORE_HK:
        q.add_log_cmd(Info.REQUEST_CORE_HK)
        q.add_pus_tc(
            generate_one_hk_command(make_sid(PWR_CONTROLLER, SetId.CORE_HK_SET))
        )
    elif cmd_str == OpCode.ENABLE_CORE_HK:
        interval = float(input("Please specify interval in floating point seconds: "))
        q.add_log_cmd(Info.ENABLE_CORE_HK)
        cmd_tuple = enable_periodic_hk_command_with_interval(
            False, make_sid(PWR_CONTROLLER, SetId.CORE_HK_SET), interval
        )
        q.add_pus_tc(cmd_tuple[0])
        q.add_pus_tc(cmd_tuple[1])
    elif cmd_str == OpCode.DISABLE_CORE_HK:
        q.add_log_cmd(Info.DISABLE_CORE_HK)
        q.add_pus_tc(
            disable_periodic_hk_command(
                False, make_sid(PWR_CONTROLLER, SetId.CORE_HK_SET)
            )
        )
    elif cmd_str == OpCode.REQUEST_ENABLE_PL_HK:
        q.add_log_cmd(Info.REQUEST_ENABLE_PL_HK)
        q.add_pus_tc(
            generate_one_hk_command(make_sid(PWR_CONTROLLER, SetId.ENABLE_PL_SET))
        )
    elif cmd_str == OpCode.ENABLE_ENABLE_PL_HK:
        interval = float(input("Please specify interval in floating point seconds: "))
        q.add_log_cmd(Info.ENABLE_ENABLE_PL_HK)
        cmd_tuple = enable_periodic_hk_command_with_interval(
            False, make_sid(PWR_CONTROLLER, SetId.ENABLE_PL_SET), interval
        )
        q.add_pus_tc(cmd_tuple[0])
        q.add_pus_tc(cmd_tuple[1])
    elif cmd_str == OpCode.DISABLE_ENABLE_PL_HK:
        q.add_log_cmd(Info.DISABLE_ENABLE_PL_HK)
        q.add_pus_tc(
            disable_periodic_hk_command(
                False, make_sid(PWR_CONTROLLER, SetId.ENABLE_PL_SET)
            )
        )


def set_pwr_ctrl_param(q: DefaultPusQueueHelper):
    for val in ParamId:
        print("{:<2}: {:<20}".format(val, val.name))
    param = int(input("Specify parameter to set \n" ""))
    match param:
        case ParamId.BATTERY_INTERNAL_RESISTANCE:
            value = float(input("Specify parameter value to set [Ohm]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.BATTERY_INTERNAL_RESISTANCE,
                        parameter=value,
                    )
                )
            )
        case ParamId.BATTERY_MAXIMUM_CAPACITY:
            value = float(input("Specify parameter value to set [Ah]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.BATTERY_MAXIMUM_CAPACITY,
                        parameter=value,
                    )
                )
            )
        case ParamId.COULOMB_COUNTER_VOLTAGE_UPPER_THRESHOLD:
            value = float(input("Specify parameter value to set [V]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.COULOMB_COUNTER_VOLTAGE_UPPER_THRESHOLD,
                        parameter=value,
                    )
                )
            )
        case ParamId.MAX_ALLOWED_TIME_DIFF:
            value = float(input("Specify parameter value to set [s]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_double_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.MAX_ALLOWED_TIME_DIFF,
                        parameter=value,
                    )
                )
            )
        case ParamId.PAYLOAD_OP_LIMIT_ON:
            value = float(input("Specify parameter value to set [1]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.PAYLOAD_OP_LIMIT_ON,
                        parameter=value,
                    )
                )
            )
        case ParamId.PAYLOAD_OP_LIMIT_LOW:
            value = float(input("Specify parameter value to set [1]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.PAYLOAD_OP_LIMIT_LOW,
                        parameter=value,
                    )
                )
            )
        case ParamId.HIGHER_MODES_LIMIT:
            value = float(input("Specify parameter value to set [1]: "))
            q.add_pus_tc(
                create_load_param_cmd(
                    create_scalar_float_parameter(
                        object_id=PWR_CONTROLLER,
                        domain_id=0,
                        unique_id=ParamId.HIGHER_MODES_LIMIT,
                        parameter=value,
                    )
                )
            )


def handle_pwr_ctrl_hk_data(
    pw: PrintWrapper,
    set_id: int,
    hk_data: bytes,
    packet_time: datetime.datetime,
):
    pw.ilog(_LOGGER, f"Received PWR CTRL HK with packet time {packet_time}")
    match set_id:
        case SetId.CORE_HK_SET:
            handle_core_hk_data(pw, hk_data)
        case SetId.ENABLE_PL_SET:
            handle_enable_pl_data(pw, hk_data)


def handle_core_hk_data(pw: PrintWrapper, hk_data: bytes):
    pw.dlog("Received Core HK Set")
    fmt_int16 = "!h"
    fmt_float = "!f"
    inc_len_int16 = struct.calcsize(fmt_int16)
    inc_len_float = struct.calcsize(fmt_float)
    if len(hk_data) < inc_len_int16 + 2 * inc_len_float:
        pw.dlog("Received HK set too small")
        return
    current_idx = 0
    total_battery_current = struct.unpack(
        fmt_int16, hk_data[current_idx : current_idx + inc_len_int16]
    )[0]
    current_idx += inc_len_int16
    open_circuit_voltage_charge = struct.unpack(
        fmt_float, hk_data[current_idx : current_idx + inc_len_float]
    )[0]
    current_idx += inc_len_float
    coulomb_counter_charge = struct.unpack(
        fmt_float, hk_data[current_idx : current_idx + inc_len_float]
    )[0]
    current_idx += inc_len_float
    pw.dlog(f"Total Battery Current: {total_battery_current} [mA]")
    pw.dlog(f"Open Circuit Voltage Charge: {open_circuit_voltage_charge*100:8.3f} [%]")
    pw.dlog(f"Coulomb Counter Charge: {coulomb_counter_charge*100:8.3f} [%]")
    FsfwTmTcPrinter.get_validity_buffer(hk_data[current_idx:], num_vars=3)


def handle_enable_pl_data(pw: PrintWrapper, hk_data: bytes):
    pw.dlog("Received Enable PL HK Set")
    fmt_uint16 = "!B"
    inc_len_uint16 = struct.calcsize(fmt_uint16)
    if len(hk_data) < inc_len_uint16:
        pw.dlog("Received HK set too small")
        return
    current_idx = 0
    pl_use_allowed = struct.unpack(
        fmt_uint16, hk_data[current_idx : current_idx + inc_len_uint16]
    )[0]
    current_idx += inc_len_uint16
    pw.dlog(f"PL Use Allowed: {pl_use_allowed}")
    FsfwTmTcPrinter.get_validity_buffer(hk_data[current_idx:], num_vars=1)