# ============================================================================ #
# ggplot2 autoplot methods for brsmm
# ============================================================================ #

#' ggplot2 autoplot for brsmm models
#'
#' @description
#' Produces ggplot2 diagnostics tailored to mixed beta interval models.
#'
#' @param object A fitted \code{"brsmm"} object.
#' @param type Plot type:
#'   \code{"calibration"}, \code{"score_dist"}, \code{"ranef_qq"}, or
#'   \code{"residuals_by_group"}, \code{"ranef_caterpillar"},
#'   \code{"ranef_density"}, \code{"ranef_pairs"}.
#' @param bins Number of bins used in calibration plots.
#' @param scores Optional integer vector of scores for \code{"score_dist"}.
#'   Defaults to all scores from \code{0} to \code{ncuts}.
#' @param residual_type Residual type passed to \code{\link{residuals.brsmm}}
#'   for \code{type = "residuals_by_group"}.
#' @param max_groups Maximum number of groups displayed in
#'   \code{"residuals_by_group"}.
#' @param ... Currently ignored.
#'
#' @return A \code{ggplot2} object.
#'
#' @examples
#' \donttest{
#' dat <- data.frame(
#'   y = c(
#'     0, 5, 20, 50, 75, 90, 100, 30, 60, 45,
#'     10, 40, 55, 70, 85, 25, 35, 65, 80, 15
#'   ),
#'   x1 = rep(c(1, 2), 10),
#'   id = factor(rep(1:4, each = 5))
#' )
#' prep <- brs_prep(dat, ncuts = 100)
#' fit_mm <- brsmm(y ~ x1, random = ~ 1 | id, data = prep)
#' ggplot2::autoplot(fit_mm, type = "calibration", bins = 4)
#' ggplot2::autoplot(fit_mm, type = "score_dist")
#' ggplot2::autoplot(fit_mm, type = "ranef_qq")
#' ggplot2::autoplot(fit_mm, type = "ranef_caterpillar")
#' ggplot2::autoplot(fit_mm, type = "ranef_density")
#' }
#'
#' @method autoplot brsmm
#' @importFrom rlang .data
#' @importFrom stats aggregate quantile
#' @export autoplot.brsmm
autoplot.brsmm <- function(object,
                           type = c(
                             "calibration",
                             "score_dist",
                             "ranef_qq",
                             "residuals_by_group",
                             "ranef_caterpillar",
                             "ranef_density",
                             "ranef_pairs"
                           ),
                           bins = 10L,
                           scores = NULL,
                           residual_type = c("response", "pearson"),
                           max_groups = 25L,
                           ...) {
  .check_class_mm(object)
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for autoplot.brsmm().", call. = FALSE)
  }

  type <- match.arg(type)
  residual_type <- match.arg(residual_type)
  bins <- as.integer(bins)
  max_groups <- as.integer(max_groups)

  if (!is.finite(bins) || bins < 3L) {
    stop("'bins' must be an integer >= 3.", call. = FALSE)
  }
  if (!is.finite(max_groups) || max_groups < 2L) {
    stop("'max_groups' must be an integer >= 2.", call. = FALSE)
  }

  switch(type,
    calibration = .brsmm_autoplot_calibration(object, bins = bins),
    score_dist = .brsmm_autoplot_score_dist(object, scores = scores),
    ranef_qq = .brsmm_autoplot_ranef_qq(object),
    residuals_by_group = .brsmm_autoplot_resid_group(
      object,
      residual_type = residual_type,
      max_groups = max_groups
    ),
    ranef_caterpillar = .brsmm_autoplot_ranef_caterpillar(object),
    ranef_density = .brsmm_autoplot_ranef_density(object),
    ranef_pairs = .brsmm_autoplot_ranef_pairs(object)
  )
}

