#' Fit a GAM-Cox Model with Tensor-Product Spline Surface
#'
#' @description
#' Fits a penalized Cox proportional hazards model with a tensor-product
#' spline surface for the latent biomarker summaries using \code{mgcv::gam}
#' with \code{family = cox.ph()}.
#'
#' @param data Data frame from \code{.build_transition_data}, containing
#'   \code{time_in_state}, \code{status}, covariate columns, and
#'   \code{eta_*} columns.
#' @param covariates Character vector of covariate names.
#' @param k_marginal Integer vector of marginal basis dimensions. Default \code{c(5,5)}.
#' @param k_additive Integer basis dimension for additive smooth of third
#'   biomarker. Default \code{6}.
#' @param bs Spline basis type. Default \code{"tp"}.
#' @param method Smoothing method. Default \code{"REML"}.
#'
#' @return A \code{mgcv::gam} object, or \code{NULL} on failure.
#'
#' @export
fit_gam_cox <- function(data, covariates = c("age_baseline", "sex"),
                        k_marginal = c(5, 5), k_additive = 6,
                        bs = "tp", method = "REML") {

  eta_cols <- grep("^eta_", names(data), value = TRUE)
  if (length(eta_cols) == 0) return(NULL)

  ## Build Surv object
  data$surv_obj <- survival::Surv(data$time_in_state, data$status)

  ## Construct formula
  cov_part <- if (length(covariates) > 0) {
    paste(intersect(covariates, names(data)), collapse = " + ")
  } else { "" }

  if (length(eta_cols) >= 2) {
    k1 <- k_marginal[1]; k2 <- k_marginal[min(2, length(k_marginal))]
    smooth_part <- paste0("te(", eta_cols[1], ", ", eta_cols[2],
                          ", k=c(", k1, ",", k2, "), bs='", bs, "')")
    if (length(eta_cols) >= 3) {
      smooth_part <- paste0(smooth_part, " + s(", eta_cols[3],
                            ", k=", k_additive, ", bs='", bs, "')")
    }
  } else {
    smooth_part <- paste0("s(", eta_cols[1], ", k=8, bs='", bs, "')")
  }

  rhs <- if (nchar(cov_part) > 0) {
    paste(cov_part, smooth_part, sep = " + ")
  } else {
    smooth_part
  }

  formula_str <- paste("surv_obj ~", rhs)

  tryCatch(
    mgcv::gam(as.formula(formula_str), family = mgcv::cox.ph(),
              data = data, method = method),
    error = function(e) {
      warning("GAM-Cox fitting failed: ", e$message)
      NULL
    }
  )
}


#' Effective Degrees of Freedom Diagnostics
#'
#' @description
#' Extracts EDF, deviance explained, and complexity diagnostics for each
#' transition-specific association surface. EDF near 1 indicates linearity;
#' EDF > 3 indicates substantial nonlinearity/interaction.
#'
#' @param object A \code{"jmSurface"} object.
#'
#' @return Data frame with columns:
#'   \item{transition}{Transition name}
#'   \item{edf}{Effective degrees of freedom of the surface smooth}
#'   \item{deviance_explained}{Proportion of deviance explained}
#'   \item{n_obs}{Number of observations}
#'   \item{n_events}{Number of events}
#'   \item{complexity}{Character label: "Linear", "Moderate", or "Nonlinear"}
#'   \item{p_value}{Approximate p-value for the smooth term}
#'
#' @details
#' The EDF is computed as
#' \deqn{\mathrm{EDF}_{rs} = \mathrm{tr}\{(B'B + \lambda_{rs} S_{rs})^{-1} B'B\}}
#' and represents the realized complexity of the association surface after
#' REML-based penalization. Interpretation:
#' \itemize{
#'   \item \code{EDF ~ 1}: Surface effectively linear; standard parametric JM suffices
#'   \item \code{1 < EDF <= 3}: Moderate nonlinearity
#'   \item \code{EDF > 3}: Substantial nonlinearity and/or interaction effects
#' }
#'
#' @export
edf_diagnostics <- function(object) {

  if (!inherits(object, "jmSurface"))
    stop("object must be of class 'jmSurface'")

  results <- data.frame(
    transition = character(),
    edf = numeric(),
    deviance_explained = numeric(),
    n_obs = integer(),
    n_events = integer(),
    complexity = character(),
    p_value = numeric(),
    stringsAsFactors = FALSE
  )

  for (tr in names(object$gam_fits)) {
    gf <- object$gam_fits[[tr]]
    sm <- .safe_summary_gam(gf)
    ed <- object$eta_data[[tr]]

    edf_val <- if (nrow(sm$s.table) > 0) sm$s.table[1, "edf"] else NA
    p_val <- if (nrow(sm$s.table) > 0) sm$s.table[1, "p-value"] else NA

    complexity <- if (is.na(edf_val)) "Unknown"
                  else if (edf_val <= 1.5) "Linear"
                  else if (edf_val <= 3) "Moderate"
                  else "Nonlinear"

    results <- rbind(results, data.frame(
      transition = tr,
      edf = round(edf_val, 2),
      deviance_explained = round(sm$dev.expl, 4),
      n_obs = nrow(ed),
      n_events = sum(ed$status),
      complexity = complexity,
      p_value = signif(p_val, 3),
      stringsAsFactors = FALSE
    ))
  }

  results
}


## ── Internal: safely call summary.gam for cox.ph models ──
## summary.gam() can fail with "Invalid operation on a survival time"
## when the response is a Surv object. This helper temporarily replaces
## object$y with its numeric representation before calling summary().
.safe_summary_gam <- function(gam_obj) {
  orig_y <- gam_obj$y
  if (inherits(orig_y, "Surv")) {
    gam_obj$y <- as.numeric(orig_y)
  }
  on.exit(gam_obj$y <- orig_y)
  summary(gam_obj)
}
