#' Compute community scores from a fitted MixMashNet model
#'
#' @description
#' Computes subject-level community scores. Community scores are obtained as
#' weighted sums of the variables belonging to each detected community, where
#' weights correspond to the standardized community loadings estimated via
#' \code{EGAnet::net.loads} and stored in the fitted \code{mixMN_fit} object.
#' Scores are computed using the dataset provided via the \code{data} argument.
#' If \code{data = NULL}, the original dataset used to fit the model
#' (\code{fit$model$data}) is used by default.
#' Optionally, percentile bootstrap quantile regions for the community
#' scores can be computed if bootstrap community loadings are available in
#' \code{fit$community_loadings$boot}.
#' Community scores are only available if community loadings were computed
#' in the fitted model. This requires that all variables in the community
#' subgraph are of MGM type Gaussian (\code{"g"}), Poisson (\code{"p"}), or
#' binary categorical (\code{"c"} with \code{level == 2}).
#'
#' @param fit A fitted object of class \code{c("mixmashnet","mixMN_fit", "multimixMN_fit")}
#'   returned by \code{mixMN()} or \code{multimixMN()}.
#' @param data Optional data.frame with variables in columns. If
#'   \code{NULL}, uses \code{fit$model$data}. Errors if both are \code{NULL}.
#' @param layer Optional. If fit is a multimixMN_fit, specify which layer to score (name or index).
#'   If NULL, scores are computed for all layers and returned as a named list.
#' @param scale Logical; if \code{TRUE} (default), z-standardize variables used
#'   for scoring, using the mean/SD computed from the dataset used for scoring.
#' @param quantile_level Optional numeric from 0 to 1, e.g. 0.95 or 0.99. If provided,
#'   percentile bootstrap quantile regions are computed for community scores
#'   (requires \code{fit$community_loadings$boot}).
#' @param return_quantile_region Logical; if \code{TRUE}, return quantile regions.
#' @param na_action Character. How to handle missing values in the scoring data:
#'   \code{"stop"} (default) stops if any missing value is present in the
#'   required variables; \code{"omit"} computes scores using row-wise omission
#'   within each community (i.e., uses available variables only, re-normalizing
#'   weights within community for that row).
#'
#' @return A list with class \code{c("mixmashnet","community_scores")} containing:
#' \describe{
#'   \item{\code{call}}{The matched call.}
#'   \item{\code{settings}}{List with \code{scale}, \code{quantile_level}, and \code{na_action}.}
#'   \item{\code{ids}}{Character vector of subject IDs (rownames of \code{data}).}
#'   \item{\code{communities}}{Character vector of community score names.}
#'   \item{\code{scores}}{Numeric matrix of scores (n × K).}
#'   \item{\code{quantile_region}}{If requested and available, a list with \code{lower} and \code{upper}
#'     matrices (n × K) for percentile bootstrap quantile regions; otherwise \code{NULL}.}
#'   \item{\code{details}}{List containing \code{nodes_used}, \code{loadings_true},
#'     \code{loadings_boot_available}, and scaling parameters (\code{center}, \code{scale}).}
#' }
#' If \code{fit} is a \code{mixMN_fit} (or a \code{multimixMN_fit} with \code{layer} specified),
#' returns a \code{c("mixmashnet","community_scores")} object.
#' If \code{fit} is a \code{multimixMN_fit} and \code{layer = NULL}, returns a named list
#' of \code{community_scores} objects (one per layer).
#'
#' @details
#' The function requires that \code{fit$community_loadings$true} exists and that
#' the input \code{data} contains all required variables in
#' \code{fit$community_loadings$nodes}. It errors otherwise.
#'
#' @references
#'
#' Christensen, A. P., Golino, H., Abad, F. J., & Garrido, L. E. (2025).
#' Revised network loadings. \emph{Behavior Research Methods}, 57(4), 114.
#' \doi{10.3758/s13428-025-02640-3}
#'
#' @examples
#' data(bacteremia)
#'
#' vars <- c("WBC", "NEU", "HGB", "PLT", "CRP")
#' df <- bacteremia[, vars]
#'
#' fit <- mixMN(
#'   data = df,
#'   lambdaSel = "EBIC",
#'   reps = 0,
#'   seed_model = 42,
#'   compute_loadings = TRUE,
#'   progress = FALSE,
#'   save_data = TRUE
#' )
#'
#' # Compute community scores on the original data
#' scores <- community_scores(fit)
#' summary(scores)
#'
#' @export
community_scores <- function(
    fit,
    data = NULL,
    layer = NULL,
    scale = TRUE,
    quantile_level = NULL,
    return_quantile_region = FALSE,
    na_action = c("stop", "omit")
) {
  na_action <- match.arg(na_action)

  if (!is.null(quantile_level)) {
    return_quantile_region <- TRUE
  }

  if (isTRUE(return_quantile_region) && is.null(quantile_level)) {
    quantile_level <- 0.95
  }

  # ---- MULTI: compute per layer ----
  if (is.null(fit) || !(inherits(fit, "mixMN_fit") || inherits(fit, "multimixMN_fit"))) {
    stop("`fit` must be a `mixMN_fit` or `multimixMN_fit` object.")
  }

  # ---- MULTI: compute per layer ----
  if (inherits(fit, "multimixMN_fit")) {

    # choose data: from multi by default
    if (is.null(data)) {
      data <- fit$model$data
      if (is.null(data)) {
        stop("`fit` is a multimixMN_fit but `fit$model$data` is NULL. Refit with save_data = TRUE or pass `data=`.")
      }
    }

    if (is.null(fit$layer_fits) || length(fit$layer_fits) == 0) {
      stop("`fit$layer_fits` is missing/empty: cannot compute per-layer scores.")
    }

    # if layer is NULL -> all layers
    if (is.null(layer)) {

      out_list <- vector("list", length(fit$layer_fits))
      names(out_list) <- names(fit$layer_fits)

      failed_layers <- character(0)

      for (L in names(fit$layer_fits)) {

        res <- tryCatch(
          community_scores(
            fit   = fit$layer_fits[[L]],
            data  = data,
            layer = NULL,
            scale = scale,
            quantile_level = quantile_level,
            return_quantile_region = return_quantile_region,
            na_action = na_action
          ),
          error = function(e) {
            failed_layers <<- c(failed_layers, L)
            warning(
              sprintf("Community scores not computed for layer '%s': %s",
                      L, conditionMessage(e)),
              call. = FALSE
            )
            NULL
          }
        )

        out_list[[L]] <- res
      }

      out_list <- out_list[!vapply(out_list, is.null, logical(1))]

      return(out_list)
    }

    # otherwise: one layer
    layer_names <- names(fit$layer_fits)
    if (is.numeric(layer) && length(layer) == 1L) {
      if (layer < 1 || layer > length(layer_names)) stop("`layer` index is out of range.")
      layer <- layer_names[layer]
    } else {
      layer <- as.character(layer)[1]
      if (!layer %in% layer_names) {
        stop("`layer` not found in `fit$layer_fits`. Available layers: ", paste(layer_names, collapse = ", "))
      }
    }

    return(
      community_scores(
        fit   = fit$layer_fits[[layer]],
        data  = data,
        scale = scale,
        quantile_level = quantile_level,
        return_quantile_region = return_quantile_region,
        na_action = na_action
      )
    )
  }

  # ---- loadings ----
  cl <- fit$community_loadings
  if (is.null(cl) || !isTRUE(cl$available)) {
    msg <- if (!is.null(cl$reason)) cl$reason else
      "Community scores are not available for this fit (community loadings were not computed)."
    if (!is.null(cl$non_scorable_nodes) && length(cl$non_scorable_nodes) > 0) {
      msg <- paste0(
        msg,
        " Non-scorable nodes: ",
        paste(utils::head(cl$non_scorable_nodes, 10), collapse = ", "),
        if (length(cl$non_scorable_nodes) > 10) " ..." else ""
      )
    }
    stop(msg)
  }

  L_true <- cl$true
  nodes  <- cl$nodes
  wc     <- cl$wc

  if (is.null(L_true) || is.null(nodes) || length(nodes) == 0) {
    stop("No community loadings found in `fit$community_loadings`. Did you run mixMN(compute_loadings = TRUE)?")
  }
  if (is.null(wc) || length(wc) != length(nodes)) {
    stop("`fit$community_loadings$wc` is missing or has wrong length: cannot zero cross-loadings.")
  }

  # Ensure dimnames
  if (is.null(rownames(L_true))) rownames(L_true) <- nodes
  if (is.null(colnames(L_true))) colnames(L_true) <- paste0("C", seq_len(ncol(L_true)))

  # helper: set cross-loadings to 0 using hard membership wc
  .zero_cross_loadings <- function(Lmat, nodes, wc) {
    Lmat <- Lmat[nodes, , drop = FALSE]
    K <- ncol(Lmat)

    wc_int <- as.integer(wc)
    if (any(is.na(wc_int))) stop("`wc` must be coercible to integers.")
    if (any(wc_int < 1 | wc_int > K)) {
      stop("`wc` contains community indices outside [1, K].")
    }

    Lhard <- matrix(0, nrow = nrow(Lmat), ncol = K,
                    dimnames = list(rownames(Lmat), colnames(Lmat)))

    for (j in seq_len(nrow(Lmat))) {
      k <- wc_int[j]
      Lhard[j, k] <- Lmat[j, k]  # keep only within-community loading
    }
    Lhard
  }

  # apply hardening to true loadings (cross-loadings -> 0)
  L_true <- .zero_cross_loadings(L_true, nodes = nodes, wc = wc)

  # ---- choose data ----
  used_fit_data <- FALSE
  if (is.null(data)) {
    data <- fit$model$data
    used_fit_data <- TRUE
    if (is.null(data)) {
      stop("No `data` provided and `fit$model$data` is NULL. Refit with save_data = TRUE or pass `data=`.")
    }
  }

  # ---- coerce and ids ----
  if (!is.data.frame(data) && !is.matrix(data)) data <- as.data.frame(data)
  if (is.null(colnames(data))) stop("`data` must have column names.")
  if (is.null(rownames(data))) rownames(data) <- sprintf("id_%d", seq_len(nrow(data)))
  ids <- rownames(data)

  # ---- variable checks ----
  missing_vars <- setdiff(nodes, colnames(data))
  if (length(missing_vars) > 0) {
    stop("`data` is missing required variables: ", paste(missing_vars, collapse = ", "))
  }

  # ---- build numeric X (robust to logical / binary factors) ----
  bin_map <- fit$data_info$binary_recode_map
  if (is.null(bin_map)) bin_map <- list()

  dfX <- data[, nodes, drop = FALSE]

  X <- sapply(nodes, function(nm) {
    v <- dfX[[nm]]

    if (is.numeric(v))  return(v)
    if (is.integer(v))  return(as.numeric(v))
    if (is.logical(v))  return(as.numeric(v))  # FALSE/TRUE -> 0/1

    if (is.factor(v) || is.ordered(v)) {
      lv <- levels(v)

      if (length(lv) == 2L) {
        m <- bin_map[[nm]]
        if (!is.null(m)) {
          return(as.numeric(m[as.character(v)]))
        }
        return(as.numeric(as.character(v) == lv[2L]))
      }

      stop("Cannot compute community scores: non-binary categorical variable '", nm, "'.")
    }

    stop("Unsupported variable class for scoring in '", nm, "': ", paste(class(v), collapse = "/"))
  })

  X <- as.matrix(X)
  colnames(X) <- nodes
  rownames(X) <- ids

  # ---- NA handling ----
  if (na_action == "stop" && anyNA(X)) {
    stop("Missing values detected in required variables. Use `na_action = \"omit\"` if you want row-wise omission.")
  }

  # ---- scaling on the scoring data (default) ----
  center_vec <- rep(0, ncol(X)); names(center_vec) <- colnames(X)
  scale_vec  <- rep(1, ncol(X)); names(scale_vec)  <- colnames(X)

  if (isTRUE(scale)) {
    # compute mean/sd on the dataset used for scoring
    center_vec <- apply(X, 2, function(v) mean(v, na.rm = TRUE))
    scale_vec  <- apply(X, 2, function(v) stats::sd(v, na.rm = TRUE))
    scale_vec[is.na(scale_vec) | scale_vec == 0] <- 1
    X <- sweep(X, 2, center_vec, "-")
    X <- sweep(X, 2, scale_vec, "/")
  }

  # ---- score computation ----
  # Scores = X %*% L_true (n x p) %*% (p x K) -> (n x K)
  # We need L_true aligned to nodes
  L_true <- L_true[nodes, , drop = FALSE]

  compute_scores_matrix <- function(Xmat, Lmat) {
    # Xmat: n x p ; Lmat: p x K
    S <- Xmat %*% Lmat
    colnames(S) <- colnames(Lmat)
    S
  }

  # If NA and na_action = "omit": compute per community, per row with renormalized weights
  if (na_action == "omit" && anyNA(X)) {
    K <- ncol(L_true)
    S <- matrix(NA_real_, nrow = nrow(X), ncol = K,
                dimnames = list(ids, colnames(L_true)))

    for (k in seq_len(K)) {
      w <- L_true[, k]
      for (i in seq_len(nrow(X))) {
        xi <- X[i, ]
        ok <- !is.na(xi) & !is.na(w)
        if (!any(ok)) {
          S[i, k] <- NA_real_
        } else {
          # re-normalize weights to keep scale stable (optional; reasonable default)
          w_ok <- w[ok]
          # if all weights are 0, keep unnormalized
          denom <- sum(abs(w_ok))
          if (is.na(denom) || denom == 0) denom <- 1
          S[i, k] <- sum(xi[ok] * w_ok) / denom
        }
      }
    }
    scores <- S
    rownames(scores) <- ids
  } else {
    scores <- compute_scores_matrix(X, L_true)
    rownames(scores) <- ids
  }

  # ---- quantile regions (optional) ----
  quantile_region_out <- NULL
  if (isTRUE(return_quantile_region)) {
    if (!is.numeric(quantile_level) || length(quantile_level) != 1L ||
        is.na(quantile_level) || quantile_level <= 0 || quantile_level >= 1) {
      stop("`quantile_level` must be a single number strictly between 0 and 1 (e.g., 0.95).")
    }

    L_boot <- fit$community_loadings$boot
    if (is.null(L_boot) || length(L_boot) == 0) {
      stop("Quantile region requested but `fit$community_loadings$boot` is NULL/empty. Run mixMN with boot_what including \"loadings\".")
    }

    alpha <- 1 - quantile_level
    probs <- c(alpha/2, 1 - alpha/2)

    reps <- length(L_boot)
    n <- nrow(X)
    K <- ncol(L_true)

    # container to store boot scores as a list then stack

    S_boot <- array(NA_real_, dim = c(reps, n, K),
                    dimnames = list(NULL, ids, colnames(L_true)))
    for (r in seq_len(reps)) {
      Lr <- L_boot[[r]]
      if (is.null(Lr)) next

      if (is.null(colnames(Lr))) colnames(Lr) <- colnames(L_true)
      if (is.null(rownames(Lr))) rownames(Lr) <- nodes

      # harden bootstrap loadings: cross-loadings -> 0
      Lr <- .zero_cross_loadings(Lr, nodes = nodes, wc = wc)

      if (na_action == "omit" && anyNA(X)) {
        Sr <- matrix(NA_real_, nrow = n, ncol = K,
                     dimnames = list(ids, colnames(L_true)))

        for (k in seq_len(K)) {
          w <- Lr[, k]
          for (i in seq_len(n)) {
            xi <- X[i, ]
            ok <- !is.na(xi) & !is.na(w)
            if (!any(ok)) {
              Sr[i, k] <- NA_real_
            } else {
              w_ok <- w[ok]
              denom <- sum(abs(w_ok))
              if (is.na(denom) || denom == 0) denom <- 1
              Sr[i, k] <- sum(xi[ok] * w_ok) / denom
            }
          }
        }
      } else {
        Sr <- compute_scores_matrix(X, Lr)
        rownames(Sr) <- ids
      }

      S_boot[r, , ] <- Sr
    }

    # Quantiles over reps for each subject/community cell
    lower <- matrix(NA_real_, nrow = n, ncol = K, dimnames = list(ids, colnames(L_true)))
    upper <- matrix(NA_real_, nrow = n, ncol = K, dimnames = list(ids, colnames(L_true)))

    for (i in seq_len(n)) {
      for (k in seq_len(K)) {
        v <- S_boot[, i, k]
        if (all(is.na(v))) {
          lower[i, k] <- NA_real_
          upper[i, k] <- NA_real_
        } else {
          qs <- stats::quantile(v, probs = probs, na.rm = TRUE, names = FALSE)
          lower[i, k] <- qs[1]
          upper[i, k] <- qs[2]
        }
      }
    }

    quantile_region_out <- list(lower = lower, upper = upper, quantile_level = quantile_level)
  }

  out <- list(
    call = match.call(),
    settings = list(
      scale = isTRUE(scale),
      quantile_level = quantile_level,
      na_action = na_action
    ),
    ids = ids,
    communities = colnames(L_true),
    scores = scores,
    quantile_region = quantile_region_out,
    details = list(
      nodes_used = nodes,
      loadings_true = L_true,
      loadings_boot_available = !is.null(fit$community_loadings$boot) && length(fit$community_loadings$boot) > 0,
      center = center_vec,
      scale = scale_vec
    )
  )
  class(out) <- c("community_scores", "mixmashnet")
  return(out)
}

