#include <RcppArmadillo.h>
#include <cmath>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;
using namespace arma;

// [[Rcpp::export]]
double log_sum_cpp(arma::mat log_vec) {

  double mmax = max(max(log_vec));

  double res = log(accu(exp(log_vec - mmax))) + mmax;

  if (std::isnan(res)) {
    res = -arma::datum::inf;
  }

  return res;
}

// [[Rcpp::export]]
arma::vec log_rowSums_cpp(arma::mat log_mat) {

  colvec mmax = max(log_mat, 1);

  arma::mat log_diff = log_mat.each_col() - mmax.col(0);

  arma::vec res = log(sum(exp(log_diff), 1)) + mmax;

  for (uword i = 0; i < res.n_elem; ++i) {
    if (std::isnan(res(i))) {
      res(i) = -arma::datum::inf;
    }
  }

  return res;
}


// [[Rcpp::export]]
arma::vec log_colSums_cpp(const arma::mat& log_mat) {

  rowvec mmax = max(log_mat, 0);

  arma::mat log_diff = log_mat.each_row() - mmax;

  arma::vec res = log(sum(exp(log_diff), 0).t()) + mmax.t();

  for (uword i = 0; i < res.n_elem; ++i) {
    if (std::isnan(res(i))) {
      res(i) = -arma::datum::inf;
    }
  }

  return res;
}

// [[Rcpp::export]]
arma::mat log_forward_algorithm_cpp(arma::vec log_ppi, arma::mat logA, arma::mat logB, arma::vec xlabeled) {
  int num_states = log_ppi.size();
  int num_obs = logB.n_rows;

  if (!(NumericVector::is_na(xlabeled(0)))) {
    for (int i = 0; i < num_states; ++i) {
      if ((i+1) == xlabeled(0)) {
        log_ppi(i) = 0;
      } else {
        log_ppi(i) = -arma::datum::inf;
      }
    }
  }

  arma::mat log_alpha(num_obs, num_states, fill::zeros);

  // Initialization (t = 1)
  log_alpha.row(0) = logB.row(0) + log_ppi.t();


  // Recursion (t = 2 to num_obs)
  for (int t = 1; t < num_obs; t++) {
    log_alpha.row(t) = logB.row(t) + log_colSums_cpp(logA.each_col() + log_alpha.row(t - 1).t()).t();
    if (!(NumericVector::is_na(xlabeled(t)))) {
      for (int i = 0; i < num_states; ++i) {
        if ((i+1) != xlabeled(t)) {
          log_alpha(t, i) = -arma::datum::inf;
        }
      }
    }
  }

  return log_alpha;

}




// [[Rcpp::export]]
arma::mat log_backward_algorithm_cpp(arma::mat logA, arma::mat logB, arma::vec xlabeled) {
  int num_obs = logB.n_rows;
  int num_states = logB.n_cols;
  arma::mat log_beta(num_obs, num_states, fill::zeros);

  // Recursion (t = num_obs - 1 to 2)
  for (int t = num_obs - 2; t >= 0; t--) {
    log_beta.row(t) = log_rowSums_cpp(logA.each_row() + (log_beta.row(t + 1) + logB.row(t + 1))).t();
    if (!(NumericVector::is_na(xlabeled(t)))) {
      for (int i = 0; i < num_states; ++i) {
        if ((i+1) != xlabeled(t)) {
          log_beta(t, i) = -arma::datum::inf;
        }
      }
    }
  }
  return(log_beta);


}



// [[Rcpp::export]]
arma::mat log_gamma_cpp(arma::mat log_alpha, arma::mat log_beta) {
  arma::mat log_ab = log_alpha + log_beta;
  arma::mat log_rowsums_log_ab = log_rowSums_cpp(log_ab);
  arma::mat log_gamma = log_ab.each_col() - log_rowsums_log_ab;
  return(log_gamma);
}


// [[Rcpp::export]]
arma::cube log_xi_cpp(arma::mat log_alpha, arma::mat logAhat, arma::mat log_beta, arma::mat logB) {
  int TT = log_alpha.n_rows;
  int nstates = log_alpha.n_cols;
  arma::cube log_xi(TT - 1, nstates, nstates, fill::zeros); // Initialize log_xi with zeros

  for (int tt = 0; tt < TT - 1; tt++) {
    for (int ii = 0; ii < nstates; ii++) {
      for (int jj = 0; jj < nstates; jj++) {
        log_xi(tt, ii, jj) = log_alpha(tt, ii) + logAhat(ii, jj) + log_beta(tt + 1, jj) + logB(tt + 1, jj);
      }
    }
    log_xi(span(tt), span::all, span::all) = log_xi(span(tt), span::all, span::all) - log_sum_cpp(log_xi(span(tt), span::all, span::all));
  }

  return log_xi;
}


