#include "FusedRotationEstimation.h"

FusedRotationEstimation::FusedRotationEstimation(AcsParameters *acsParameters_) {
  acsParameters = acsParameters_;
}

void FusedRotationEstimation::estimateFusedRotationRateSafe(
    acsctrl::SusDataProcessed *susDataProcessed, acsctrl::MgmDataProcessed *mgmDataProcessed,
    acsctrl::GyrDataProcessed *gyrDataProcessed, acsctrl::FusedRotRateData *fusedRotRateData) {
  if ((not mgmDataProcessed->mgmVecTot.isValid() and not susDataProcessed->susVecTot.isValid() and
       not fusedRotRateData->rotRateTotal.isValid()) or
      (not susDataProcessed->susVecTotDerivative.isValid() and
       not mgmDataProcessed->mgmVecTotDerivative.isValid())) {
    {
      PoolReadGuard pg(fusedRotRateData);
      std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateTotal.value, ZERO_VEC, 3 * sizeof(double));
      fusedRotRateData->setValidity(false, true);
    }
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }
  if (not susDataProcessed->susVecTot.isValid()) {
    estimateFusedRotationRateEclipse(gyrDataProcessed, fusedRotRateData);
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }

  // calculate rotation around the sun
  double magSunCross[3] = {0, 0, 0};

  VectorOperations<double>::cross(mgmDataProcessed->mgmVecTot.value,
                                  susDataProcessed->susVecTot.value, magSunCross);
  double magSunCrossNorm = VectorOperations<double>::norm(magSunCross, 3);
  double magNorm = VectorOperations<double>::norm(mgmDataProcessed->mgmVecTot.value, 3);
  double fusedRotRateParallel[3] = {0, 0, 0};
  if (magSunCrossNorm >
      (acsParameters->safeModeControllerParameters.sineLimitSunRotRate * magNorm)) {
    double omegaParallel =
        VectorOperations<double>::dot(mgmDataProcessed->mgmVecTotDerivative.value, magSunCross) *
        pow(magSunCrossNorm, -2);
    VectorOperations<double>::mulScalar(susDataProcessed->susVecTot.value, omegaParallel,
                                        fusedRotRateParallel, 3);
  } else {
    estimateFusedRotationRateEclipse(gyrDataProcessed, fusedRotRateData);
    // store for calculation of angular acceleration
    if (gyrDataProcessed->gyrVecTot.isValid()) {
      std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
    }
    return;
  }

  // calculate rotation orthogonal to the sun
  double fusedRotRateOrthogonal[3] = {0, 0, 0};
  VectorOperations<double>::cross(susDataProcessed->susVecTotDerivative.value,
                                  susDataProcessed->susVecTot.value, fusedRotRateOrthogonal);
  VectorOperations<double>::mulScalar(
      fusedRotRateOrthogonal,
      pow(VectorOperations<double>::norm(susDataProcessed->susVecTot.value, 3), -2),
      fusedRotRateOrthogonal, 3);

  // calculate total rotation rate
  double fusedRotRateTotal[3] = {0, 0, 0};
  VectorOperations<double>::add(fusedRotRateParallel, fusedRotRateOrthogonal, fusedRotRateTotal);

  {
    PoolReadGuard pg(fusedRotRateData);
    std::memcpy(fusedRotRateData->rotRateOrthogonal.value, fusedRotRateOrthogonal,
                3 * sizeof(double));
    std::memcpy(fusedRotRateData->rotRateParallel.value, fusedRotRateParallel, 3 * sizeof(double));
    std::memcpy(fusedRotRateData->rotRateTotal.value, fusedRotRateTotal, 3 * sizeof(double));
    fusedRotRateData->setValidity(true, true);
  }

  // store for calculation of angular acceleration
  if (gyrDataProcessed->gyrVecTot.isValid()) {
    std::memcpy(rotRateOldB, gyrDataProcessed->gyrVecTot.value, 3 * sizeof(double));
  }
}

void FusedRotationEstimation::estimateFusedRotationRateEclipse(
    acsctrl::GyrDataProcessed *gyrDataProcessed, acsctrl::FusedRotRateData *fusedRotRateData) {
  if (not acsParameters->onBoardParams.fusedRateSafeDuringEclipse or
      not gyrDataProcessed->gyrVecTot.isValid() or
      VectorOperations<double>::norm(fusedRotRateData->rotRateTotal.value, 3) == 0) {
    {
      PoolReadGuard pg(fusedRotRateData);
      std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC, 3 * sizeof(double));
      std::memcpy(fusedRotRateData->rotRateTotal.value, ZERO_VEC, 3 * sizeof(double));
      fusedRotRateData->setValidity(false, true);
    }
    return;
  }
  double angAccelB[3] = {0, 0, 0};
  VectorOperations<double>::subtract(gyrDataProcessed->gyrVecTot.value, rotRateOldB, angAccelB, 3);
  double fusedRotRateTotal[3] = {0, 0, 0};
  VectorOperations<double>::add(fusedRotRateData->rotRateTotal.value, angAccelB, fusedRotRateTotal,
                                3);
  {
    PoolReadGuard pg(fusedRotRateData);
    std::memcpy(fusedRotRateData->rotRateOrthogonal.value, ZERO_VEC, 3 * sizeof(double));
    fusedRotRateData->rotRateOrthogonal.setValid(false);
    std::memcpy(fusedRotRateData->rotRateParallel.value, ZERO_VEC, 3 * sizeof(double));
    fusedRotRateData->rotRateParallel.setValid(false);
    std::memcpy(fusedRotRateData->rotRateTotal.value, fusedRotRateTotal, 3 * sizeof(double));
    fusedRotRateData->rotRateTotal.setValid(true);
  }
}