#!/usr/bin/env python3
from typing import List, Tuple
from spacepackets.ecss.defs import PusService
from spacepackets.ecss.tm import PusTm
from tmtccmd.com import ComInterface
import toml
import struct
import logging
import argparse
import time
import enum
from tmtccmd.com.serial_base import SerialCfg
from tmtccmd.com.serial_cobs import SerialCobsComIF
from tmtccmd.com.ser_utils import prompt_com_port
from crcmod.predefined import PredefinedCrc
from spacepackets.ecss.tc import PusTc
from spacepackets.ecss.pus_verificator import PusVerificator, StatusField
from spacepackets.ecss.pus_1_verification import Service1Tm, UnpackParams
from spacepackets.seqcount import SeqCountProvider
from pathlib import Path
import dataclasses
from elftools.elf.elffile import ELFFile


BAUD_RATE = 115200

BOOTLOADER_START_ADDR = 0x0
BOOTLOADER_END_ADDR = 0x4000
BOOTLOADER_CRC_ADDR = BOOTLOADER_END_ADDR - 4
BOOTLOADER_MAX_SIZE = BOOTLOADER_END_ADDR - BOOTLOADER_START_ADDR - 4

APP_A_START_ADDR = 0x4000
APP_A_END_ADDR = 0x22000
# The actual size of the image which is relevant for CRC calculation.
APP_A_SIZE_ADDR = APP_A_END_ADDR - 8
APP_A_CRC_ADDR = APP_A_END_ADDR - 4
APP_A_MAX_SIZE = APP_A_END_ADDR - APP_A_START_ADDR - 8

APP_B_START_ADDR = 0x22000
APP_B_END_ADDR = 0x40000
# The actual size of the image which is relevant for CRC calculation.
APP_B_SIZE_ADDR = APP_B_END_ADDR - 8
APP_B_CRC_ADDR = APP_B_END_ADDR - 4
APP_B_MAX_SIZE = APP_A_END_ADDR - APP_A_START_ADDR - 8

APP_IMG_SZ = (APP_B_END_ADDR - APP_A_START_ADDR) // 2

CHUNK_SIZE = 896

MEMORY_SERVICE = 6
ACTION_SERVICE = 8

RAW_MEMORY_WRITE_SUBSERVICE = 2
BOOT_NVM_MEMORY_ID = 1
PING_PAYLOAD_SIZE = 0


class ActionId(enum.IntEnum):
    CORRUPT_APP_A = 128
    CORRUPT_APP_B = 129


_LOGGER = logging.getLogger(__name__)
SEQ_PROVIDER = SeqCountProvider(bit_width=14)


@dataclasses.dataclass
class LoadableSegment:
    name: str
    offset: int
    size: int
    data: bytes


class Target(enum.Enum):
    BOOTLOADER = 0
    APP_A = 1
    APP_B = 2


