#' Variance Function Panel Regression
#'
#' @description
#' Implements an iterative mean-variance panel regression estimator that allows
#' both the mean and variance of the dependent variable to be functions of
#' covariates. Based on Mooi-Reci and Liao (2025).
#'
#' @param formula A formula specifying the dependent variable
#' @param data A data frame containing the variables
#' @param group A character string naming the grouping variable
#' @param panel_id A character string naming the panel identifier
#' @param mean_vars A character vector of variable names for the mean equation
#' @param var_vars A character vector of variable names for the variance equation
#' @param weights Optional character string naming the weight variable
#' @param subset Optional logical vector for subsetting
#' @param converge Convergence tolerance (default: 1e-6)
#' @param max_iter Maximum iterations (default: 100)
#' @param verbose Logical; print iteration history? (default: TRUE)
#'
#' @return An object of class "xtvfreg" containing:
#' \item{results}{List of results for each group}
#' \item{groups}{Vector of group values}
#' \item{call}{The matched call}
#' \item{convergence}{Convergence information for each group}
#' \item{variance_decomp}{Variance decomposition for each group}
#' \item{depvar}{Name of dependent variable}
#' \item{panel_id}{Name of panel identifier}
#' \item{group_var}{Name of grouping variable}
#' \item{mean_vars}{Variables in mean equation}
#' \item{var_vars}{Variables in variance equation}
#'
#' @examples
#' # Example using nlswork subset data
#' data(nlswork_subset)
#' 
#' # Prepare the data
#' # Keep only observations with complete wage data and white/black race
#' analysis_data <- subset(nlswork_subset, 
#'                         !is.na(ln_wage) & 
#'                         !is.na(tenure) & 
#'                         race %in% 1:2)
#' 
#' # Create race grouping variable
#' analysis_data$race_group <- factor(analysis_data$race,
#'                                    levels = 1:2,
#'                                    labels = c("white", "black"))
#' 
#' # Create within and between components for tenure
#' analysis_data$m_tenure <- ave(analysis_data$tenure,
#'                               analysis_data$idcode,
#'                               FUN = function(x) mean(x, na.rm = TRUE))
#' analysis_data$d_tenure <- analysis_data$tenure - analysis_data$m_tenure
#' 
#' # Estimate varying effects model
#' result <- xtvfreg(
#'   formula = ln_wage ~ 1,
#'   data = analysis_data,
#'   group = "race_group",
#'   panel_id = "idcode",
#'   mean_vars = c("m_tenure", "d_tenure", "age"),
#'   var_vars = c("m_tenure"),
#'   verbose = FALSE
#' )
#' 
#' # View a summary of results
#' summary(result)
#' 
#' # Extract coefficients for white group if needed
#' coef(result, equation = "mean", group = "white")
#'
#' @references
#' Mooi-Reci, I., and Liao, T. F. (2025). Unemployment: a hidden source of wage
#' inequality? European Sociological Review, 41(3), 382-401.
#' \doi{10.1093/esr/jcae052}
#'
#' @importFrom stats as.formula glm gaussian Gamma predict residuals vcov coef var ave
#' @export
xtvfreg <- function(formula,
                    data,
                    group,
                    panel_id,
                    mean_vars,
                    var_vars,
                    weights = NULL,
                    subset = NULL,
                    converge = 1e-6,
                    max_iter = 100,
                    verbose = TRUE) {
  
  # Store call
  cl <- match.call()
  
  # Extract dependent variable from formula
  depvar <- all.vars(formula)[1]
  
  # Apply subset if provided
  if (!is.null(subset)) {
    data <- data[subset, ]
  }
  
  # Convert panel_id and group to factors if they aren't already
  # This prevents issues with droplevels() in ave() and plm functions
  if (!is.factor(data[[panel_id]])) {
    data[[panel_id]] <- factor(data[[panel_id]])
  }
  if (!is.factor(data[[group]])) {
    data[[group]] <- factor(data[[group]])
  }
  
  # Get group levels
  groups <- levels(data[[group]])
  n_groups <- length(groups)
  
  if (verbose) {
    cat("\nVarying Fixed Effects Panel Regression\n")
    cat("Dependent variable:", depvar, "\n")
    cat("Panel ID:", panel_id, "\n")
    cat("Group variable:", group, "\n")
    cat("Number of groups:", n_groups, "\n")
    cat("Convergence criterion:", converge, "\n")
    cat("Maximum iterations:", max_iter, "\n")
    cat(strrep("-", 78), "\n")
  }
  
  # Initialize results storage
  results <- list()
  convergence_info <- list()
  variance_decomp <- list()
  
  # Loop over groups
  for (i in seq_along(groups)) {
    g <- groups[i]
    
    if (verbose) {
      cat(sprintf("\nGroup %s (%d of %d)\n", g, i, n_groups))
      cat(strrep("-", 78), "\n")
    }
    
    # Subset data for this group
    group_data <- data[data[[group]] == g, , drop = FALSE]
    
    # Ensure panel_id is a factor after subsetting
    if (!is.factor(group_data[[panel_id]])) {
      group_data[[panel_id]] <- factor(group_data[[panel_id]])
    }
    
    n_obs <- nrow(group_data)
    
    if (verbose) {
      cat("Observations:", n_obs, "\n")
    }
    
    # Estimate for this group
    group_result <- estimate_group(
      data = group_data,
      depvar = depvar,
      panel_id = panel_id,
      mean_vars = mean_vars,
      var_vars = var_vars,
      weights = weights,
      converge = converge,
      max_iter = max_iter,
      verbose = verbose
    )
    
    # Calculate variance decomposition
    var_decomp <- calculate_variance_decomposition(
      data = group_data,
      depvar = depvar,
      mean_model = group_result$mean_model,
      s2 = group_result$s2,
      verbose = verbose
    )
    
    # Store results
    results[[as.character(g)]] <- group_result
    convergence_info[[as.character(g)]] <- list(
      n_iter = group_result$n_iter,
      converged = group_result$converged,
      ll_final = group_result$ll_final
    )
    variance_decomp[[as.character(g)]] <- var_decomp
  }
  
  if (verbose) {
    cat("\n\nSummary\n")
    cat(strrep("-", 78), "\n")
    cat("Total groups estimated:", n_groups, "\n")
  }
  
  # Create return object
  out <- list(
    results = results,
    groups = groups,
    call = cl,
    convergence = convergence_info,
    variance_decomp = variance_decomp,
    depvar = depvar,
    panel_id = panel_id,
    group_var = group,
    mean_vars = mean_vars,
    var_vars = var_vars
  )
  
  class(out) <- "xtvfreg"
  return(out)
}


