#include "ArduinoComIF.h"

#include <fsfw/globalfunctions/CRC.h>
#include <fsfw/globalfunctions/DleEncoder.h>
#include <fsfw/serviceinterface/ServiceInterfaceStream.h>

#include "ArduinoCookie.h"

// This only works on Linux
#ifdef LINUX
#include <fcntl.h>
#include <termios.h>
#include <unistd.h>
#elif WIN32
#include <strsafe.h>
#include <windows.h>
#endif

#include <cstring>

ArduinoComIF::ArduinoComIF(object_id_t setObjectId, bool promptComIF, const char *serialDevice)
    : rxBuffer(MAX_PACKET_SIZE * MAX_NUMBER_OF_SPI_DEVICES * 10, true), SystemObject(setObjectId) {
#ifdef LINUX
  initialized = false;
  serialPort = ::open("/dev/ttyUSB0", O_RDWR);

  if (serialPort < 0) {
    // configuration error
    printf("Error %i from open: %s\n", errno, strerror(errno));
    return;
  }

  struct termios tty;
  memset(&tty, 0, sizeof tty);

  // Read in existing settings, and handle any error
  if (tcgetattr(serialPort, &tty) != 0) {
    printf("Error %i from tcgetattr: %s\n", errno, strerror(errno));
    return;
  }

  tty.c_cflag &= ~PARENB;   // Clear parity bit, disabling parity
  tty.c_cflag &= ~CSTOPB;   // Clear stop field, only one stop bit used in communication
  tty.c_cflag |= CS8;       // 8 bits per byte
  tty.c_cflag &= ~CRTSCTS;  // Disable RTS/CTS hardware flow control
  tty.c_lflag &= ~ICANON;   // Disable Canonical Mode
  tty.c_oflag &= ~OPOST;    // Prevent special interpretation of output bytes (e.g. newline chars)
  tty.c_oflag &= ~ONLCR;    // Prevent conversion of newline to carriage return/line feed
  tty.c_cc[VTIME] = 0;      // Non Blocking
  tty.c_cc[VMIN] = 0;

  cfsetispeed(&tty, B9600);  // Baudrate

  if (tcsetattr(serialPort, TCSANOW, &tty) != 0) {
    // printf("Error %i from tcsetattr: %s\n", errno, strerror(errno));
    return;
  }

  initialized = true;
#elif WIN32
  DCB serialParams = {0};

  // we need to ask the COM port from the user.
  if (promptComIF) {
    sif::info << "Please enter the COM port (c to cancel): " << std::flush;
    std::string comPort;
    while (hCom == INVALID_HANDLE_VALUE) {
      std::getline(std::cin, comPort);
      if (comPort[0] == 'c') {
        break;
      }
      const TCHAR *pcCommPort = comPort.c_str();
      hCom = CreateFileA(pcCommPort,                    // port name
                         GENERIC_READ | GENERIC_WRITE,  // Read/Write
                         0,                             // No Sharing
                         NULL,                          // No Security
                         OPEN_EXISTING,                 // Open existing port only
                         0,                             // Non Overlapped I/O
                         NULL);                         // Null for Comm Devices

      if (hCom == INVALID_HANDLE_VALUE) {
        if (GetLastError() == 2) {
          sif::error << "COM Port does not found!" << std::endl;
        } else {
          TCHAR err[128];
          FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, GetLastError(),
                        MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), err, sizeof(err), NULL);
          //  Handle the error.
          sif::info << "CreateFileA Error code: " << GetLastError() << std::endl;
          sif::error << err << std::flush;
        }
        sif::info << "Please enter a valid COM port: " << std::flush;
      }
    }
  }

  serialParams.DCBlength = sizeof(serialParams);
  if (baudRate == 9600) {
    serialParams.BaudRate = CBR_9600;
  }
  if (baudRate == 115200) {
    serialParams.BaudRate = CBR_115200;
  } else {
    serialParams.BaudRate = baudRate;
  }

  serialParams.ByteSize = 8;
  serialParams.Parity = NOPARITY;
  serialParams.StopBits = ONESTOPBIT;
  SetCommState(hCom, &serialParams);

  COMMTIMEOUTS timeout = {0};
  // This will set the read operation to be blocking until data is received
  // and then read continuously until there is a gap of one millisecond.
  timeout.ReadIntervalTimeout = 1;
  timeout.ReadTotalTimeoutConstant = 0;
  timeout.ReadTotalTimeoutMultiplier = 0;
  timeout.WriteTotalTimeoutConstant = 0;
  timeout.WriteTotalTimeoutMultiplier = 0;
  SetCommTimeouts(hCom, &timeout);
  // Serial port should now be read for operations.
