#include "SafeCtrl.h"

#include <fsfw/globalfunctions/math/MatrixOperations.h>
#include <fsfw/globalfunctions/math/QuaternionOperations.h>
#include <fsfw/globalfunctions/math/VectorOperations.h>
#include <math.h>

SafeCtrl::SafeCtrl(AcsParameters *acsParameters_) { acsParameters = acsParameters_; }

SafeCtrl::~SafeCtrl() {}

acs::ControlModeStrategy SafeCtrl::safeCtrlStrategy(
    const bool magFieldValid, const bool mekfValid, const bool satRotRateValid,
    const bool sunDirValid, const bool fusedRateTotalValid, const uint8_t mekfEnabled,
    const uint8_t gyrEnabled, const uint8_t dampingEnabled) {
  if (not magFieldValid) {
    return acs::ControlModeStrategy::CTRL_NO_MAG_FIELD_FOR_CONTROL;
  } else if (mekfEnabled and mekfValid) {
    return acs::ControlModeStrategy::SAFECTRL_MEKF;
  } else if (sunDirValid) {
    if (gyrEnabled and satRotRateValid) {
      return acs::ControlModeStrategy::SAFECTRL_GYR;
    } else if (not gyrEnabled and fusedRateTotalValid) {
      return acs::ControlModeStrategy::SAFECTRL_SUSMGM;
    } else {
      return acs::ControlModeStrategy::CTRL_NO_SENSORS_FOR_CONTROL;
    }
  } else if (not sunDirValid) {
    if (dampingEnabled) {
      if (gyrEnabled and satRotRateValid) {
        return acs::ControlModeStrategy::SAFECTRL_ECLIPSE_DAMPING_GYR;
      } else if (not gyrEnabled and satRotRateValid and fusedRateTotalValid) {
        return acs::ControlModeStrategy::SAFECTRL_ECLIPSE_DAMPING_SUSMGM;
      } else {
        return acs::ControlModeStrategy::CTRL_NO_SENSORS_FOR_CONTROL;
      }
    } else if (not dampingEnabled and satRotRateValid) {
      return acs::ControlModeStrategy::SAFECTRL_ECLIPSE_IDELING;
    } else {
      return acs::ControlModeStrategy::CTRL_NO_SENSORS_FOR_CONTROL;
    }
  } else {
    return acs::ControlModeStrategy::CTRL_NO_SENSORS_FOR_CONTROL;
  }
}

void SafeCtrl::safeMekf(const double *magFieldB, const double *satRotRateB,
                        const double *sunDirModelI, const double *quatBI, const double *sunDirRefB,
                        double *magMomB, double &errorAngle) {
  // convert magFieldB from uT to T
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // convert sunDirModel to body rf
  double sunDirB[3] = {0, 0, 0};
  QuaternionOperations::multiplyVector(quatBI, sunDirModelI, sunDirB);

  // calculate angle alpha between sunDirRef and sunDir
  double dotSun = VectorOperations<double>::dot(sunDirRefB, sunDirB);
  errorAngle = acos(dotSun);

  splitRotationalRate(satRotRateB, sunDirB);
  calculateRotationalRateTorque(acsParameters->safeModeControllerParameters.k_parallelMekf,
                                acsParameters->safeModeControllerParameters.k_orthoMekf);
  calculateAngleErrorTorque(sunDirB, sunDirRefB,
                            acsParameters->safeModeControllerParameters.k_alignMekf);

  // sum of all torques
  for (uint8_t i = 0; i < 3; i++) {
    cmdTorque[i] = cmdAlign[i] + cmdOrtho[i] + cmdParallel[i];
  }

  calculateMagneticMoment(magMomB);
}

void SafeCtrl::safeGyr(const double *magFieldB, const double *satRotRateB, const double *sunDirB,
                       const double *sunDirRefB, double *magMomB, double &errorAngle) {
  // convert magFieldB from uT to T
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // calculate error angle between sunDirRef and sunDir
  double dotSun = VectorOperations<double>::dot(sunDirRefB, sunDirB);
  errorAngle = acos(dotSun);

  splitRotationalRate(satRotRateB, sunDirB);
  calculateRotationalRateTorque(acsParameters->safeModeControllerParameters.k_parallelGyr,
                                acsParameters->safeModeControllerParameters.k_orthoGyr);
  calculateAngleErrorTorque(sunDirB, sunDirRefB,
                            acsParameters->safeModeControllerParameters.k_alignGyr);

  // sum of all torques
  for (uint8_t i = 0; i < 3; i++) {
    cmdTorque[i] = cmdAlign[i] + cmdOrtho[i] + cmdParallel[i];
  }

  calculateMagneticMoment(magMomB);
}

