#include "Rcpp.h"

template <typename INDEX, typename REAL, INDEX MAX_LEAF_SIZE = 10>
class KDTree {
private:
  struct node {
    INDEX first, last, var;
    REAL val;
    node *left, *right;
    node() : left(nullptr), right(nullptr) {}
    ~node() {
      delete left;
      delete right;
    }
  };

  node *root;
  INDEX p, k, found, *nn, *idx;
  REAL worst, *target, *dd;
  std::vector<INDEX> index;
  std::vector<REAL *> obs;

public:
  KDTree(INDEX p_, INDEX n_, REAL *data_) : p(p_) {
    index.resize(n_);
    obs.resize(n_);
    for (INDEX i = 0; i < n_; i++, data_ += p) {
      index[i] = i;
      obs[i] = data_;
    }
    idx = index.data();
    root = grow(0, n_, 0);
  }

  ~KDTree() { delete root; }

  void search(REAL *target_, INDEX k_, INDEX *nn_, REAL *dd_) {
    target = target_;
    k = k_;
    nn = nn_;
    dd = dd_;
    found = 0;
    worst = dd[k - 1] = std::numeric_limits<REAL>::max();
    search(root);
  }

private:
  inline node *grow(INDEX lower, INDEX upper, INDEX var) {
    node *n = new node;
    if (upper - lower <= MAX_LEAF_SIZE) {
      n->first = lower;
      n->last = upper;
    } else {
      INDEX mid = (lower + upper) / 2;
      std::nth_element(
          idx + lower, idx + mid, idx + upper,
          [&](INDEX i, INDEX j) { return obs[i][var] < obs[j][var]; });
      n->var = var;
      n->val = obs[idx[mid]][var];
      var = (var + 1) % p;
      n->left = grow(lower, mid, var);
      n->right = grow(mid, upper, var);
    }
    return n;
  };

  inline void search(node *n) {
    if (n->left) {
      REAL dv = target[n->var] - (n->val);
      if (dv < 0) {
        search(n->left);
        if (dv * dv < worst)
          search(n->right);
      } else {
        search(n->right);
        if (dv * dv < worst)
          search(n->left);
      }
    } else {
      for (INDEX l = n->first; l < n->last; l++) {
        INDEX i = idx[l];
        REAL d = 0.0, *x = target, *y = obs[i];
        for (INDEX j = 0; j < p; j++)
          d += (x[j] - y[j]) * (x[j] - y[j]);
        if (d < worst)
          add(i, d);
      }
    }
  }

  inline void add(INDEX i, REAL d) {
    if (found < k)
      found++;
    INDEX r = found - 1;
    for (; r > 0; r--) {
      if (d >= dd[r - 1])
        break;
      dd[r] = dd[r - 1];
      nn[r] = nn[r - 1];
    }
    dd[r] = d;
    nn[r] = i;
    worst = dd[k - 1];
  }
};

// [[Rcpp::export]]
Rcpp::IntegerVector knnsearch0(Rcpp::NumericMatrix x, int k) {
  Rcpp::IntegerVector nn(k);
  Rcpp::NumericVector xnew(x.nrow()), dd(k);
  KDTree<int, double> kd(x.nrow(), x.ncol(), x.begin());
  kd.search(xnew.begin(), k, nn.begin(), dd.begin());
  return nn + 1;
}

// [[Rcpp::export]]
Rcpp::NumericMatrix knnreg(Rcpp::NumericMatrix x, Rcpp::NumericMatrix y,
                           Rcpp::NumericMatrix xnew, int k = 0) {
  const int p = x.nrow(), n = x.ncol(), q = y.nrow(), m = xnew.ncol();
  if (k <= 0)
    k = std::ceil(std::sqrt(static_cast<double>(n)));
  KDTree<int, double> kd(p, n, x.begin());
  Rcpp::IntegerVector nnn(k);
  Rcpp::NumericVector ddd(k);
  Rcpp::NumericMatrix yhat(q, m);
  int *nn = nnn.begin();
  double *xi = xnew.begin(), *yhi = yhat.begin(), *d = ddd.begin();
  for (int i = 0; i < m; i++, xi += p, yhi += q) {
    kd.search(xi, k, nn, d);
    double sw = 0.0,
           dmax = std::max(std::numeric_limits<double>::epsilon(), d[k - 1]);
    for (int r = 0; r < k; r++) {
      double *yc = y.begin() + nn[r] * q;
      double w = std::pow(1 - (d[r] / dmax) * std::sqrt(d[r] / dmax), 3);
      sw += w;
      w /= sw;
      for (int j = 0; j < q; j++)
        yhi[j] += w * (yc[j] - yhi[j]);
    }
  }
  return yhat;
}
