#!/usr/bin/env python3
import logging
import sys
import time
import traceback
from pathlib import Path
from typing import cast

from spacepackets import SpacePacketHeader, SpacePacket
from spacepackets.ccsds import SPACE_PACKET_HEADER_SIZE
from spacepackets.cfdp import (
    ConditionCode,
    ChecksumType,
    TransmissionMode,
    PduHolder,
    DirectiveType,
    PduFactory,
    PduType,
)
from tmtccmd.cfdp import CfdpUserBase, TransactionId
from tmtccmd.cfdp.defs import CfdpRequestType
from tmtccmd.cfdp.handler import CfdpInCcsdsHandler
from tmtccmd.cfdp.mib import (
    DefaultFaultHandlerBase,
    LocalEntityCfg,
    IndicationCfg,
    RemoteEntityCfg,
)
from tmtccmd.cfdp.user import (
    TransactionFinishedParams,
    MetadataRecvParams,
    FileSegmentRecvdParams,
)
from tmtccmd.tc.handler import SendCbParams

try:
    import spacepackets
except ImportError as error:
    print(error)
    print("Python spacepackets module could not be imported")
    print(
        'Install with "cd spacepackets && python3 -m pip intall -e ." for interative installation'
    )
    sys.exit(1)

try:
    import tmtccmd
except ImportError as error:
    run_tmtc_commander = None
    initialize_tmtc_commander = None
    tb = traceback.format_exc()
    print(tb)
    print("Python tmtccmd submodule could not be imported")
    sys.exit(1)

from spacepackets.ecss import PusVerificator
from tmtccmd import get_console_logger, TcHandlerBase, BackendBase
from tmtccmd.util import FileSeqCountProvider, PusFileSeqCountProvider
from tmtccmd.util.tmtc_printer import FsfwTmTcPrinter

from tmtccmd.logging.pus import (
    RawTmtcTimedLogWrapper,
    RegularTmtcLogWrapper,
    TimedLogWhen,
)
from tmtccmd.pus import VerificationWrapper
from tmtccmd.tm import SpecificApidHandlerBase, GenericApidHandlerBase, CcsdsTmHandler
from tmtccmd.core import BackendRequest
from tmtccmd.logging import get_current_time_string
from tmtccmd.tc import (
    ProcedureWrapper,
    FeedWrapper,
    TcProcedureType,
    TcQueueEntryType,
    DefaultPusQueueHelper,
)
from tmtccmd.config import (
    default_json_path,
    SetupWrapper,
    params_to_procedure_conversion,
)
from tmtccmd.config.args import (
    SetupParams,
    PreArgsParsingWrapper,
    ProcedureParamsWrapper,
)
from config import __version__
from config.definitions import (
    PUS_APID,
    CFDP_APID,
    CFDP_LOCAL_ENTITY_ID,
    CFDP_REMOTE_ENTITY_ID,
)
from config.hook import EiveHookObject
from pus_tm.factory_hook import pus_factory_hook
from pus_tc.procedure_packer import handle_default_procedure

LOGGER = get_console_logger()

# Put rotating file logger parameters here for quick changes
ROTATING_TIMED_LOGGER_INTERVAL_WHEN = TimedLogWhen.PER_MINUTE
ROTATING_TIMED_LOGGER_INTERVAL = 30


class EiveCfdpFaultHandler(DefaultFaultHandlerBase):
    def notice_of_suspension_cb(self, cond: ConditionCode):
        pass

    def notice_of_cancellation_cb(self, cond: ConditionCode):
        pass

    def abandoned_cb(self, cond: ConditionCode):
        pass

    def ignore_cb(self, cond: ConditionCode):
        pass