#' Calculate variance decomposition
#' @keywords internal
calculate_variance_decomposition <- function(data, depvar, mean_model, s2, verbose = TRUE) {
  
  # Total variance
  var_total <- var(data[[depvar]], na.rm = TRUE)
  
  # Variance of fitted values (mean model)
  fitted_vals <- predict(mean_model, type = "response")
  var_fitted <- var(fitted_vals, na.rm = TRUE)
  
  # Mean of estimated variance function
  var_heterosced <- mean(s2, na.rm = TRUE)
  
  # Proportions
  prop_mean <- var_fitted / var_total
  prop_var <- var_heterosced / var_total
  prop_unexplained <- 1 - prop_mean - prop_var
  
  # Display
  if (verbose) {
    cat("\nVariance Decomposition:\n")
    cat(strrep("-", 78), "\n")
    cat(sprintf("Total variance of %s: %9.6f\n", depvar, var_total))
    cat(sprintf("  Variance explained by mean model: %9.6f (%5.1f%%)\n", 
                var_fitted, prop_mean * 100))
    cat(sprintf("  Variance explained by variance model: %9.6f (%5.1f%%)\n",
                var_heterosced, prop_var * 100))
    cat(sprintf("  Unexplained variance: %9.6f (%5.1f%%)\n",
                var_total - var_fitted - var_heterosced, prop_unexplained * 100))
  }
  
  list(
    var_total = var_total,
    var_fitted = var_fitted,
    var_heterosced = var_heterosced,
    prop_mean = prop_mean,
    prop_var = prop_var,
    prop_unexplained = prop_unexplained
  )
}