#' @keywords internal
.brsmm_autoplot_calibration <- function(object, bins = 10L) {
  df <- data.frame(
    observed = as.numeric(object$Y[, "yt"]),
    predicted = as.numeric(object$fitted_mu)
  )
  probs <- seq(0, 1, length.out = bins + 1L)
  breaks <- unique(stats::quantile(df$predicted, probs = probs, na.rm = TRUE))
  if (length(breaks) < 3L) {
    breaks <- seq(min(df$predicted), max(df$predicted), length.out = bins + 1L)
  }
  df$bin <- cut(df$predicted, breaks = breaks, include.lowest = TRUE, ordered_result = TRUE)

  cal <- stats::aggregate(df[, c("predicted", "observed")], by = list(bin = df$bin), FUN = mean)
  cal$n <- as.integer(table(df$bin)[as.character(cal$bin)])

  ggplot2::ggplot(cal, ggplot2::aes(x = .data$predicted, y = .data$observed, size = .data$n)) +
    ggplot2::geom_point(color = "#1b9e77", alpha = 0.9) +
    ggplot2::geom_line(color = "#1b9e77", alpha = 0.6) +
    ggplot2::geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "gray35") +
    ggplot2::labs(
      title = "Calibration Plot (brsmm)",
      x = "Mean predicted response (bin average)",
      y = "Mean observed response (bin average)",
      size = "Bin n"
    ) +
    ggplot2::theme_minimal()
}

#' @keywords internal
.brsmm_autoplot_score_dist <- function(object, scores = NULL) {
  K <- as.integer(object$ncuts)
  if (is.null(scores)) {
    scores <- 0:K
  }
  scores <- sort(unique(as.integer(scores)))
  if (any(!is.finite(scores)) || any(scores < 0L) || any(scores > K)) {
    stop("'scores' must be integers in [0, ncuts].", call. = FALSE)
  }

  obs_scores <- .brs_observed_scores(object$Y[, "y"], K = K)
  obs_counts <- as.numeric(table(factor(obs_scores, levels = scores)))

  probs <- .brs_score_prob_matrix(
    mu = object$fitted_mu,
    phi = object$fitted_phi,
    repar = object$repar,
    ncuts = K,
    lim = object$lim,
    scores = scores
  )
  exp_counts <- colSums(probs)

  df <- rbind(
    data.frame(score = scores, count = obs_counts, source = "Observed"),
    data.frame(score = scores, count = exp_counts, source = "Expected")
  )

  ggplot2::ggplot(df, ggplot2::aes(x = .data$score, y = .data$count, fill = .data$source)) +
    ggplot2::geom_col(position = "dodge", alpha = 0.85) +
    ggplot2::scale_fill_manual(values = c(Observed = "#1b9e77", Expected = "#7570b3")) +
    ggplot2::labs(
      title = "Observed vs Expected Score Distribution (brsmm)",
      x = "Score",
      y = "Count",
      fill = ""
    ) +
    ggplot2::theme_minimal()
}

#' @keywords internal
.brsmm_autoplot_ranef_qq <- function(object) {
  re <- .brsmm_ranef_matrix(object)
  parts <- lapply(colnames(re), function(term) {
    x <- sort(as.numeric(re[, term]))
    n <- length(x)
    q <- stats::qnorm(stats::ppoints(n))
    data.frame(
      theoretical = q,
      sample = x,
      term = term,
      intercept = mean(x) - stats::sd(x) * mean(q),
      slope = stats::sd(x)
    )
  })
  df <- do.call(rbind, parts)

  ggplot2::ggplot(df, ggplot2::aes(x = .data$theoretical, y = .data$sample)) +
    ggplot2::geom_point(color = "#1b9e77", alpha = 0.85, size = 1.1) +
    ggplot2::geom_abline(
      ggplot2::aes(intercept = .data$intercept, slope = .data$slope),
      linetype = "dashed",
      color = "gray35"
    ) +
    ggplot2::facet_wrap(~term, scales = "free_y") +
    ggplot2::labs(
      title = "Random-Effects Q-Q Plot",
      x = "Theoretical Normal Quantiles",
      y = "Estimated random-effect mode"
    ) +
    ggplot2::theme_minimal()
}