class ImageLoader:
    def __init__(self, com_if: ComInterface, verificator: PusVerificator) -> None:
        self.com_if = com_if
        self.verificator = verificator

    def handle_ping_cmd(self):
        _LOGGER.info("Sending ping command")
        ping_tc = PusTc(
            apid=0x00,
            service=PusService.S17_TEST,
            subservice=1,
            seq_count=SEQ_PROVIDER.get_and_increment(),
            app_data=bytes(PING_PAYLOAD_SIZE),
        )
        self.verificator.add_tc(ping_tc)
        self.com_if.send(bytes(ping_tc.pack()))

        data_available = self.com_if.data_available(0.4)
        if not data_available:
            _LOGGER.warning("no ping reply received")
        for reply in self.com_if.receive():
            result = self.verificator.add_tm(
                Service1Tm.from_tm(PusTm.unpack(reply, 0), UnpackParams(0))
            )
            if result is not None and result.completed:
                _LOGGER.info("received ping completion reply")

    def handle_corruption_cmd(self, target: Target):

        if target == Target.BOOTLOADER:
            _LOGGER.error("can not corrupt bootloader")
        if target == Target.APP_A:
            self.send_tc(
                PusTc(
                    apid=0,
                    service=ACTION_SERVICE,
                    subservice=ActionId.CORRUPT_APP_A,
                ),
            )
        if target == Target.APP_B:
            self.send_tc(
                PusTc(
                    apid=0,
                    service=ACTION_SERVICE,
                    subservice=ActionId.CORRUPT_APP_B,
                ),
            )

    def handle_flash_cmd(self, target: Target, file_path: Path) -> int:
        loadable_segments = []
        _LOGGER.info("Parsing ELF file for loadable sections")
        total_size = 0
        loadable_segments, total_size = create_loadable_segments(target, file_path)
        segments_info_str(target, loadable_segments, total_size, file_path)
        result = self._perform_flashing_algorithm(loadable_segments)
        if result != 0:
            return result
        self._crc_and_app_size_postprocessing(target, total_size, loadable_segments)
        return 0

    def _perform_flashing_algorithm(
        self,
        loadable_segments: List[LoadableSegment],
    ) -> int:
        # Perform the flashing algorithm.
        for segment in loadable_segments:
            segment_end = segment.offset + segment.size
            current_addr = segment.offset
            pos_in_segment = 0
            while pos_in_segment < segment.size:
                next_chunk_size = min(segment_end - current_addr, CHUNK_SIZE)
                data = segment.data[pos_in_segment : pos_in_segment + next_chunk_size]
                next_packet = pack_memory_write_command(current_addr, data)
                _LOGGER.info(
                    f"Sending memory write command for address {current_addr:#08x} and data with "
                    f"length {len(data)}"
                )
                self.verificator.add_tc(next_packet)
                self.com_if.send(bytes(next_packet.pack()))
                current_addr += next_chunk_size
                pos_in_segment += next_chunk_size
                start_time = time.time()
                while True:
                    if time.time() - start_time > 1.0:
                        _LOGGER.error("Timeout while waiting for reply")
                        return -1
                    data_available = self.com_if.data_available(0.1)
                    done = False
                    if not data_available:
                        continue
                    replies = self.com_if.receive()
                    for reply in replies:
                        tm = PusTm.unpack(reply, 0)
                        if tm.service != 1:
                            continue
                        service_1_tm = Service1Tm.from_tm(tm, UnpackParams(0))
                        check_result = self.verificator.add_tm(service_1_tm)
                        # We could send after we have received the step reply, but that can
                        # somehow lead to overrun errors. I think it's okay to do it like
                        # this as long as the flash loader only uses polling..
                        if (
                            check_result is not None
                            and check_result.status.completed == StatusField.SUCCESS
                        ):
                            done = True

                        # This is an optimized variant, but I think the small delay is not an issue.
                        """
                        if (
                            check_result is not None
                            and check_result.status.step == StatusField.SUCCESS
                            and len(check_result.status.step_list) == 1
                        ):
                            done = True
                        """
                    self.verificator.remove_completed_entries()
                    if done:
                        break
        return 0

    def _crc_and_app_size_postprocessing(
        self,
        target: Target,
        total_size: int,
        loadable_segments: List[LoadableSegment],
    ):
        if target == Target.BOOTLOADER:
            _LOGGER.info("Blanking the bootloader checksum")
            # Blank the checksum. For the bootloader, the bootloader will calculate the
            # checksum itself on the initial run.
            checksum_write_packet = pack_memory_write_command(
                BOOTLOADER_CRC_ADDR, bytes([0x00, 0x00, 0x00, 0x00])
            )
            self.send_tc(checksum_write_packet)
        else:
            crc_addr = None
            size_addr = None
            if target == Target.APP_A:
                crc_addr = APP_A_CRC_ADDR
                size_addr = APP_A_SIZE_ADDR
            elif target == Target.APP_B:
                crc_addr = APP_B_CRC_ADDR
                size_addr = APP_B_SIZE_ADDR
            assert crc_addr is not None
            assert size_addr is not None
            _LOGGER.info(f"Writing app size {total_size} at address {size_addr:#08x}")
            size_write_packet = pack_memory_write_command(
                size_addr, struct.pack("!I", total_size)
            )
            self.com_if.send(bytes(size_write_packet.pack()))
            time.sleep(0.2)
            crc_calc = PredefinedCrc("crc-32")
            for segment in loadable_segments:
                crc_calc.update(segment.data)
            checksum = crc_calc.digest()
            _LOGGER.info(
                f"Writing checksum 0x[{checksum.hex(sep=',')}] at address {crc_addr:#08x}"
            )
            self.send_tc(pack_memory_write_command(crc_addr, checksum))

    def send_tc(self, tc: PusTc):
        self.com_if.send(bytes(tc.pack()))


def main() -> int:
    print("Python VA416XX Image Loader Application")
    logging.basicConfig(
        format="[%(asctime)s] [%(levelname)s] %(message)s", level=logging.DEBUG
    )
    parser = argparse.ArgumentParser(
        prog="image-loader", description="Python VA416XX Image Loader Application"
    )
    parser.add_argument("-p", "--ping", action="store_true", help="Send ping command")
    parser.add_argument("-c", "--corrupt", action="store_true", help="Corrupt a target")
    parser.add_argument(
        "-t",
        "--target",
        choices=["bl", "a", "b"],
        help="Target (Bootloader or slot A or B)",
    )
    parser.add_argument(
        "path", nargs="?", default=None, help="Path to the App to flash"
    )
    args = parser.parse_args()
    serial_port = None
    if Path("loader.toml").exists():
        with open("loader.toml", "r") as toml_file:
            parsed_toml = toml.loads(toml_file.read())
            if "serial_port" in parsed_toml:
                serial_port = parsed_toml["serial_port"]
    if serial_port is None:
        serial_port = prompt_com_port()
    serial_cfg = SerialCfg(
        com_if_id="ser_cobs",
        serial_port=serial_port,
        baud_rate=BAUD_RATE,
        serial_timeout=0.1,
    )
    verificator = PusVerificator()
    com_if = SerialCobsComIF(serial_cfg)
    com_if.open()
    target = None
    if args.target == "bl":
        target = Target.BOOTLOADER
    elif args.target == "a":
        target = Target.APP_A
    elif args.target == "b":
        target = Target.APP_B
    image_loader = ImageLoader(com_if, verificator)
    file_path = None
    result = -1
    if args.ping:
        image_loader.handle_ping_cmd()
        com_if.close()
        return 0
    if target:
        if not args.corrupt:
            if not args.path:
                _LOGGER.error("App Path needs to be specified for the flash process")
            file_path = Path(args.path)
            if not file_path.exists():
                _LOGGER.error("File does not exist")
    if args.corrupt:
        if not target:
            _LOGGER.error("target for corruption command required")
            com_if.close()
            return -1
        image_loader.handle_corruption_cmd(target)
    else:
        assert file_path is not None
        assert target is not None
        result = image_loader.handle_flash_cmd(target, file_path)

    com_if.close()
    return result


