#include "PtgCtrl.h"

PtgCtrl::PtgCtrl(AcsParameters *acsParameters_) { acsParameters = acsParameters_; }

PtgCtrl::~PtgCtrl() {}

acs::ControlModeStrategy PtgCtrl::pointingCtrlStrategy(
    const bool magFieldValid, const bool mekfValid, const bool strValid, const bool questValid,
    const bool fusedRateValid, const uint8_t rotRateSource, const uint8_t mekfEnabled) {
  if (not magFieldValid) {
    return acs::ControlModeStrategy::CTRL_NO_MAG_FIELD_FOR_CONTROL;
  } else if (strValid and fusedRateValid and rotRateSource > acs::rotrate::Source::SUSMGM) {
    return acs::ControlModeStrategy::PTGCTRL_STR;
  } else if (mekfEnabled and mekfValid) {
    return acs::ControlModeStrategy::PTGCTRL_MEKF;
  } else if (questValid and fusedRateValid and rotRateSource > acs::rotrate::Source::SUSMGM) {
    return acs::ControlModeStrategy::PTGCTRL_QUEST;
  }
  return acs::ControlModeStrategy::CTRL_NO_SENSORS_FOR_CONTROL;
}

void PtgCtrl::ptgLaw(AcsParameters::PointingLawParameters *pointingLawParameters,
                     const double *errorQuat, const double *deltaRate, const double *rwPseudoInv,
                     double *torqueRws) {
  //------------------------------------------------------------------------------------------------
  // Compute gain matrix K and P matrix
  //------------------------------------------------------------------------------------------------
  double om = pointingLawParameters->om;
  double zeta = pointingLawParameters->zeta;
  double qErrorMin = pointingLawParameters->qiMin;
  double omMax = pointingLawParameters->omMax;

  double qError[3] = {errorQuat[0], errorQuat[1], errorQuat[2]};

  double cInt = 2 * om * zeta;
  double kInt = 2 * om * om;

  double qErrorLaw[3] = {0, 0, 0};

  for (int i = 0; i < 3; i++) {
    if (std::abs(qError[i]) < qErrorMin) {
      qErrorLaw[i] = qErrorMin;
    } else {
      qErrorLaw[i] = std::abs(qError[i]);
    }
  }

  double qErrorLawNorm = VectorOperations<double>::norm(qErrorLaw, 3);

  double gain1 = cInt * omMax / qErrorLawNorm;
  double gainVector[3] = {0, 0, 0};
  VectorOperations<double>::mulScalar(qErrorLaw, gain1, gainVector, 3);

  double gainMatrixDiagonal[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
  double gainMatrix[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
  gainMatrixDiagonal[0][0] = gainVector[0];
  gainMatrixDiagonal[1][1] = gainVector[1];
  gainMatrixDiagonal[2][2] = gainVector[2];
  MatrixOperations<double>::multiply(*gainMatrixDiagonal,
                                     *(acsParameters->inertiaEIVE.inertiaMatrixDeployed),
                                     *gainMatrix, 3, 3, 3);

  // Inverse of gainMatrix
  double gainMatrixInverse[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
  gainMatrixInverse[0][0] = 1. / gainMatrix[0][0];
  gainMatrixInverse[1][1] = 1. / gainMatrix[1][1];
  gainMatrixInverse[2][2] = 1. / gainMatrix[2][2];

  double pMatrix[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
  MatrixOperations<double>::multiply(
      *gainMatrixInverse, *(acsParameters->inertiaEIVE.inertiaMatrixDeployed), *pMatrix, 3, 3, 3);
  MatrixOperations<double>::multiplyScalar(*pMatrix, kInt, *pMatrix, 3, 3);

  //------------------------------------------------------------------------------------------------
  // Torque Calculations for the reaction wheels
  //------------------------------------------------------------------------------------------------

  double pError[3] = {0, 0, 0};
  MatrixOperations<double>::multiply(*pMatrix, qError, pError, 3, 3, 1);
  double pErrorSign[3] = {0, 0, 0};

  for (int i = 0; i < 3; i++) {
    if (std::abs(pError[i]) > 1) {
      pErrorSign[i] = sign(pError[i]);
    } else {
      pErrorSign[i] = pError[i];
    }
  }
  // torque for quaternion error
  double torqueQuat[3] = {0, 0, 0};
  MatrixOperations<double>::multiply(*gainMatrix, pErrorSign, torqueQuat, 3, 3, 1);
  VectorOperations<double>::mulScalar(torqueQuat, -1, torqueQuat, 3);

  // torque for rate error
  double torqueRate[3] = {0, 0, 0};
  MatrixOperations<double>::multiply(*(acsParameters->inertiaEIVE.inertiaMatrixDeployed), deltaRate,
                                     torqueRate, 3, 3, 1);
  VectorOperations<double>::mulScalar(torqueRate, cInt, torqueRate, 3);
  VectorOperations<double>::mulScalar(torqueRate, -1, torqueRate, 3);

  // final commanded Torque for every reaction wheel
  double torque[3] = {0, 0, 0};
  VectorOperations<double>::add(torqueRate, torqueQuat, torque, 3);
  MatrixOperations<double>::multiply(rwPseudoInv, torque, torqueRws, 4, 3, 1);
  VectorOperations<double>::mulScalar(torqueRws, -1, torqueRws, 4);
}

void PtgCtrl::ptgNullspace(const bool allRwAvabilable,
                           AcsParameters::PointingLawParameters *pointingLawParameters,
                           const int32_t speedRw0, const int32_t speedRw1, const int32_t speedRw2,
                           const int32_t speedRw3, double *rwTrqNs) {
  if (not allRwAvabilable) {
    return;
  }
  // concentrate RW speeds as vector and convert to double
  double speedRws[4] = {static_cast<double>(speedRw0), static_cast<double>(speedRw1),
                        static_cast<double>(speedRw2), static_cast<double>(speedRw3)};
  VectorOperations<double>::mulScalar(speedRws, 1e-1, speedRws, 4);
  VectorOperations<double>::mulScalar(speedRws, RPM_TO_RAD_PER_SEC, speedRws, 4);

  // calculate RPM offset utilizing the nullspace
  double rpmOffset[4] = {0, 0, 0, 0};
  double rpmOffsetSpeed = pointingLawParameters->nullspaceSpeed / 10 * RPM_TO_RAD_PER_SEC;
  VectorOperations<double>::mulScalar(acsParameters->rwMatrices.nullspaceVector, rpmOffsetSpeed,
                                      rpmOffset, 4);

  // calculate resulting angular momentum
  double rwAngMomentum[4] = {0, 0, 0, 0}, diffRwSpeed[4] = {0, 0, 0, 0};
  VectorOperations<double>::subtract(speedRws, rpmOffset, diffRwSpeed, 4);
  VectorOperations<double>::mulScalar(diffRwSpeed, acsParameters->rwHandlingParameters.inertiaWheel,
                                      rwAngMomentum, 4);

  // calculate resulting torque
  double nullspaceMatrix[4][4] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
  MatrixOperations<double>::multiply(acsParameters->rwMatrices.nullspaceVector,
                                     acsParameters->rwMatrices.nullspaceVector, *nullspaceMatrix, 4,
                                     1, 4);
  MatrixOperations<double>::multiply(*nullspaceMatrix, rwAngMomentum, rwTrqNs, 4, 4, 1);
  VectorOperations<double>::mulScalar(rwTrqNs, -1 * pointingLawParameters->gainNullspace, rwTrqNs,
                                      4);
}

void PtgCtrl::ptgDesaturation(const bool allRwAvailable, const acsctrl::RwAvail *rwAvail,
                              AcsParameters::PointingLawParameters *pointingLawParameters,
                              const double *magFieldB, const bool magFieldBValid,
                              const int32_t speedRw0, const int32_t speedRw1,
                              const int32_t speedRw2, const int32_t speedRw3, double *mgtDpDes) {
  if (not magFieldBValid or not pointingLawParameters->desatOn) {
    return;
  }

  // concentrate RW speeds as vector and convert to double
  double speedRws[4] = {static_cast<double>(speedRw0), static_cast<double>(speedRw1),
                        static_cast<double>(speedRw2), static_cast<double>(speedRw3)};

  // convert magFieldB from uT to T
  double magFieldBT[3] = {0, 0, 0};
  VectorOperations<double>::mulScalar(magFieldB, 1e-6, magFieldBT, 3);

  // calculate angular momentum of the reaction wheels with respect to the nullspace RW speed
  // relocate RW speed zero to nullspace RW speed
  double refSpeedRws[4] = {0, 0, 0, 0};
  VectorOperations<double>::mulScalar(acsParameters->rwMatrices.nullspaceVector,
                                      pointingLawParameters->nullspaceSpeed, refSpeedRws, 4);
  if (not allRwAvailable) {
    if (not rwAvail->rw1avail) {
      refSpeedRws[0] = 0.0;
    } else if (not rwAvail->rw2avail) {
      refSpeedRws[1] = 0.0;
    } else if (not rwAvail->rw3avail) {
      refSpeedRws[2] = 0.0;
    } else if (not rwAvail->rw4avail) {
      refSpeedRws[3] = 0.0;
    }
  }
  VectorOperations<double>::subtract(speedRws, refSpeedRws, speedRws, 4);

  // convert speed from 10 RPM to 1 RPM
  VectorOperations<double>::mulScalar(speedRws, 1e-1, speedRws, 4);
  // convert to rad/s
  VectorOperations<double>::mulScalar(speedRws, RPM_TO_RAD_PER_SEC, speedRws, 4);
  // calculate angular momentum of each RW
  double angMomentumRwU[4] = {0, 0, 0, 0};
  VectorOperations<double>::mulScalar(speedRws, acsParameters->rwHandlingParameters.inertiaWheel,
                                      angMomentumRwU, 4);
  // convert RW angular momentum to body RF
  double angMomentumRw[3] = {0, 0, 0};
  MatrixOperations<double>::multiply(*(acsParameters->rwMatrices.alignmentMatrix), angMomentumRwU,
                                     angMomentumRw, 3, 4, 1);

  // calculate total angular momentum
  double angMomentumTotal[3] = {0, 0, 0};
  VectorOperations<double>::subtract(angMomentumRw, pointingLawParameters->desatMomentumRef,
                                     angMomentumTotal, 3);

  // resulting magnetic dipole command
  double crossAngMomentumMagField[3] = {0, 0, 0};
  VectorOperations<double>::cross(angMomentumTotal, magFieldBT, crossAngMomentumMagField);
  double factor =
      pointingLawParameters->deSatGainFactor / VectorOperations<double>::norm(magFieldBT, 3);
  VectorOperations<double>::mulScalar(crossAngMomentumMagField, factor, mgtDpDes, 3);
}

void PtgCtrl::rwAntistiction(ACS::SensorValues *sensorValues, int32_t *rwCmdSpeeds) {
  bool rwAvailable[4] = {
      (sensorValues->rw1Set.state.value && sensorValues->rw1Set.state.isValid()),
      (sensorValues->rw2Set.state.value && sensorValues->rw2Set.state.isValid()),
      (sensorValues->rw3Set.state.value && sensorValues->rw3Set.state.isValid()),
      (sensorValues->rw4Set.state.value && sensorValues->rw4Set.state.isValid())};
  int32_t currRwSpeed[4] = {
      sensorValues->rw1Set.currSpeed.value, sensorValues->rw2Set.currSpeed.value,
      sensorValues->rw3Set.currSpeed.value, sensorValues->rw4Set.currSpeed.value};
  for (uint8_t i = 0; i < 4; i++) {
    if (rwAvailable[i]) {
      if (rwCmdSpeeds[i] != 0) {
        if (rwCmdSpeeds[i] > -acsParameters->rwHandlingParameters.stictionSpeed &&
            rwCmdSpeeds[i] < acsParameters->rwHandlingParameters.stictionSpeed) {
          if (rwCmdSpeeds[i] > currRwSpeed[i]) {
            rwCmdSpeeds[i] = acsParameters->rwHandlingParameters.stictionSpeed;
          } else if (rwCmdSpeeds[i] < currRwSpeed[i]) {
            rwCmdSpeeds[i] = -acsParameters->rwHandlingParameters.stictionSpeed;
          } else {
            rwCmdSpeeds[i] = 0;
          }
        }
      }
    } else {
      rwCmdSpeeds[i] = 0;
    }
  }
}