import enum
import socket
import struct
from socket import AF_INET
from typing import Tuple, Optional

from config.definitions import CustomServiceList
from config.object_ids import ACS_CONTROLLER
from pus_tm.defs import PrintWrapper
from tmtccmd import get_console_logger
from tmtccmd.config.tmtc import (
    tmtc_definitions_provider,
    TmtcDefinitionWrapper,
    OpCodeEntry,
)
from tmtccmd.tc import service_provider
from tmtccmd.tc.decorator import ServiceProviderParams
from tmtccmd.tc.pus_3_fsfw_hk import (
    generate_one_hk_command,
    make_sid,
    enable_periodic_hk_command_with_interval,
    disable_periodic_hk_command,
)
from tmtccmd.util.tmtc_printer import FsfwTmTcPrinter

LOGGER = get_console_logger()


class SetIds(enum.IntEnum):
    MGM_SET = 0


class OpCodes:
    REQUEST_MGM_HK = ["0", "req-mgm-hk"]
    ENABLE_MGM_HK = ["1", "enable-mgm-hk"]
    DISABLE_MGM_HK = ["1", "disable-mgm-hk"]


class Info:
    REQUEST_MGM_HK = "Request MGM HK once"
    ENABLE_MGM_HK = "Enable MGM HK data generation"
    DISABLE_MGM_HK = "Disable MGM HK data generation"


PERFORM_MGM_CALIBRATION = False
CALIBRATION_SOCKET_HOST = "localhost"
CALIBRATION_SOCKET_PORT = 6677
CALIBRATION_ADDR = (CALIBRATION_SOCKET_HOST, CALIBRATION_SOCKET_PORT)

if PERFORM_MGM_CALIBRATION:
    CALIBR_SOCKET = socket.socket(AF_INET, socket.SOCK_STREAM)
    CALIBR_SOCKET.setblocking(False)
    CALIBR_SOCKET.settimeout(0.2)
    CALIBR_SOCKET.connect(CALIBRATION_ADDR)


@tmtc_definitions_provider
def acs_cmd_defs(defs: TmtcDefinitionWrapper):
    oce = OpCodeEntry()
    oce.add(keys=OpCodes.REQUEST_MGM_HK, info=Info.REQUEST_MGM_HK)
    oce.add(keys=OpCodes.ENABLE_MGM_HK, info=Info.ENABLE_MGM_HK)
    oce.add(keys=OpCodes.DISABLE_MGM_HK, info=Info.DISABLE_MGM_HK)
    defs.add_service(
        name=CustomServiceList.ACS_CTRL.value, info="ACS Controller", op_code_entry=oce
    )


@service_provider(CustomServiceList.ACS_CTRL.value)
def pack_acs_ctrl_command(p: ServiceProviderParams):
    op_code = p.op_code
    q = p.queue_helper
    sid = make_sid(ACS_CONTROLLER, SetIds.MGM_SET)
    if op_code in OpCodes.REQUEST_MGM_HK:
        q.add_log_cmd(Info.REQUEST_MGM_HK)
        q.add_pus_tc(generate_one_hk_command(sid))
    elif op_code in OpCodes.ENABLE_MGM_HK:
        q.add_log_cmd(Info.ENABLE_MGM_HK)
        cmd_tuple = enable_periodic_hk_command_with_interval(False, sid, 2.0)
        q.add_pus_tc(cmd_tuple[0])
        q.add_pus_tc(cmd_tuple[1])
    elif op_code in OpCodes.DISABLE_MGM_HK:
        q.add_log_cmd(Info.DISABLE_MGM_HK)
        q.add_pus_tc(disable_periodic_hk_command(False, sid))
    else:
        LOGGER.info(f"Unknown op code {op_code}")