#endif
}

ArduinoComIF::~ArduinoComIF() {
#ifdef LINUX
  ::close(serialPort);
#elif WIN32
  CloseHandle(hCom);
#endif
}
ReturnValue_t ArduinoComIF::initializeInterface(CookieIF *cookie) { return returnvalue::OK; }

ReturnValue_t ArduinoComIF::sendMessage(CookieIF *cookie, const uint8_t *data, size_t len) {
  ArduinoCookie *arduinoCookie = dynamic_cast<ArduinoCookie *>(cookie);
  if (arduinoCookie == nullptr) {
    return INVALID_COOKIE_TYPE;
  }

  return sendMessage(arduinoCookie->command, arduinoCookie->address, data, len);
}

ReturnValue_t ArduinoComIF::getSendSuccess(CookieIF *cookie) { return returnvalue::OK; }

ReturnValue_t ArduinoComIF::requestReceiveMessage(CookieIF *cookie, size_t requestLen) {
  return returnvalue::OK;
}

ReturnValue_t ArduinoComIF::readReceivedMessage(CookieIF *cookie, uint8_t **buffer, size_t *size) {
  handleSerialPortRx();

  ArduinoCookie *arduinoCookie = dynamic_cast<ArduinoCookie *>(cookie);
  if (arduinoCookie == nullptr) {
    return INVALID_COOKIE_TYPE;
  }

  *buffer = arduinoCookie->replyBuffer.data();
  *size = arduinoCookie->receivedDataLen;
  return returnvalue::OK;
}

ReturnValue_t ArduinoComIF::sendMessage(uint8_t command, uint8_t address, const uint8_t *data,
                                        size_t dataLen) {
  if (dataLen > UINT16_MAX) {
    return TOO_MUCH_DATA;
  }

  // being conservative here
  uint8_t sendBuffer[(dataLen + 6) * 2 + 2];

  sendBuffer[0] = DleEncoder::STX_CHAR;

  uint8_t *currentPosition = sendBuffer + 1;
  size_t remainingLen = sizeof(sendBuffer) - 1;
  size_t encodedLen = 0;

  ReturnValue_t result =
      DleEncoder::encode(&command, 1, currentPosition, remainingLen, &encodedLen, false);
  if (result != returnvalue::OK) {
    return result;
  }
  currentPosition += encodedLen;
  remainingLen -= encodedLen;  // DleEncoder will never return encodedLen > remainingLen

  result = DleEncoder::encode(&address, 1, currentPosition, remainingLen, &encodedLen, false);
  if (result != returnvalue::OK) {
    return result;
  }
  currentPosition += encodedLen;
  remainingLen -= encodedLen;  // DleEncoder will never return encodedLen > remainingLen

  uint8_t temporaryBuffer[2];

  // note to Lukas: yes we _could_ use Serialize here, but for 16 bit it is a bit too much...
  temporaryBuffer[0] = dataLen >> 8;  // we checked dataLen above
  temporaryBuffer[1] = dataLen;

  result =
      DleEncoder::encode(temporaryBuffer, 2, currentPosition, remainingLen, &encodedLen, false);
  if (result != returnvalue::OK) {
    return result;
  }
  currentPosition += encodedLen;
  remainingLen -= encodedLen;  // DleEncoder will never return encodedLen > remainingLen

  // encoding the actual data
  result = DleEncoder::encode(data, dataLen, currentPosition, remainingLen, &encodedLen, false);
  if (result != returnvalue::OK) {
    return result;
  }
  currentPosition += encodedLen;
  remainingLen -= encodedLen;  // DleEncoder will never return encodedLen > remainingLen

  uint16_t crc = CRC::crc16ccitt(&command, 1);
  crc = CRC::crc16ccitt(&address, 1, crc);
  // fortunately the length is still there
  crc = CRC::crc16ccitt(temporaryBuffer, 2, crc);
  crc = CRC::crc16ccitt(data, dataLen, crc);

  temporaryBuffer[0] = crc >> 8;
  temporaryBuffer[1] = crc;

  result =
      DleEncoder::encode(temporaryBuffer, 2, currentPosition, remainingLen, &encodedLen, false);
  if (result != returnvalue::OK) {
    return result;
  }
  currentPosition += encodedLen;
  remainingLen -= encodedLen;  // DleEncoder will never return encodedLen > remainingLen

  if (remainingLen > 0) {
    *currentPosition = DleEncoder::ETX_CHAR;
  }
  remainingLen -= 1;

  encodedLen = sizeof(sendBuffer) - remainingLen;

#ifdef LINUX
  ssize_t writtenlen = ::write(serialPort, sendBuffer, encodedLen);
  if (writtenlen < 0) {
    // we could try to find out what happened...
    return returnvalue::FAILED;
  }
  if (writtenlen != encodedLen) {
    // the OS failed us, we do not try to block until everything is written, as
    // we can not block the whole system here
    return returnvalue::FAILED;
  }
  return returnvalue::OK;
#elif WIN32
  return returnvalue::OK;
#endif
}