class EiveCfdpUser(CfdpUserBase):
    def transaction_indication(self, transaction_id: TransactionId):
        LOGGER.info(f"CFDP User: Start of File {transaction_id}")

    def eof_sent_indication(self, transaction_id: TransactionId):
        LOGGER.info(f"CFDP User: EOF sent for {transaction_id}")

    def transaction_finished_indication(self, params: TransactionFinishedParams):
        LOGGER.info(f"CFDP User: {params.transaction_id} finished")

    def metadata_recv_indication(self, params: MetadataRecvParams):
        pass

    def file_segment_recv_indication(self, params: FileSegmentRecvdParams):
        pass

    def report_indication(self, transaction_id: TransactionId, status_report: any):
        pass

    def suspended_indication(
        self, transaction_id: TransactionId, cond_code: ConditionCode
    ):
        pass

    def resumed_indication(self, transaction_id: TransactionId, progress: int):
        pass

    def fault_indication(
        self, transaction_id: TransactionId, cond_code: ConditionCode, progress: int
    ):
        pass

    def abandoned_indication(
        self, transaction_id: TransactionId, cond_code: ConditionCode, progress: int
    ):
        pass

    def eof_recv_indication(self, transaction_id: TransactionId):
        pass


class PusHandler(SpecificApidHandlerBase):
    def __init__(
        self,
        wrapper: VerificationWrapper,
        printer: FsfwTmTcPrinter,
        raw_logger: RawTmtcTimedLogWrapper,
    ):
        super().__init__(PUS_APID, None)
        self.printer = printer
        self.verif_wrapper = wrapper
        self.raw_logger = raw_logger

    def handle_tm(self, packet: bytes, _user_args: any):
        # with open("tc.bin", "wb") as of:
        #    of.write(packet)
        pus_factory_hook(packet, self.verif_wrapper, self.printer, self.raw_logger)


class UnknownApidHandler(GenericApidHandlerBase):
    def handle_tm(self, apid: int, _packet: bytes, _user_args: any):
        LOGGER.warning(f"Packet with unknwon APID {apid} detected")


class CfdpInCcsdsWrapper(SpecificApidHandlerBase):
    def __init__(self, cfdp_in_ccsds_handler: CfdpInCcsdsHandler):
        super().__init__(CFDP_APID, None)
        self.handler = cfdp_in_ccsds_handler

    def handle_tm(self, packet: bytes, _user_args: any):
        # Ignore the space packet header. Its only purpose is to use the same protocol and
        # have a seaprate APID for space packets. If this function is called, the APID is correct.
        pdu = packet[SPACE_PACKET_HEADER_SIZE:]
        pdu_base = PduFactory.from_raw(pdu)
        if pdu_base.pdu_type == PduType.FILE_DATA:
            LOGGER.info("Received File Data PDU TM")
        else:
            if pdu_base.directive_type == DirectiveType.FINISHED_PDU:
                LOGGER.info(f"Received Finished PDU TM")
            else:
                LOGGER.info(
                    f"Received File Directive PDU with type {pdu_base.directive_type!r} TM"
                )
        self.handler.pass_pdu_packet(pdu_base)


