#include "FusedRotationEstimation.h"

FusedRotationEstimation::FusedRotationEstimation(AcsParameters *acsParameters_) {
  acsParameters = acsParameters_;
}

void FusedRotationEstimation::estimateFusedRotationRate(
    acsctrl::SusDataProcessed *susDataProcessed, acsctrl::MgmDataProcessed *mgmDataProcessed,
    acsctrl::GyrDataProcessed *gyrDataProcessed, ACS::SensorValues *sensorValues,
    acsctrl::AttitudeEstimationData *attitudeEstimationData, const double timeDelta,
    acsctrl::FusedRotRateSourcesData *fusedRotRateSourcesData,
    acsctrl::FusedRotRateData *fusedRotRateData) {
  estimateFusedRotationRateStr(sensorValues, timeDelta, fusedRotRateSourcesData);
  estimateFusedRotationRateQuest(attitudeEstimationData, timeDelta, fusedRotRateSourcesData);
  estimateFusedRotationRateSusMgm(susDataProcessed, mgmDataProcessed, gyrDataProcessed,
                                  fusedRotRateSourcesData);

  if (fusedRotRateSourcesData->rotRateTotalStr.isValid() and
      acsParameters->onBoardParams.fusedRateFromStr) {
    PoolReadGuard pg(fusedRotRateData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateData->rotRateOrthogonal.setValid(false);
      std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateData->rotRateParallel.setValid(false);
      std::memcpy(fusedRotRateData->rotRateTotal.value,
                  fusedRotRateSourcesData->rotRateTotalStr.value, 3 * sizeof(double));
      fusedRotRateData->rotRateTotal.setValid(true);
      fusedRotRateData->rotRateSource.value = acs::rotrate::Source::STR;
      fusedRotRateData->rotRateSource.setValid(true);
    }
  } else if (fusedRotRateSourcesData->rotRateTotalQuest.isValid() and
             acsParameters->onBoardParams.fusedRateFromQuest) {
    PoolReadGuard pg(fusedRotRateData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateData->rotRateOrthogonal.setValid(false);
      std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateData->rotRateParallel.setValid(false);
      std::memcpy(fusedRotRateData->rotRateTotal.value,
                  fusedRotRateSourcesData->rotRateTotalQuest.value, 3 * sizeof(double));
      fusedRotRateData->rotRateTotal.setValid(true);
      fusedRotRateData->rotRateSource.value = acs::rotrate::Source::QUEST;
      fusedRotRateData->rotRateSource.setValid(true);
    }
  } else if (fusedRotRateSourcesData->rotRateTotalSusMgm.isValid()) {
    std::memcpy(fusedRotRateData->rotRateOrthogonal.value,
                fusedRotRateSourcesData->rotRateOrthogonalSusMgm.value, 3 * sizeof(double));
    fusedRotRateData->rotRateOrthogonal.setValid(
        fusedRotRateSourcesData->rotRateOrthogonalSusMgm.isValid());
    std::memcpy(fusedRotRateData->rotRateParallel.value,
                fusedRotRateSourcesData->rotRateParallelSusMgm.value, 3 * sizeof(double));
    fusedRotRateData->rotRateParallel.setValid(
        fusedRotRateSourcesData->rotRateParallelSusMgm.isValid());
    std::memcpy(fusedRotRateData->rotRateTotal.value,
                fusedRotRateSourcesData->rotRateTotalSusMgm.value, 3 * sizeof(double));
    fusedRotRateData->rotRateTotal.setValid(true);
    fusedRotRateData->rotRateSource.value = acs::rotrate::Source::SUSMGM;
    fusedRotRateData->rotRateSource.setValid(true);
  } else {
    PoolReadGuard pg(fusedRotRateData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC3, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC3, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateTotal.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateData->setValidity(false, true);
      fusedRotRateData->rotRateSource.value = acs::rotrate::Source::NONE;
      fusedRotRateData->rotRateSource.setValid(true);
    }
  }
}