void SafeCtrl::safeSusMgm(const double *magFieldB, const double *rotRateTotalB,
                          const double *rotRateParallelB, const double *rotRateOrthogonalB,
                          const double *sunDirB, const double *sunDirRefB, double *magMomB,
                          double &errorAngle) {
  // convert magFieldB from uT to T
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // calculate error angle between sunDirRef and sunDir
  double dotSun = VectorOperations<double>::dot(sunDirRefB, sunDirB);
  errorAngle = acos(dotSun);

  if (VectorOperations<double>::norm(rotRateParallelB, 3) != 0 and
      VectorOperations<double>::norm(rotRateOrthogonalB, 3) != 0) {
    std::memcpy(satRotRateParallelB, rotRateParallelB, sizeof(satRotRateParallelB));
    std::memcpy(satRotRateOrthogonalB, rotRateOrthogonalB, sizeof(satRotRateOrthogonalB));
  } else {
    splitRotationalRate(rotRateTotalB, sunDirB);
  }

  calculateRotationalRateTorque(acsParameters->safeModeControllerParameters.k_parallelSusMgm,
                                acsParameters->safeModeControllerParameters.k_orthoSusMgm);
  calculateAngleErrorTorque(sunDirB, sunDirRefB,
                            acsParameters->safeModeControllerParameters.k_alignSusMgm);

  // sum of all torques
  for (uint8_t i = 0; i < 3; i++) {
    cmdTorque[i] = cmdAlign[i] + cmdOrtho[i] + cmdParallel[i];
  }

  calculateMagneticMoment(magMomB);
}

void SafeCtrl::safeRateDampingGyr(const double *magFieldB, const double *satRotRateB,
                                  const double *sunDirRefB, double *magMomB, double &errorAngle) {
  // convert magFieldB from uT to T
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // no error angle available for eclipse
  errorAngle = NAN;

  splitRotationalRate(satRotRateB, sunDirRefB);
  calculateRotationalRateTorque(acsParameters->safeModeControllerParameters.k_parallelGyr,
                                acsParameters->safeModeControllerParameters.k_orthoGyr);

  // sum of all torques
  VectorOperations<double>::add(cmdParallel, cmdOrtho, cmdTorque, 3);

  // calculate magnetic moment to command
  calculateMagneticMoment(magMomB);
}

void SafeCtrl::safeRateDampingSusMgm(const double *magFieldB, const double *satRotRateB,
                                     const double *sunDirRefB, double *magMomB,
                                     double &errorAngle) {
  // convert magFieldB from uT to T
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // no error angle available for eclipse
  errorAngle = NAN;

  splitRotationalRate(satRotRateB, sunDirRefB);
  calculateRotationalRateTorque(acsParameters->safeModeControllerParameters.k_parallelSusMgm,
                                acsParameters->safeModeControllerParameters.k_orthoSusMgm);

  // sum of all torques
  VectorOperations<double>::add(cmdParallel, cmdOrtho, cmdTorque, 3);

  // calculate magnetic moment to command
  calculateMagneticMoment(magMomB);
}

void SafeCtrl::splitRotationalRate(const double *satRotRateB, const double *sunDirB) {
  // split rotational rate into parallel and orthogonal parts
  double parallelLength = VectorOperations<double>::dot(satRotRateB, sunDirB) *
                          pow(VectorOperations<double>::norm(sunDirB, 3), -2);
  VectorOperations<double>::mulScalar(sunDirB, parallelLength, satRotRateParallelB, 3);
  VectorOperations<double>::subtract(satRotRateB, satRotRateParallelB, satRotRateOrthogonalB, 3);
}

void SafeCtrl::calculateRotationalRateTorque(const double gainParallel, const double gainOrtho) {
  // calculate torque for parallel rotational rate
  VectorOperations<double>::mulScalar(satRotRateParallelB, -gainParallel, cmdParallel, 3);

  // calculate torque for orthogonal rotational rate
  VectorOperations<double>::mulScalar(satRotRateOrthogonalB, -gainOrtho, cmdOrtho, 3);
}

void SafeCtrl::calculateAngleErrorTorque(const double *sunDirB, const double *sunDirRefB,
                                         const double gainAlign) {
  // calculate torque for alignment
  double crossAlign[3] = {0, 0, 0};
  VectorOperations<double>::cross(sunDirRefB, sunDirB, crossAlign);
  VectorOperations<double>::mulScalar(crossAlign, gainAlign, cmdAlign, 3);
}

void SafeCtrl::calculateMagneticMoment(double *magMomB) {
  double torqueMgt[3] = {0, 0, 0};
  VectorOperations<double>::cross(magFieldBT, cmdTorque, torqueMgt);
  double normMag = VectorOperations<double>::norm(magFieldBT, 3);
  VectorOperations<double>::mulScalar(torqueMgt, pow(normMag, -2), magMomB, 3);
}