class TcHandler(TcHandlerBase):
    def __init__(
        self,
        seq_count_provider: FileSeqCountProvider,
        cfdp_in_ccsds_wrapper: CfdpInCcsdsWrapper,
        pus_verificator: PusVerificator,
        high_level_file_logger: logging.Logger,
        raw_pus_logger: RawTmtcTimedLogWrapper,
        gui: bool,
    ):
        super().__init__()
        self.cfdp_handler_started = False
        self.cfdp_dest_id = CFDP_REMOTE_ENTITY_ID
        self.seq_count_provider = seq_count_provider
        self.pus_verificator = pus_verificator
        self.high_level_file_logger = high_level_file_logger
        self.pus_raw_logger = raw_pus_logger
        self.gui = gui
        self.queue_helper = DefaultPusQueueHelper(
            queue_wrapper=None,
            pus_apid=PUS_APID,
            seq_cnt_provider=seq_count_provider,
            pus_verificator=pus_verificator,
        )
        self.cfdp_in_ccsds_wrapper = cfdp_in_ccsds_wrapper

    def cfdp_done(self) -> bool:
        if self.cfdp_handler_started:
            if not self.cfdp_in_ccsds_wrapper.handler.put_request_pending():
                self.cfdp_handler_started = False
                return True
        return False

    def feed_cb(self, info: ProcedureWrapper, wrapper: FeedWrapper):
        self.queue_helper.queue_wrapper = wrapper.queue_wrapper
        if info.proc_type == TcProcedureType.DEFAULT:
            handle_default_procedure(self, info.to_def_procedure(), self.queue_helper)
        elif info.proc_type == TcProcedureType.CFDP:
            self.handle_cfdp_procedure(info)

    def send_cb(self, send_params: SendCbParams):
        entry_helper = send_params.entry
        if entry_helper.is_tc:
            if entry_helper.entry_type == TcQueueEntryType.PUS_TC:
                pus_tc_wrapper = entry_helper.to_pus_tc_entry()
                pus_tc_wrapper.pus_tc.seq_count = (
                    self.seq_count_provider.get_and_increment()
                )
                pus_tc_wrapper.pus_tc.apid = PUS_APID
                # Add TC after Sequence Count stamping
                self.pus_verificator.add_tc(pus_tc_wrapper.pus_tc)
                raw_tc = pus_tc_wrapper.pus_tc.pack()
                self.pus_raw_logger.log_tc(pus_tc_wrapper.pus_tc)
                tc_info_string = f"Sent {pus_tc_wrapper.pus_tc}"
                LOGGER.info(tc_info_string)
                self.high_level_file_logger.info(
                    f"{get_current_time_string(True)}: {tc_info_string}"
                )
                # with open("tc.bin", "wb") as of:
                #    of.write(raw_tc)
                send_params.com_if.send(raw_tc)
            elif entry_helper.entry_type == TcQueueEntryType.CCSDS_TC:
                cfdp_packet_in_ccsds = entry_helper.to_space_packet_entry()
                send_params.com_if.send(cfdp_packet_in_ccsds.space_packet.pack())
                # TODO: Log raw CFDP packets similarly to how PUS packets are logged.
                #        - Log full raw format including space packet wrapper
                #        - Log context information: Transaction ID, and PDU type and directive
                # Could re-use file logger. Should probably do that
                # print(f"sending packet: [{cfdp_packet_in_ccsds.space_packet.pack()}]")
                # with open(f"cfdp_packet_{self.cfdp_counter}", "wb") as of:
                # of.write(cfdp_packet_in_ccsds.space_packet.pack())
                # self.cfdp_counter += 1
        elif entry_helper.entry_type == TcQueueEntryType.LOG:
            log_entry = entry_helper.to_log_entry()
            LOGGER.info(log_entry.log_str)
            self.high_level_file_logger.info(log_entry.log_str)

    def handle_cfdp_procedure(self, info: ProcedureWrapper):
        cfdp_procedure = info.to_cfdp_procedure()
        if cfdp_procedure.cfdp_request_type == CfdpRequestType.PUT:
            if (
                not self.cfdp_in_ccsds_wrapper.handler.put_request_pending()
                and not self.cfdp_handler_started
            ):
                put_req = cfdp_procedure.request_wrapper.to_put_request()
                put_req.cfg.destination_id = self.cfdp_dest_id
                LOGGER.info(
                    f"CFDP: Starting file put request with parameters:\n{put_req}"
                )
                self.cfdp_in_ccsds_wrapper.handler.cfdp_handler.put_request(put_req)
                self.cfdp_handler_started = True

                for source_pair, dest_pair in self.cfdp_in_ccsds_wrapper.handler:
                    pdu, sp = source_pair
                    pdu = cast(PduHolder, pdu)
                    if pdu.is_file_directive:
                        if pdu.pdu_directive_type == DirectiveType.METADATA_PDU:
                            metadata = pdu.to_metadata_pdu()
                            self.queue_helper.add_log_cmd(
                                f"CFDP Source: Sending Metadata PDU for file with size "
                                f"{metadata.file_size}"
                            )
                        elif pdu.pdu_directive_type == DirectiveType.EOF_PDU:
                            self.queue_helper.add_log_cmd(
                                f"CFDP Source: Sending EOF PDU"
                            )
                    else:
                        fd_pdu = pdu.to_file_data_pdu()
                        self.queue_helper.add_log_cmd(
                            f"CFDP Source: Sending File Data PDU for segment at offset "
                            f"{fd_pdu.offset} with length {len(fd_pdu.file_data)}"
                        )
                    self.queue_helper.add_ccsds_tc(sp)
                    self.cfdp_in_ccsds_wrapper.handler.confirm_source_packet_sent()
                self.cfdp_in_ccsds_wrapper.handler.source_handler.state_machine()

    def queue_finished_cb(self, info: ProcedureWrapper):
        if info is not None:
            if info.proc_type == TcQueueEntryType.PUS_TC:
                def_proc = info.to_def_procedure()
                LOGGER.info(
                    f"Finished queue for service {def_proc.service} and op code {def_proc.op_code}"
                )
            elif info.proc_type == TcProcedureType.CFDP:
                LOGGER.info(f"Finished CFDP queue")


