#' Estimate the Joint Distribution of Two Binary Variables from Marginal Summaries
#'
#' Performs maximum likelihood estimation (MLE) of the joint distribution
#' of two binary variables using only marginal summary data from multiple studies.
#'
#' @param ni Numeric vector. Sample sizes for each dataset.
#' @param xi Numeric vector. Count of observations where variable 1 equals 1.
#' @param yi Numeric vector. Count of observations where variable 2 equals 1.
#' @param ci_method Character string. Method for confidence interval computation.
#'        Options are \code{"none"} (default), \code{"normal"}, or \code{"lr"} (likelihood ratio).
#'
#' @return A named list with point estimates, variance, standard error, and confidence interval (if requested).
#' \describe{
#'   \item{p1_hat}{Estimated marginal probability for variable 1.}
#'   \item{p2_hat}{Estimated marginal probability for variable 2.}
#'   \item{p11_hat}{Estimated joint probability.}
#'   \item{var_hat}{Estimated variance of \code{p11_hat}.}
#'   \item{sd_hat}{Standard error of \code{p11_hat}.}
#'   \item{ci}{Confidence interval for \code{p11_hat}, if requested.}
#' }
#'
#' @examples
#' data(bin_example)
#' cor_bin(bin_example$ni, bin_example$xi, bin_example$yi, ci_method = "lr")
#' @export
#' @importFrom stats optim optimize qnorm qchisq
cor_bin <- function(ni, xi, yi, ci_method = c("none", "normal", "lr")){
  ci_method <- match.arg(ci_method)

  if (!(is.numeric(ni) && is.numeric(xi) && is.numeric(yi))) {
    stop("All inputs must be numeric vectors.")
  }
  if (length(ni) != length(xi) || length(xi) != length(yi)) {
    stop("Input vectors ni, xi, yi must be of equal length.")
  }
  if (anyNA(c(ni, xi, yi))) {
    stop("Input vectors must not contain NA values.")
  }
  if (any(ni %% 1 != 0 | xi %% 1 != 0 | yi %% 1 != 0)) {
    warning("Input counts should be integers; non-integer values will be treated as-is.")
  }

  # Marginal estimates
  sum_ni <- sum(ni)
  sum_xi <- sum(xi)
  sum_yi <- sum(yi)
  p1 <- sum_xi / sum_ni
  p2 <- sum_yi / sum_ni

  data <- data.frame(ni = ni, xi = xi, yi = yi)
  k <- nrow(data)
  ################################################################################
  ################################################################################
  # Log-likelihood term for one z
  tr <- function(z, ni, xi, yi, p1, p2, p11) {
    if (p11 <= 0 || p1 - p11 <= 0 || p2 - p11 <= 0 || 1 - p1 - p2 + p11 <= 0) {
      return(0)
    }
    log_coef <- lfactorial(ni) - lfactorial(z) - lfactorial(xi - z) - lfactorial(yi - z) - lfactorial(ni - xi - yi + z)
    log_prob <- z * log(p11) + (xi - z) * log(p1 - p11) + (yi - z) * log(p2 - p11) + (ni - xi - yi + z) * log(1 - p1 - p2 + p11)
    exp(log_coef + log_prob)
  }
  # Total log-likelihood
  logL <- function(p11) {
    s <- 0
    for (i in 1:k) {
      ni <- data$ni[i]; xi <- data$xi[i]; yi <- data$yi[i]
      z_start <- max(0, xi + yi - ni)
      lim <- min(xi, yi)
      s <- s + log(sum(sapply(z_start:lim, tr, ni = ni, xi = xi, yi = yi, p1 = p1, p2 = p2, p11 = p11)))
    }
    return(s)
  }

  # Observed Fisher Information
  fishinfo <- function(p11) {
    info <- 0
    for (i in 1:k) {
      ni <- data$ni[i]; xi <- data$xi[i]; yi <- data$yi[i]
      z_start <- max(0, xi + yi - ni); lim <- min(xi, yi)
      probs <- sapply(z_start:lim, tr, ni = ni, xi = xi, yi = yi, p1 = p1, p2 = p2, p11 = p11)
      b <- sum(probs)
      d1 <- sapply(z_start:lim, function(z) {
        z/p11 - (xi - z)/(p1 - p11) - (yi - z)/(p2 - p11) +
          (ni - xi - yi + z)/(1 - p1 - p2 + p11)
      })
      d2 <- sapply(z_start:lim, function(z) {
        z/p11^2 + (xi - z)/(p1 - p11)^2 + (yi - z)/(p2 - p11)^2 +
          (ni - xi - yi + z)/(1 - p1 - p2 + p11)^2
      })
      f1 <- sum(probs * (d1^2 - d2))
      f2 <- sum(probs * d1)^2
      info <- info + (b * f1 - f2) / b^2
    }
    return(-1 / info)
  }
  ################################################################################
  ################################################################################
  # Optimize p11
  grid <- which.max(sapply(seq(0.001, min(p1, p2) - 0.001, by = 0.01), logL)) / 100
  opt <- optim(par = grid, fn = logL, method = "L-BFGS-B",
               lower = max(0, p1 + p2 - 1) + 1e-3,
               upper = min(p1, p2) - 1e-3,
               control = list(fnscale = -1))
  p11_hat <- opt$par

  if (logL(grid) > logL(p11_hat)) {
    warning("Numerical optimization result worse than initial grid search value.")
    p11_hat <- grid
  }

  # Standard error and CI
  var_hat <- fishinfo(p11_hat)
  if (var_hat < 0) {
    warning("Estimated variance is negative. The result may be unreliable due to data sparsity or numerical instability.")
    sd_hat <- NA
  } else {
    sd_hat <- sqrt(var_hat)
  }

  ci <- NULL
  if (ci_method == "normal" && !is.na(sd_hat)) {
    z <- qnorm(0.975)
    ci <- c(lower = p11_hat - z * sd_hat, upper = p11_hat + z * sd_hat)
  }
  if (ci_method == "lr") {
    bound <- -0.5 * qchisq(0.95, 1) + logL(p11_hat)
    f <- function(p) abs(logL(p) - bound)
    lwr <- optimize(f, interval = c(max(0, p1 + p2 - 1) + 1e-5, p11_hat))$minimum
    upr <- optimize(f, interval = c(p11_hat, min(p1, p2) - 1e-5))$minimum
    ci <- c(lower = lwr, upper = upr)
  }

  return(list(
    p1_hat = p1,
    p2_hat = p2,
    p11_hat = p11_hat,
    var_hat = var_hat,
    sd_hat = sd_hat,
    ci = ci
  ))
}