#' @keywords internal
.brsmm_autoplot_resid_group <- function(object,
                                        residual_type = c("response", "pearson"),
                                        max_groups = 25L) {
  residual_type <- match.arg(residual_type)
  res <- residuals(object, type = residual_type)
  grp <- object$group
  tab <- sort(table(grp), decreasing = TRUE)
  keep <- names(tab)[seq_len(min(max_groups, length(tab)))]

  df <- data.frame(
    residual = res,
    group = factor(as.character(grp), levels = keep)
  )
  df <- df[!is.na(df$group), , drop = FALSE]

  ggplot2::ggplot(df, ggplot2::aes(x = .data$group, y = .data$residual, fill = .data$group)) +
    ggplot2::geom_boxplot(alpha = 0.70, outlier.shape = NA) +
    ggplot2::geom_hline(yintercept = 0, linetype = "dashed", color = "gray35") +
    ggplot2::labs(
      title = paste0(
        "Residuals by Group (top ", length(keep), " groups, ",
        residual_type, ")"
      ),
      x = "Group",
      y = "Residual"
    ) +
    ggplot2::theme_minimal() +
    ggplot2::theme(
      legend.position = "none",
      axis.text.x = ggplot2::element_text(angle = 60, hjust = 1)
    )
}

#' @keywords internal
.brsmm_ranef_matrix <- function(object) {
  re <- object$random$mode_b
  if (is.matrix(re)) {
    return(re)
  }
  out <- matrix(as.numeric(re), ncol = 1L)
  rownames(out) <- names(re)
  cn <- object$random$terms
  if (is.null(cn) || length(cn) == 0L) cn <- "(Intercept)"
  colnames(out) <- cn[1L]
  out
}

#' @keywords internal
.brsmm_autoplot_ranef_caterpillar <- function(object) {
  re <- .brsmm_ranef_matrix(object)
  parts <- lapply(colnames(re), function(term) {
    v <- as.numeric(re[, term])
    ord <- order(v)
    data.frame(
      group = factor(rownames(re)[ord], levels = rownames(re)[ord]),
      mode = v[ord],
      term = term
    )
  })
  df <- do.call(rbind, parts)

  ggplot2::ggplot(df, ggplot2::aes(x = .data$group, y = .data$mode)) +
    ggplot2::geom_hline(yintercept = 0, linetype = "dashed", color = "gray35") +
    ggplot2::geom_segment(ggplot2::aes(xend = .data$group, y = 0, yend = .data$mode),
      color = "#7570b3", alpha = 0.5
    ) +
    ggplot2::geom_point(color = "#1b9e77", size = 1.6) +
    ggplot2::facet_wrap(~term, scales = "free_y") +
    ggplot2::labs(
      title = "Caterpillar Plot of Random Effects",
      x = "Group (ordered by mode)",
      y = "Random-effect mode"
    ) +
    ggplot2::theme_minimal() +
    ggplot2::theme(axis.text.x = ggplot2::element_blank())
}

#' @keywords internal
.brsmm_autoplot_ranef_density <- function(object) {
  re <- .brsmm_ranef_matrix(object)
  df <- data.frame(
    value = as.numeric(re),
    term = rep(colnames(re), each = nrow(re))
  )
  ggplot2::ggplot(df, ggplot2::aes(x = .data$value, color = .data$term, fill = .data$term)) +
    ggplot2::geom_density(alpha = 0.20, linewidth = 0.9) +
    ggplot2::geom_vline(xintercept = 0, linetype = "dashed", color = "gray35") +
    ggplot2::labs(
      title = "Density of Random Effects",
      x = "Random-effect mode",
      y = "Density",
      color = "Term",
      fill = "Term"
    ) +
    ggplot2::theme_minimal()
}

#' @keywords internal
.brsmm_autoplot_ranef_pairs <- function(object) {
  re <- .brsmm_ranef_matrix(object)
  if (ncol(re) < 2L) {
    stop("ranef_pairs requires at least two random-effect terms.", call. = FALSE)
  }
  df <- data.frame(
    x = re[, 1L],
    y = re[, 2L]
  )
  xlab <- colnames(re)[1L]
  ylab <- colnames(re)[2L]
  ggplot2::ggplot(df, ggplot2::aes(x = .data$x, y = .data$y)) +
    ggplot2::geom_hline(yintercept = 0, linetype = "dashed", color = "gray35") +
    ggplot2::geom_vline(xintercept = 0, linetype = "dashed", color = "gray35") +
    ggplot2::geom_point(color = "#1b9e77", alpha = 0.8) +
    ggplot2::geom_smooth(method = "lm", se = FALSE, color = "#d95f02", linewidth = 0.8) +
    ggplot2::labs(
      title = "Random-Effects Pair Plot",
      x = xlab,
      y = ylab
    ) +
    ggplot2::theme_minimal()
}