#' Estimate model for a single group
#' @keywords internal
estimate_group <- function(data,
                           depvar,
                           panel_id,
                           mean_vars,
                           var_vars,
                           weights = NULL,
                           converge = 1e-6,
                           max_iter = 100,
                           verbose = TRUE) {
  
  # Convert panel_id to factor if it's character
  if (is.character(data[[panel_id]])) {
    data[[panel_id]] <- as.factor(data[[panel_id]])
  }
  
  # Prepare weight variable
  if (!is.null(weights)) {
    wgt <- data[[weights]]
  } else {
    wgt <- rep(1, nrow(data))
  }
  
  # Create formulas
  mean_formula <- as.formula(paste(depvar, "~", paste(mean_vars, collapse = " + ")))
  
  # Initial estimation - Mean equation
  mean_model_init <- glm(
    mean_formula,
    data = data,
    family = gaussian(link = "identity"),
    weights = wgt
  )
  
  # Get residuals
  resid_a <- residuals(mean_model_init, type = "response")
  
  # Demean residuals (within transformation) using ave
  resid_mean <- ave(resid_a, data[[panel_id]], FUN = mean)
  resid_within <- resid_a - resid_mean
  
  # Square the within residuals
  r2 <- resid_within^2
  
  # Add a small constant to avoid zeros for Gamma distribution
  r2 <- pmax(r2, 1e-10)  # Ensure all values are positive
  
  # Initial variance equation
  var_formula <- as.formula(paste("r2 ~", paste(var_vars, collapse = " + ")))
  var_data <- data.frame(r2 = r2, data[, var_vars, drop = FALSE])
  
  var_model_init <- glm(
    var_formula,
    data = var_data,
    family = Gamma(link = "log"),
    weights = wgt
  )
  
  # Predicted variance
  s2 <- predict(var_model_init, type = "response")
  
  # Calculate initial log-likelihood
  loglik_init <- sum(-0.5 * (log(s2) + (r2 / s2)))
  ll_old <- loglik_init
  
  if (verbose) {
    cat("\nInitial Estimation:\n")
    cat(sprintf("  Initial log-likelihood = %10.4f\n", loglik_init))
    cat("\nIteration History:\n")
    cat(strrep("-", 50), "\n")
    cat(sprintf("%-4s %15s %12s %12s\n", "Iter", "Log-likelihood", "Change", "Criterion"))
    cat(strrep("-", 50), "\n")
  }
  
  # Iterative estimation
  iter <- 0
  diff <- Inf
  converged <- FALSE
  
  while (diff > converge && iter < max_iter) {
    iter <- iter + 1
    
    # Combined weights
    if (!is.null(weights)) {
      combined_wgt <- wgt / s2
    } else {
      combined_wgt <- 1 / s2
    }
    
    # Re-estimate mean equation with weights
    mean_model <- glm(
      mean_formula,
      data = data,
      family = gaussian(link = "identity"),
      weights = combined_wgt
    )
    
    # Get new residuals
    resid_a <- residuals(mean_model, type = "response")
    
    # Within transformation
    resid_mean <- ave(resid_a, data[[panel_id]], FUN = mean)
    resid_within <- resid_a - resid_mean
    
    # Square the within residuals
    r2 <- resid_within^2
    
    # Add a small constant to avoid zeros for Gamma distribution
    r2 <- pmax(r2, 1e-10)  # Ensure all values are positive
    
    # Re-estimate variance equation
    var_data$r2 <- r2
    var_model <- glm(
      var_formula,
      data = var_data,
      family = Gamma(link = "log"),
      weights = wgt
    )
    
    # Update predicted variance
    s2 <- predict(var_model, type = "response")
    
    # Calculate log-likelihood
    ll_new <- sum(-0.5 * (log(s2) + (r2 / s2)))
    
    # Check convergence
    diff <- abs(ll_new - ll_old)
    
    if (verbose) {
      cat(sprintf("%4d %15.4f %12.6f %12.6f\n", iter, ll_new, diff, converge))
    }
    
    ll_old <- ll_new
    
    if (diff <= converge) {
      converged <- TRUE
    }
  }
  
  if (verbose) {
    cat(strrep("-", 50), "\n")
    if (converged) {
      cat(sprintf("\nConverged in %d iterations\n", iter))
    } else {
      cat(sprintf("\nWarning: Maximum iterations (%d) reached without convergence\n", max_iter))
    }
  }
  
  # Return results
  list(
    mean_model = mean_model,
    var_model = var_model,
    n_iter = iter,
    converged = converged,
    ll_init = loglik_init,
    ll_final = ll_new,
    s2 = s2,
    r2 = r2
  )
}