def create_loadable_segments(
    target: Target, file_path: Path
) -> Tuple[List[LoadableSegment], int]:
    loadable_segments = []
    total_size = 0
    with open(file_path, "rb") as app_file:
        elf_file = ELFFile(app_file)

        for idx, segment in enumerate(elf_file.iter_segments("PT_LOAD")):
            if segment.header.p_filesz == 0:
                continue
            # Basic validity checks of the base addresses.
            if idx == 0:
                if (
                    target == Target.BOOTLOADER
                    and segment.header.p_paddr != BOOTLOADER_START_ADDR
                ):
                    raise ValueError(
                        f"detected possibly invalid start address {segment.header.p_paddr:#08x} for "
                        f"bootloader, expected {BOOTLOADER_START_ADDR}"
                    )
                if (
                    target == Target.APP_A
                    and segment.header.p_paddr != APP_A_START_ADDR
                ):
                    raise ValueError(
                        f"detected possibly invalid start address {segment.header.p_paddr:#08x} for "
                        f"App A, expected {APP_A_START_ADDR}"
                    )
                if (
                    target == Target.APP_B
                    and segment.header.p_paddr != APP_B_START_ADDR
                ):
                    raise ValueError(
                        f"detected possibly invalid start address {segment.header.p_paddr:#08x} for "
                        f"App B, expected {APP_B_START_ADDR}"
                    )
            name = None
            for section in elf_file.iter_sections():
                if (
                    section.header.sh_offset == segment.header.p_offset
                    and section.header.sh_size > 0
                ):
                    name = section.name
            if name is None:
                _LOGGER.warning("no fitting section found for segment")
                continue
            # print(f"Segment Addr: {segment.header.p_paddr}")
            # print(f"Segment Offset: {segment.header.p_offset}")
            # print(f"Segment Filesize: {segment.header.p_filesz}")
            loadable_segments.append(
                LoadableSegment(
                    name=name,
                    offset=segment.header.p_paddr,
                    size=segment.header.p_filesz,
                    data=segment.data(),
                )
            )
            total_size += segment.header.p_filesz
    return loadable_segments, total_size


def segments_info_str(
    target: Target,
    loadable_segments: List[LoadableSegment],
    total_size: int,
    file_path: Path,
):
    # Set context string and perform basic sanity checks.
    if target == Target.BOOTLOADER:
        if total_size > BOOTLOADER_MAX_SIZE:
            _LOGGER.error(
                f"provided bootloader app larger than allowed {total_size} bytes"
            )
            return -1
        context_str = "Bootloader"
    elif target == Target.APP_A:
        if total_size > APP_A_MAX_SIZE:
            _LOGGER.error(f"provided App A larger than allowed {total_size} bytes")
            return -1
        context_str = "App Slot A"
    elif target == Target.APP_B:
        if total_size > APP_B_MAX_SIZE:
            _LOGGER.error(f"provided App B larger than allowed {total_size} bytes")
            return -1
        context_str = "App Slot B"
    _LOGGER.info(f"Flashing {context_str} with image {file_path} (size {total_size})")
    for idx, segment in enumerate(loadable_segments):
        _LOGGER.info(
            f"Loadable section {idx} {segment.name} with offset {segment.offset:#08x} and "
            f"size {segment.size}"
        )


def pack_memory_write_command(addr: int, data: bytes) -> PusTc:
    app_data = bytearray()
    app_data.append(BOOT_NVM_MEMORY_ID)
    # N parameter is always 1 here.
    app_data.append(1)
    app_data.extend(struct.pack("!I", addr))
    app_data.extend(struct.pack("!I", len(data)))
    app_data.extend(data)
    return PusTc(
        apid=0,
        service=MEMORY_SERVICE,
        subservice=RAW_MEMORY_WRITE_SUBSERVICE,
        seq_count=SEQ_PROVIDER.get_and_increment(),
        app_data=bytes(app_data),
    )


if __name__ == "__main__":
    main()