def setup_params() -> SetupWrapper:
    print(f"-- eive tmtc v{__version__} --")
    print(f"-- spacepackets v{spacepackets.__version__} --")
    hook_obj = EiveHookObject(default_json_path())
    params = SetupParams()
    parser_wrapper = PreArgsParsingWrapper()
    parser_wrapper.create_default_parent_parser()
    parser_wrapper.create_default_parser()
    parser_wrapper.add_def_proc_and_cfdp_as_subparsers()
    post_arg_parsing_wrapper = parser_wrapper.parse(hook_obj)
    tmtccmd.init_printout(post_arg_parsing_wrapper.use_gui)
    use_prompts = not post_arg_parsing_wrapper.use_gui
    proc_param_wrapper = ProcedureParamsWrapper()
    if use_prompts:
        post_arg_parsing_wrapper.set_params_with_prompts(params, proc_param_wrapper)
    else:
        post_arg_parsing_wrapper.set_params_without_prompts(params, proc_param_wrapper)
    params.apid = PUS_APID
    setup_wrapper = SetupWrapper(
        hook_obj=hook_obj, setup_params=params, proc_param_wrapper=proc_param_wrapper
    )
    return setup_wrapper


def setup_cfdp_handler() -> CfdpInCcsdsWrapper:
    fh_base = EiveCfdpFaultHandler()
    cfdp_cfg = LocalEntityCfg(
        local_entity_id=CFDP_LOCAL_ENTITY_ID,
        indication_cfg=IndicationCfg(),
        default_fault_handlers=fh_base,
    )
    remote_cfg = RemoteEntityCfg(
        closure_requested=False,
        entity_id=CFDP_REMOTE_ENTITY_ID,
        max_file_segment_len=1024,
        check_limit=None,
        crc_on_transmission=False,
        crc_type=ChecksumType.CRC_32,
        default_transmission_mode=TransmissionMode.UNACKNOWLEDGED,
    )
    cfdp_seq_count_provider = FileSeqCountProvider(
        max_bit_width=16, file_name=Path("seqcnt_cfdp_transaction.txt")
    )
    cfdp_ccsds_seq_count_provider = PusFileSeqCountProvider(
        file_name=Path("seqcnt_cfdp_ccsds_.txt")
    )
    cfdp_user = EiveCfdpUser()
    cfdp_in_ccsds_handler = CfdpInCcsdsHandler(
        cfg=cfdp_cfg,
        remote_cfgs=[remote_cfg],
        ccsds_apid=CFDP_APID,
        ccsds_seq_cnt_provider=cfdp_ccsds_seq_count_provider,
        cfdp_seq_cnt_provider=cfdp_seq_count_provider,
        user=cfdp_user,
    )
    return CfdpInCcsdsWrapper(cfdp_in_ccsds_handler)