#'
#' @export
print.community_scores <- function(x, ...) {

  cat("MixMashNet community scores\n")
  cat(strrep("=", 30), "\n\n", sep = "")

  # Basic structure
  cat("Subjects:    ", length(x$ids), "\n", sep = "")
  cat("Communities: ", length(x$communities), "\n", sep = "")

  # Settings
  cat("\nSettings\n")
  cat("  Scaling:       ", x$settings$scale, "\n", sep = "")

  if (!is.null(x$quantile_region)) {
    ql <- x$quantile_region$quantile_level
    cat("  Quantile region: ", paste0(round(100 * ql), "%"), "\n", sep = "")
  } else {
    cat("  Quantile region: not computed\n")
  }

  # Community names
  cat("\nCommunity names:\n  ")
  cat(paste(x$communities, collapse = ", "))
  cat("\n")

  invisible(x)
}

#' @export
summary.community_scores <- function(object, ...) {

  x <- object

  n <- nrow(x$scores)
  K <- ncol(x$scores)

  has_qr <- !is.null(x$quantile_region)
  ql <- if (has_qr) x$quantile_region$quantile_level else NA_real_

  score_means <- colMeans(x$scores, na.rm = TRUE)
  score_sds   <- apply(x$scores, 2, stats::sd, na.rm = TRUE)
  score_min   <- apply(x$scores, 2, min, na.rm = TRUE)
  score_max   <- apply(x$scores, 2, max, na.rm = TRUE)

  out <- list(
    subjects = n,
    communities = K,
    community_names = x$communities,
    settings = x$settings,
    quantile_region = list(
      computed = has_qr,
      quantile_level = ql
    ),
    score_summary = data.frame(
      community = x$communities,
      mean = as.numeric(score_means),
      sd   = as.numeric(score_sds),
      min  = as.numeric(score_min),
      max  = as.numeric(score_max),
      row.names = NULL
    )
  )

  class(out) <- "summary.community_scores"
  out
}


#' @export
print.summary.community_scores <- function(x, ...) {

  cat("Summary of MixMashNet community scores\n")
  cat(strrep("=", 40), "\n\n", sep = "")

  cat("Subjects:    ", x$subjects, "\n", sep = "")
  cat("Communities: ", x$communities, "\n", sep = "")

  cat("\nSettings\n")
  cat("  Scaling:   ", x$settings$scale, "\n", sep = "")
  cat("  NA action: ", x$settings$na_action, "\n", sep = "")

  if (isTRUE(x$quantile_region$computed)) {
    cat("  Quantile region: ",
        round(100 * x$quantile_region$quantile_level), "%\n", sep = "")
  } else {
    cat("  Quantile region: not computed\n")
  }

  cat("\nPer-community statistics (across subjects)\n")
  df <- x$score_summary
  df$mean <- signif(df$mean, 4)
  df$sd   <- signif(df$sd, 4)
  df$min  <- signif(df$min, 4)
  df$max  <- signif(df$max, 4)

  print(df, row.names = FALSE)

  invisible(x)
}