// [[Rcpp::export]]
List expectation_step(arma::vec log_ppi, arma::mat logB, arma::mat logA, arma::vec xlabeled) {

  arma::mat log_alpha = log_forward_algorithm_cpp(log_ppi, logA, logB, xlabeled);
  arma::mat log_beta = log_backward_algorithm_cpp(logA, logB, xlabeled);
  arma::mat log_gamma = log_gamma_cpp(log_alpha, log_beta);
  arma::cube log_xi = log_xi_cpp(log_alpha, logA, log_beta, logB);

  List result;
  result["log_alpha"] = log_alpha;
  result["log_beta"] = log_beta;
  result["log_gamma"] = log_gamma;
  result["log_xi"] = log_xi;

  return(result);

}


// [[Rcpp::export]]
List maximization_step(arma::mat y, arma::mat log_gamma, arma::cube log_xi) {

  // int nobs = y.n_rows;
  int nstates = log_gamma.n_cols;

  // Ahat
  arma::mat logAhat(nstates, nstates, fill::zeros);
  arma::vec log_gamma_colsums = log_colSums_cpp(log_gamma);
  for (int ii = 0; ii < nstates; ii++) {
    for (int jj = 0; jj < nstates; jj++) {
      arma::vec log_xi_ij = log_xi(span::all, arma::span(ii), arma::span(jj));
      logAhat(ii, jj) = log_sum_cpp(log_xi_ij) - log_gamma_colsums(ii);
    }
  }
  logAhat = logAhat.each_col() - log_rowSums_cpp(logAhat);


  // Emission distribution parameters
  arma::mat gamma = exp(log_gamma);
  rowvec colSums_gamma = sum(gamma, 0);

  List means_hat(nstates);
  List covariances_hat(nstates);
  for (int ii = 0; ii < nstates; ii++) {
    arma::vec means_hat_ii = (sum(y.each_col() % gamma.col(ii), 0)/ colSums_gamma[ii]).t();
    means_hat[ii] = means_hat_ii;
    arma::mat ycen_ii = (y.each_row() - means_hat_ii.t());
    covariances_hat[ii] = (ycen_ii.t() * (ycen_ii.each_col() % gamma.col(ii))) / colSums_gamma[ii];
  }
  arma::vec log_pi_hat = log_gamma.row(0).t();

  List result;
  result["logAhat"] = logAhat;
  result["means_hat"] = means_hat;
  result["covariances_hat"] = covariances_hat;
  result["log_pi_hat"] = log_pi_hat;

  return(result);


}



// [[Rcpp::export]]
arma::vec log_multivariate_normal_density(const arma::mat& data, const arma::vec& mean, const arma::mat& covariance) {
  int num_dimensions = data.n_cols;
  int num_samples = data.n_rows;

  arma::mat cov_inv = inv(covariance);

  double constant_term = -0.5 * num_dimensions * log(2 * M_PI);
  double log_det_cov = log(det(covariance));

  arma::vec log_density(num_samples, fill::zeros);
  for (int i = 0; i < num_samples; ++i) {
    arma::vec diff = data.row(i).t() - mean;
    double exponent = -0.5 * as_scalar(diff.t() * cov_inv * diff);
    log_density[i] = constant_term - 0.5 * log_det_cov + exponent;
  }

  return log_density;
}