void FusedRotationEstimation::estimateFusedRotationRateStr(
    ACS::SensorValues *sensorValues, const double timeDelta,
    acsctrl::FusedRotRateSourcesData *fusedRotRateSourcesData) {
  if (not(sensorValues->strSet.caliQw.isValid() and sensorValues->strSet.caliQx.isValid() and
          sensorValues->strSet.caliQy.isValid() and sensorValues->strSet.caliQz.isValid())) {
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateTotalStr.value, ZERO_VEC3, 3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalStr.setValid(false);
      }
    }
    std::memcpy(quatOldStr, ZERO_VEC4, sizeof(quatOldStr));
    return;
  }

  double quatNew[4] = {sensorValues->strSet.caliQx.value, sensorValues->strSet.caliQy.value,
                       sensorValues->strSet.caliQz.value, sensorValues->strSet.caliQw.value};
  if (VectorOperations<double>::norm(quatOldStr, 4) != 0 and timeDelta != 0) {
    double quatOldInv[4] = {0, 0, 0, 0};
    double quatDelta[4] = {0, 0, 0, 0};

    QuaternionOperations::inverse(quatOldStr, quatOldInv);
    QuaternionOperations::multiply(quatNew, quatOldInv, quatDelta);
    if (VectorOperations<double>::norm(quatDelta, 4) != 0.0) {
      QuaternionOperations::normalize(quatDelta);
    }

    double rotVec[3] = {0, 0, 0};
    double angle = QuaternionOperations::getAngle(quatDelta);
    if (VectorOperations<double>::norm(quatDelta, 3) == 0.0) {
      {
        PoolReadGuard pg(fusedRotRateSourcesData);
        if (pg.getReadResult() == returnvalue::OK) {
          std::memcpy(fusedRotRateSourcesData->rotRateTotalStr.value, ZERO_VEC3,
                      3 * sizeof(double));
          fusedRotRateSourcesData->rotRateTotalStr.setValid(true);
        }
      }
      std::memcpy(quatOldStr, quatNew, sizeof(quatOldStr));
      return;
    }
    VectorOperations<double>::normalize(quatDelta, rotVec, 3);
    VectorOperations<double>::mulScalar(rotVec, angle / timeDelta, rotVec, 3);
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateTotalStr.value, rotVec, 3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalStr.setValid(true);
      }
    }
    std::memcpy(quatOldStr, quatNew, sizeof(quatOldStr));
    return;
  }
  {
    PoolReadGuard pg(fusedRotRateSourcesData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateSourcesData->rotRateTotalStr.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateSourcesData->rotRateTotalStr.setValid(false);
    }
  }
  std::memcpy(quatOldStr, quatNew, sizeof(quatOldStr));
  return;
}

void FusedRotationEstimation::estimateFusedRotationRateQuest(
    acsctrl::AttitudeEstimationData *attitudeEstimationData, const double timeDelta,
    acsctrl::FusedRotRateSourcesData *fusedRotRateSourcesData) {
  if (not attitudeEstimationData->quatQuest.isValid()) {
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateTotalQuest.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalQuest.setValid(false);
      }
    }
    std::memcpy(quatOldQuest, ZERO_VEC4, sizeof(quatOldQuest));
  }

  if (VectorOperations<double>::norm(quatOldQuest, 4) != 0 and timeDelta != 0) {
    double quatOldInv[4] = {0, 0, 0, 0};
    double quatDelta[4] = {0, 0, 0, 0};

    QuaternionOperations::inverse(quatOldQuest, quatOldInv);
    QuaternionOperations::multiply(attitudeEstimationData->quatQuest.value, quatOldInv, quatDelta);
    if (VectorOperations<double>::norm(quatDelta, 4) != 0.0) {
      QuaternionOperations::normalize(quatDelta);
    }

    double rotVec[3] = {0, 0, 0};
    double angle = QuaternionOperations::getAngle(quatDelta);
    if (VectorOperations<double>::norm(quatDelta, 3) == 0.0) {
      {
        PoolReadGuard pg(fusedRotRateSourcesData);
        if (pg.getReadResult() == returnvalue::OK) {
          std::memcpy(fusedRotRateSourcesData->rotRateTotalQuest.value, ZERO_VEC3,
                      3 * sizeof(double));
          fusedRotRateSourcesData->rotRateTotalQuest.setValid(true);
        }
      }
      std::memcpy(quatOldQuest, attitudeEstimationData->quatQuest.value, sizeof(quatOldQuest));
      return;
    }
    VectorOperations<double>::normalize(quatDelta, rotVec, 3);
    VectorOperations<double>::mulScalar(rotVec, angle / timeDelta, rotVec, 3);
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateTotalQuest.value, rotVec, 3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalQuest.setValid(true);
      }
    }
    std::memcpy(quatOldQuest, attitudeEstimationData->quatQuest.value, sizeof(quatOldQuest));
    return;
  }
  {
    PoolReadGuard pg(fusedRotRateSourcesData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateSourcesData->rotRateTotalQuest.value, ZERO_VEC3, 3 * sizeof(double));
      fusedRotRateSourcesData->rotRateTotalQuest.setValid(false);
    }
  }
  std::memcpy(quatOldQuest, attitudeEstimationData->quatQuest.value, sizeof(quatOldQuest));
  return;
}

