// [[Rcpp::depends(Rcpp)]]
#define RCPP_NO_RTTI
#define RCPP_NO_SUGAR

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
void viterbi(NumericVector transitionMatrix, NumericVector initialProbs, NumericVector emissionProbs,
             NumericVector durationParams, NumericVector finalDurationParams, int sequenceLength,
             int numberOfStates, IntegerVector maxDurations, NumericVector forwardProbs,
             IntegerVector optimalPath, IntegerVector optimalPrevStates,
             IntegerVector optimalDurations){

  const int timeSteps = sequenceLength;
  const int stateCount = numberOfStates;
  const double MIN_SCORE = -1e300;

  std::vector<std::vector<double>> stateTransitionScores(stateCount, std::vector<double>(timeSteps));
  std::vector<double*> emissions(stateCount), durations(stateCount);
  std::vector<double*> finalDurations(stateCount), pathScores(stateCount);
  std::vector<int*> bestDurations(stateCount), bestPrevStates(stateCount);

  for(int state = 0; state < stateCount; ++state) {
    emissions[state] = REAL(emissionProbs) + state * timeSteps;
    durations[state] = REAL(durationParams) + state * maxDurations[state];
    finalDurations[state] = REAL(finalDurationParams) + state * maxDurations[state];
    pathScores[state] = REAL(forwardProbs) + state * timeSteps;
    bestDurations[state] = INTEGER(optimalDurations) + state * timeSteps;
    bestPrevStates[state] = INTEGER(optimalPrevStates) + state * timeSteps;
  }

  for(int time = 0; time < timeSteps; ++time) {
    for(int state = 0; state < stateCount; ++state) {
      double emissionSum = 0.0;
      double bestScore = MIN_SCORE;

      const int maxPossibleDuration = std::min(time + 1, maxDurations[state]);

      if(time < timeSteps - 1) {

        for(int duration = 1; duration <= maxPossibleDuration; ++duration) {
          if(duration <= time) {

            double currentScore = emissionSum + durations[state][duration - 1] +
              stateTransitionScores[state][time - duration + 1];
            if(duration == 1 || bestScore < currentScore) {
              bestScore = currentScore;
              bestDurations[state][time] = duration;
            }
            emissionSum += emissions[state][time - duration];
          } else {

            double currentScore = emissionSum + durations[state][time] + initialProbs[state];
            if(duration == 1 || bestScore < currentScore) {
              bestScore = currentScore;
              bestDurations[state][time] = duration;
            }
          }
        }
        pathScores[state][time] = bestScore + emissions[state][time];
      } else {

        for(int duration = 1; duration <= maxPossibleDuration; ++duration) {
          if(duration < timeSteps) {
            double currentScore = emissionSum + finalDurations[state][duration - 1] +
              stateTransitionScores[state][time - duration + 1];
            if(duration == 1 || bestScore < currentScore) {
              bestScore = currentScore;
              bestDurations[state][time] = duration;
            }
            emissionSum += emissions[state][timeSteps - 1 - duration];
          } else {
            double currentScore = emissionSum + finalDurations[state][timeSteps - 1] + initialProbs[state];
            if(duration == 1 || bestScore < currentScore) {
              bestScore = currentScore;
              bestDurations[state][time] = duration;
            }
          }
        }
        pathScores[state][time] = bestScore + emissions[state][time];
      }
    }

    if(time < timeSteps - 1) {

      for(int currentState = 0; currentState < stateCount; ++currentState) {
        stateTransitionScores[currentState][time + 1] = REAL(transitionMatrix)[currentState * stateCount] +
          pathScores[0][time];
        bestPrevStates[currentState][time + 1] = 0;

        for(int prevState = 1; prevState < stateCount; ++prevState) {
          if(prevState != currentState) {
            double transitionScore = REAL(transitionMatrix)[currentState * stateCount + prevState] +
              pathScores[prevState][time];
            if(stateTransitionScores[currentState][time + 1] <= transitionScore) {
              stateTransitionScores[currentState][time + 1] = transitionScore;
              bestPrevStates[currentState][time + 1] = prevState;
            }
          }
        }
      }
    }
  }

  int* statePath = INTEGER(optimalPath);
  statePath[timeSteps - 1] = 0;

  for(int state = 1; state < stateCount; ++state) {
    if(pathScores[statePath[timeSteps - 1]][timeSteps - 1] < pathScores[state][timeSteps - 1]) {
      statePath[timeSteps - 1] = state;
    }
  }

  int durationCounter = 1;
  for(int t = timeSteps - 2; t >= 0; --t) {
    if(durationCounter < bestDurations[statePath[t + durationCounter]][t + durationCounter]) {

      statePath[t] = statePath[t + durationCounter];
      durationCounter++;
    } else {

      statePath[t] = bestPrevStates[statePath[t + durationCounter]][t + durationCounter];
      durationCounter = 1;
    }
  }
}