void ArduinoComIF::handleSerialPortRx() {
#ifdef LINUX
  uint32_t availableSpace = rxBuffer.availableWriteSpace();

  uint8_t dataFromSerial[availableSpace];

  ssize_t bytesRead = read(serialPort, dataFromSerial, sizeof(dataFromSerial));

  if (bytesRead < 0) {
    return;
  }

  rxBuffer.writeData(dataFromSerial, bytesRead);

  uint8_t dataReceivedSoFar[rxBuffer.getMaxSize()];

  uint32_t dataLenReceivedSoFar = 0;

  rxBuffer.readData(dataReceivedSoFar, sizeof(dataReceivedSoFar), true, &dataLenReceivedSoFar);

  // look for STX
  size_t firstSTXinRawData = 0;
  while ((firstSTXinRawData < dataLenReceivedSoFar) &&
         (dataReceivedSoFar[firstSTXinRawData] != DleEncoder::STX_CHAR)) {
    firstSTXinRawData++;
  }

  if (dataReceivedSoFar[firstSTXinRawData] != DleEncoder::STX_CHAR) {
    // there is no STX in our data, throw it away...
    rxBuffer.deleteData(dataLenReceivedSoFar);
    return;
  }

  uint8_t packet[MAX_PACKET_SIZE];
  size_t packetLen = 0;

  size_t readSize = 0;

  ReturnValue_t result = DleEncoder::decode(dataReceivedSoFar + firstSTXinRawData,
                                            dataLenReceivedSoFar - firstSTXinRawData, &readSize,
                                            packet, sizeof(packet), &packetLen);

  size_t toDelete = firstSTXinRawData;
  if (result == returnvalue::OK) {
    handlePacket(packet, packetLen);

    // after handling the packet, we can delete it from the raw stream,
    // it has been copied to packet
    toDelete += readSize;
  }

  // remove Data which was processed
  rxBuffer.deleteData(toDelete);
#elif WIN32
#endif
}

void ArduinoComIF::setBaudrate(uint32_t baudRate) { this->baudRate = baudRate; }

void ArduinoComIF::handlePacket(uint8_t *packet, size_t packetLen) {
  uint16_t crc = CRC::crc16ccitt(packet, packetLen);
  if (crc != 0) {
    // CRC error
    return;
  }

  uint8_t command = packet[0];
  uint8_t address = packet[1];

  uint16_t size = (packet[2] << 8) + packet[3];

  if (size != packetLen - 6) {
    // Invalid Length
    return;
  }

  switch (command) {
    case ArduinoCookie::SPI: {
      // ArduinoCookie **itsComplicated;
      auto findIter = spiMap.find(address);
      if (findIter == spiMap.end()) {
        // we do no know this address
        return;
      }
      ArduinoCookie &cookie = findIter->second;
      if (packetLen > cookie.maxReplySize + 6) {
        packetLen = cookie.maxReplySize + 6;
      }
      std::memcpy(cookie.replyBuffer.data(), packet + 4, packetLen - 6);
      cookie.receivedDataLen = packetLen - 6;
    } break;
    default:
      return;
  }
}