void FusedRotationEstimation::estimateFusedRotationRateSusMgm(
    acsctrl::SusDataProcessed *susDataProcessed, acsctrl::MgmDataProcessed *mgmDataProcessed,
    acsctrl::GyrDataProcessed *gyrDataProcessed,
    acsctrl::FusedRotRateSourcesData *fusedRotRateSourcesData) {
  if ((not mgmDataProcessed->mgmVecTot.isValid() and not susDataProcessed->susVecTot.isValid() and
       not fusedRotRateSourcesData->rotRateTotalSusMgm.isValid()) or
      (not susDataProcessed->susVecTotDerivative.isValid() and
       not mgmDataProcessed->mgmVecTotDerivative.isValid())) {
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateOrthogonalSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateOrthogonalSusMgm.setValid(false);
        std::memcpy(fusedRotRateSourcesData->rotRateParallelSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateParallelSusMgm.setValid(false);
        std::memcpy(fusedRotRateSourcesData->rotRateTotalSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalSusMgm.setValid(false);
      }
    }
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }
  if (not susDataProcessed->susVecTot.isValid()) {
    estimateFusedRotationRateEclipse(gyrDataProcessed, fusedRotRateSourcesData);
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }

  // calculate rotation around the sun
  double magSunCross[3] = {0, 0, 0};

  VectorOperations<double>::cross(mgmDataProcessed->mgmVecTot.value,
                                  susDataProcessed->susVecTot.value, magSunCross);
  double magSunCrossNorm = VectorOperations<double>::norm(magSunCross, 3);
  double magNorm = VectorOperations<double>::norm(mgmDataProcessed->mgmVecTot.value, 3);
  double fusedRotRateParallel[3] = {0, 0, 0};
  if (magSunCrossNorm >
      (acsParameters->safeModeControllerParameters.sineLimitSunRotRate * magNorm)) {
    double omegaParallel =
        VectorOperations<double>::dot(mgmDataProcessed->mgmVecTotDerivative.value, magSunCross) *
        pow(magSunCrossNorm, -2);
    VectorOperations<double>::mulScalar(susDataProcessed->susVecTot.value, omegaParallel,
                                        fusedRotRateParallel, 3);
  } else {
    estimateFusedRotationRateEclipse(gyrDataProcessed, fusedRotRateSourcesData);
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }

  // calculate rotation orthogonal to the sun
  double fusedRotRateOrthogonal[3] = {0, 0, 0};
  VectorOperations<double>::cross(susDataProcessed->susVecTotDerivative.value,
                                  susDataProcessed->susVecTot.value, fusedRotRateOrthogonal);
  VectorOperations<double>::mulScalar(
      fusedRotRateOrthogonal,
      pow(VectorOperations<double>::norm(susDataProcessed->susVecTot.value, 3), -2),
      fusedRotRateOrthogonal, 3);

  // calculate total rotation rate
  double fusedRotRateTotal[3] = {0, 0, 0};
  VectorOperations<double>::add(fusedRotRateParallel, fusedRotRateOrthogonal, fusedRotRateTotal);

  {
    PoolReadGuard pg(fusedRotRateSourcesData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateSourcesData->rotRateOrthogonalSusMgm.value, fusedRotRateOrthogonal,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateOrthogonalSusMgm.setValid(true);
      std::memcpy(fusedRotRateSourcesData->rotRateParallelSusMgm.value, fusedRotRateParallel,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateParallelSusMgm.setValid(true);
      std::memcpy(fusedRotRateSourcesData->rotRateTotalSusMgm.value, fusedRotRateTotal,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateTotalSusMgm.setValid(true);
    }
  }

  // store for calculation of angular acceleration
  if (gyrDataProcessed->gyrVecTot.isValid()) {
    std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
  }
}

void FusedRotationEstimation::estimateFusedRotationRateEclipse(
    acsctrl::GyrDataProcessed *gyrDataProcessed,
    acsctrl::FusedRotRateSourcesData *fusedRotRateSourcesData) {
  if (not acsParameters->onBoardParams.fusedRateSafeDuringEclipse or
      not gyrDataProcessed->gyrVecTot.isValid() or
      VectorOperations<double>::norm(fusedRotRateSourcesData->rotRateTotalSusMgm.value, 3) == 0) {
    {
      PoolReadGuard pg(fusedRotRateSourcesData);
      if (pg.getReadResult() == returnvalue::OK) {
        std::memcpy(fusedRotRateSourcesData->rotRateOrthogonalSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateOrthogonalSusMgm.setValid(false);
        std::memcpy(fusedRotRateSourcesData->rotRateParallelSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateParallelSusMgm.setValid(false);
        std::memcpy(fusedRotRateSourcesData->rotRateTotalSusMgm.value, ZERO_VEC3,
                    3 * sizeof(double));
        fusedRotRateSourcesData->rotRateTotalSusMgm.setValid(false);
      }
    }
    return;
  }
  double angAccelB[3] = {0, 0, 0};
  VectorOperations<double>::subtract(gyrDataProcessed->gyrVecTot.value, rotRateOldB, angAccelB, 3);
  double fusedRotRateTotal[3] = {0, 0, 0};
  VectorOperations<double>::add(fusedRotRateSourcesData->rotRateTotalSusMgm.value, angAccelB,
                                fusedRotRateTotal, 3);
  {
    PoolReadGuard pg(fusedRotRateSourcesData);
    if (pg.getReadResult() == returnvalue::OK) {
      std::memcpy(fusedRotRateSourcesData->rotRateOrthogonalSusMgm.value, ZERO_VEC3,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateOrthogonalSusMgm.setValid(false);
      std::memcpy(fusedRotRateSourcesData->rotRateParallelSusMgm.value, ZERO_VEC3,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateParallelSusMgm.setValid(false);
      std::memcpy(fusedRotRateSourcesData->rotRateTotalSusMgm.value, fusedRotRateTotal,
                  3 * sizeof(double));
      fusedRotRateSourcesData->rotRateTotalSusMgm.setValid(true);
    }
  }
}