# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

#' @title Create hierarchical clusters of selected metrics using a Person query
#'
#' @description
#' `r lifecycle::badge('questioning')`
#'
#' Apply hierarchical clustering to selected metrics. Person averages are computed prior to clustering.
#' The hierarchical clustering uses cosine distance and the ward.D method
#' of agglomeration.
#'
#' @author Ainize Cidoncha <ainize.cidoncha@@microsoft.com>
#'
#' @param data A data frame containing `PersonId` and selected metrics for
#'   clustering.
#' @param metrics Character vector containing names of metrics to use for
#'   clustering. See examples section.
#' @param k Numeric vector to specify the `k` number of clusters to cut by.
#' @param return String specifying what to return. This must be one of the
#'   following strings:
#'   - `"plot"`
#'   - `"data"`
#'   - `"table"`
#'   - `"hclust"`
#'
#' See `Value` for more information.

#' @return
#' A different output is returned depending on the value passed to the `return`
#' argument:
#'   - `"plot"`: 'ggplot' object. A heatmap plot comparing the key metric averages
#'   of the clusters as per `keymetrics_scan()`.
#'   - `"data"`: data frame. Raw data with clusters appended
#'   - `"table"`: data frame. Summary table for identified clusters
#'   - `"hclust"`: 'hclust' object. hierarchical model generated by the function.
#'
#' @import dplyr
#' @import tidyselect
#' @import ggplot2
#' @importFrom proxy dist
#' @importFrom stats hclust
#' @importFrom stats rect.hclust
#' @importFrom stats cutree
#' @importFrom tidyr replace_na
#'
#' @family Clustering
#'
#' @examples
#' \donttest{
#' # Return plot
#' personas_hclust(sq_data,
#'                 metrics = c("Collaboration_hours", "Workweek_span"),
#'                 k = 4)
#'
#' # Return summary table
#'
#' personas_hclust(sq_data,
#'                 metrics = c("Collaboration_hours", "Workweek_span"),
#'                 k = 4,
#'                 return = "table")
#'
#' # Return data with clusters appended
#' personas_hclust(sq_data,
#'                 metrics = c("Collaboration_hours", "Workweek_span"),
#'                 k = 4,
#'                 return = "data")
#' }
#'
#' @export
personas_hclust <- function(data,
                            metrics,
                            k = 4,
                            return = "plot"){

  ## Use names for matching
  input_var <- metrics


  ## transform the data for clusters
 data_cluster <-
   data %>%
   select(PersonId, input_var) %>%
   group_by(PersonId) %>%
   summarise_at(vars(input_var), ~mean(., na.rm = TRUE), .groups = "drop")


  ## Run hclust
  h_clust <-
    data_cluster %>%
    select(input_var) %>%
    proxy::dist(method = "cosine") %>%
    stats::hclust(method = "ward.D")

  ## Cut tree
  cuts <- stats::cutree(h_clust, k = k)

  ## Bind cut tree to data frame
  data_final <-
    data_cluster%>%
    select(PersonId) %>%
    cbind("cluster" = cuts) %>%
    left_join(data, by = "PersonId")




  ## Return
  if(return == "data"){

    return(data_final)

  } else if(return == "table"){

    ## Count table
    count_tb <-
      data_final %>%
      group_by(cluster) %>%
      summarise(n = n()) %>%
      mutate(prop = n / sum(n))

    ## Summary statistics
    sums_tb <-
      data_final %>%
      group_by(cluster) %>%
      summarise_if(is.numeric,function(x) round(mean(x),1))


    count_tb %>%
      left_join(sums_tb, by = "cluster") %>%
      return()

  } else if(return =="plot"){

    ## Unique person count
    ## Print count string
    count_tb_p <-
      data_final %>%
      hrvar_count(hrvar = "cluster", return = "table") %>%
      arrange(cluster) %>%
      mutate(print_str = paste0("cluster ", cluster, " = ", n)) %>%
      pull(print_str) %>%
      paste(collapse = "; ")

    ## Use keymetrics_scan() to visualize clusters
    data_final %>%
      mutate(cluster = factor(cluster)) %>%
      keymetrics_scan(hrvar = "cluster") +
      labs(title = "Key metrics by personas clusters",
           caption = paste(count_tb_p, "\n",
                           extract_date_range(data, return = "text")))

  } else if(return == "hclust"){

    return(h_clust)

  } else {

    stop("Invalid input for `return`.")

  }
}