def setup_tmtc_handlers(
    verificator: PusVerificator,
    printer: FsfwTmTcPrinter,
    raw_logger: RawTmtcTimedLogWrapper,
    gui: bool,
) -> (CcsdsTmHandler, TcHandler):
    cfdp_in_ccsds_wrapper = setup_cfdp_handler()
    verification_wrapper = VerificationWrapper(verificator, LOGGER, printer.file_logger)
    pus_handler = PusHandler(verification_wrapper, printer, raw_logger)
    ccsds_handler = CcsdsTmHandler(generic_handler=UnknownApidHandler(None))
    ccsds_handler.add_apid_handler(pus_handler)
    ccsds_handler.add_apid_handler(cfdp_in_ccsds_wrapper)
    seq_count_provider = PusFileSeqCountProvider()
    tc_handler = TcHandler(
        seq_count_provider=seq_count_provider,
        pus_verificator=verificator,
        high_level_file_logger=printer.file_logger,
        raw_pus_logger=raw_logger,
        gui=gui,
        cfdp_in_ccsds_wrapper=cfdp_in_ccsds_wrapper,
    )
    return ccsds_handler, tc_handler


def setup_backend(
    setup_wrapper: SetupWrapper,
    tc_handler: TcHandler,
    ccsds_handler: CcsdsTmHandler,
) -> BackendBase:
    init_proc = params_to_procedure_conversion(setup_wrapper.proc_param_wrapper)
    tmtc_backend = tmtccmd.create_default_tmtc_backend(
        setup_wrapper=setup_wrapper,
        tm_handler=ccsds_handler,
        tc_handler=tc_handler,
        init_procedure=init_proc,
    )
    tmtccmd.start(tmtc_backend=tmtc_backend, hook_obj=setup_wrapper.hook_obj)
    return tmtc_backend


def main():
    try:
        setup_wrapper = setup_params()
    except KeyboardInterrupt as e:
        LOGGER.info(f"{e}. Exiting")
        sys.exit(0)
    tmtc_logger = RegularTmtcLogWrapper()
    printer = FsfwTmTcPrinter(tmtc_logger.logger)
    raw_logger = RawTmtcTimedLogWrapper(
        when=ROTATING_TIMED_LOGGER_INTERVAL_WHEN,
        interval=ROTATING_TIMED_LOGGER_INTERVAL,
    )
    pus_verificator = PusVerificator()
    ccsds_handler, tc_handler = setup_tmtc_handlers(
        pus_verificator, printer, raw_logger, setup_wrapper.params.use_gui
    )

    tmtccmd.setup(setup_wrapper)
    tmtc_backend = setup_backend(
        setup_wrapper=setup_wrapper, ccsds_handler=ccsds_handler, tc_handler=tc_handler
    )
    try:
        while True:
            state = tmtc_backend.periodic_op(None)
            tc_handler.cfdp_in_ccsds_wrapper.handler.fsm()
            if state.request == BackendRequest.TERMINATION_NO_ERROR:
                sys.exit(0)
            elif state.request == BackendRequest.DELAY_IDLE:
                LOGGER.info("TMTC Client in IDLE mode")
                time.sleep(3.0)
            elif state.request == BackendRequest.DELAY_LISTENER:
                if tc_handler.cfdp_done():
                    LOGGER.info("CFDP transaction done, closing client")
                    sys.exit(0)
                time.sleep(0.5)
            elif state.request == BackendRequest.DELAY_CUSTOM:
                if state.next_delay.total_seconds() < 0.5:
                    time.sleep(state.next_delay.total_seconds())
                else:
                    time.sleep(0.5)
            elif state.request == BackendRequest.CALL_NEXT:
                pass
    except KeyboardInterrupt:
        sys.exit(0)


if __name__ == "__main__":
    main()