// [[Rcpp::export]]
List baum_welch(arma::mat y,
                arma::vec xlabeled,
                int nstates,
                arma::vec ppi_start,
                arma::mat A_start,
                List mean_start,
                List covariances_start,
                int max_iter = 200,
                double tol = 1e-6,
                bool silent = true) {

  // Initialize
  arma::vec log_ppi_hat = log(ppi_start);
  arma::mat logAhat = log(A_start);
  List means_hat(nstates);
  List covariances_hat(nstates);

  for (int ii = 0; ii < nstates; ii++) {
    arma::vec mean_ii = mean_start[ii];
    arma::mat cov_ii = covariances_start[ii];
    means_hat[ii] = mean_ii;
    covariances_hat[ii] = cov_ii;
  }


  double new_logLik = 0;

  List Estep;
  List Mstep;

  int iter = 0;
  double log_lik = -1e30;
  double loglik_change = 1e5;
  arma::vec log_lik_vec(1, fill::zeros);
  log_lik_vec(0) = log_lik;
  int num_obs = y.n_rows;
  arma::mat logB(num_obs, nstates);

  iter = 0;
  while ((iter < max_iter) & (std::abs(loglik_change) > tol)) {
    iter = iter + 1;
    for (int ii = 0; ii < nstates; ii++) {
      logB.col(ii) = log_multivariate_normal_density(y, means_hat(ii), covariances_hat(ii));
    }

    Estep = expectation_step(log_ppi_hat, logB, logAhat, xlabeled);
    Mstep = maximization_step(y, Estep["log_gamma"], Estep["log_xi"]);
    arma::vec updated_log_ppi_hat = Mstep["log_pi_hat"];
    log_ppi_hat = updated_log_ppi_hat;
    List updated_means_hat = Mstep["means_hat"];
    means_hat = updated_means_hat;
    List updated_covariances_hat = Mstep["covariances_hat"];
    covariances_hat = updated_covariances_hat;
    arma::mat updated_logAhat = Mstep["logAhat"];
    logAhat = updated_logAhat;

    arma::mat log_alpha = Estep["log_alpha"];
    new_logLik = log_sum_cpp(log_alpha.row(num_obs - 1));
    if (!silent) {
      Rprintf("log-likelihood value: %f\n", new_logLik);
    }

    // print(loglik_change)
    log_lik_vec = join_cols(log_lik_vec, arma::vec({new_logLik}));
    loglik_change = log_lik_vec(iter) - log_lik_vec(iter - 1);
    // log_lik = new_logLik;
  }

  List result;

  result["Estep"] = Estep;
  result["Mstep"] = Mstep;
  result["logB"] = logB;
  result["log_lik"] = log_lik;
  result["log_lik_vec"] = log_lik_vec;
  result["n_iterations"] = iter;
  result["new_logLik"] = new_logLik;
  result["loglik_change"] = loglik_change;

  return(result);

}


// [[Rcpp::export]]
int sample_multinomial(NumericVector p) {

  double U = arma::randu();
  int n = p.size();
  NumericVector cumsum_p(n);
  cumsum_p[0] = p[0];
  for (int i = 1; i < n; ++i) {
    cumsum_p[i] = cumsum_p[i - 1] + p[i];
  }
  for (int i = 0; i < n; ++i) {
    if (U <= cumsum_p[i]) {
      return i + 1; // Return the index (1-based)
    }
  }
  return n + 1; // If U exceeds the last element, return n + 1
}


// [[Rcpp::export]]
IntegerVector simulate_state_sequence(NumericVector pp, int length_sequence, NumericMatrix A) {
  IntegerVector simulated_state_sequence(length_sequence);
  simulated_state_sequence[0] = sample_multinomial(pp);

  for (int tttt = 1; tttt < length_sequence; ++tttt) {
    int prev_state = simulated_state_sequence[tttt - 1] - 1;
    NumericVector state_probs = A.row(prev_state);
    simulated_state_sequence[tttt] = sample_multinomial(state_probs);
  }

  return simulated_state_sequence;
}



// [[Rcpp::export]]
double min_ewma_cpp(const arma::mat& y,
                    const std::vector<arma::rowvec>& mean_list,
                    const std::vector<arma::mat>& cov_list,
                    int TT, int p, double lambda, int ww_end) {

  int nstates = mean_list.size();  // Number of states
  arma::vec Vn_ww(nstates, fill::zeros); // Vector to store results for each state

  for (int jj = 0; jj < nstates; jj++) {

    // Initialize En matrix (TT rows, p columns)
    arma::mat En(TT, p, fill::none);
    En.row(0) = lambda * (y.row(0) - mean_list[jj]); // First row computation

    // Loop through ww_end to compute the exponentially weighted moving average
    for (int ii = 1; ii < ww_end; ii++) {
      arma::rowvec ycen = y.row(ii) - mean_list[jj]; // Centered observation
      En.row(ii) = lambda * ycen + (1 - lambda) * En.row(ii - 1);
    }

    // Compute the variance term for state jj
    arma::rowvec last_En = En.row(ww_end - 1);
    arma::mat cov_adjusted = cov_list[jj] * lambda / (2 - lambda);
    Vn_ww(jj) = as_scalar(last_En * inv(cov_adjusted) * last_En.t());
  }

  // Return the minimum value of Vn_ww
  return Vn_ww.min();
}