def handle_acs_ctrl_mgm_data(printer: FsfwTmTcPrinter, hk_data: bytes):
    current_idx = 0
    pw = PrintWrapper(printer)

    if len(hk_data) < 61:
        pw.dlog(
            f"ACS CTRL HK: MGM HK data with length {len(hk_data)} shorter than expected 61 bytes"
        )
        pw.dlog(f"Raw Data: {hk_data.hex(sep=',')}")
        return

    def unpack_float_tuple(idx: int) -> (tuple, int):
        f_tuple = struct.unpack(
            float_tuple_fmt_str,
            hk_data[idx : idx + struct.calcsize(float_tuple_fmt_str)],
        )
        idx += struct.calcsize(float_tuple_fmt_str)
        return f_tuple, idx

    float_tuple_fmt_str = "!fff"
    mgm_0_lis3_floats_ut, current_idx = unpack_float_tuple(current_idx)
    mgm_1_rm3100_floats_ut, current_idx = unpack_float_tuple(current_idx)
    mgm_2_lis3_floats_ut, current_idx = unpack_float_tuple(current_idx)
    mgm_3_rm3100_floats_ut, current_idx = unpack_float_tuple(current_idx)
    isis_floats_nt, current_idx = unpack_float_tuple(current_idx)
    imtq_mgm_ut = tuple(val / 1000.0 for val in isis_floats_nt)
    pw.dlog("ACS CTRL HK: MGM values [X,Y,Z] in floating point uT: ")
    mgm_lists = [
        mgm_0_lis3_floats_ut,
        mgm_1_rm3100_floats_ut,
        mgm_2_lis3_floats_ut,
        mgm_3_rm3100_floats_ut,
        imtq_mgm_ut,
    ]
    formatted_list = []
    # Reserve 8 decimal digits, use precision 3
    float_str_fmt = "[{:8.3f}, {:8.3f}, {:8.3f}]"
    for mgm_entry in mgm_lists[0:4]:
        formatted_list.append(float_str_fmt.format(*mgm_entry))
    formatted_list.append(hk_data[current_idx])
    formatted_list.append(float_str_fmt.format(*mgm_lists[4]))
    print_str_list = [
        "ACS Board MGM 0 LIS3MDL",
        "ACS Board MGM 1 RM3100",
        "ACS Board MGM 2 LIS3MDL",
        "ACS Board MGM 3 RM3100",
        "IMTQ Actuation Status:",
        "IMTQ MGM:",
    ]
    for entry in zip(print_str_list, formatted_list):
        pw.dlog(f"{entry[0].ljust(28)}: {entry[1]}")
    current_idx += 1
    if PERFORM_MGM_CALIBRATION:
        perform_mgm_calibration(pw, mgm_0_lis3_floats_ut)
    assert current_idx == 61


def perform_mgm_calibration(pw: PrintWrapper, mgm_tuple: Tuple):
    global CALIBR_SOCKET, CALIBRATION_ADDR
    try:
        declare_api_cmd = "declare_api_version 2"
        CALIBR_SOCKET.sendall(f"{declare_api_cmd}\n".encode())
        reply = CALIBR_SOCKET.recv(1024)
        if len(reply) != 2:
            pw.dlog(
                f"MGM calibration: Reply received command {declare_api_cmd} has invalid length {len(reply)}"
            )
            return
        else:
            if str(reply[0]) == "0":
                pw.dlog(f"MGM calibration: API version 2 was not accepted")
                return
        if len(mgm_tuple) != 3:
            pw.dlog(f"MGM tuple has invalid length {len(mgm_tuple)}")
        mgm_list = [mgm / 1e6 for mgm in mgm_tuple]
        command = (
            f"magnetometer_field {mgm_list[0]} {mgm_list[1]} {mgm_list[2]}\n".encode()
        )
        CALIBR_SOCKET.sendall(command)
        reply = CALIBR_SOCKET.recv(1024)
        if len(reply) != 2:
            pw.dlog(
                f"MGM calibration: Reply received command magnetometer_field has invalid length {len(reply)}"
            )
            return
        else:
            if str(reply[0]) == "0":
                pw.dlog(f"MGM calibration: magnetmeter field format was not accepted")
                return
        pw.dlog(f"Sent data {mgm_list} to Helmholtz Testbench successfully")
    except socket.timeout:
        pw.dlog("Socket timeout")
    except BlockingIOError as e:
        pw.dlog(f"Error {e}")
    except ConnectionResetError as e:
        pw.dlog("Socket was closed")
    except ConnectionRefusedError or OSError:
        pw.dlog("Connecting to Calibration Socket on addrss {} failed")