## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 4
)

## ----install, eval=FALSE------------------------------------------------------
# install.packages("remotes")
# remotes::install_github("qc-zhao/VCMoE")

## -----------------------------------------------------------------------------
library(VCMoE)

## ----simulate-----------------------------------------------------------------
sim <- simulate_vcmoe_gaussian(
  n = 240,
  k = 2,
  seed = 1,
  separation = 1.4,
  scenario = "well_separated"
)

head(sim$data)
str(sim$truth, max.level = 1)

## ----fit----------------------------------------------------------------------
fit <- vcmoe_fit(
  y ~ z1 | x1,
  data = sim$data,
  u = "u",
  family = "gaussian",
  k = 2,
  bandwidth = 0.30,
  u_grid = seq(0.1, 0.9, length.out = 5),
  control = list(maxit = 80, n_starts = 2, seed = 2)
)

fit

## ----coefficients-------------------------------------------------------------
expert_coef <- coef(fit, "expert")
dim(expert_coef)
expert_coef[, , "z1"]

## ----predictions--------------------------------------------------------------
posterior <- predict(fit, type = "posterior")
head(posterior)
rowSums(head(posterior))

fitted_mean <- predict(fit, type = "mean")
head(fitted_mean)

## ----diagnostics--------------------------------------------------------------
diagnostics <- vcmoe_diagnostics(fit)
diagnostics[, c("u", "converged", "ambiguous", "posterior_entropy", "effective_n")]

## ----coefficient-plot---------------------------------------------------------
plot_coefficients(fit, "expert")

## ----posterior-plot-----------------------------------------------------------
plot_posterior(fit)

## ----bandwidth, eval=FALSE----------------------------------------------------
# selection <- vcmoe_select_bandwidth(
#   y ~ z1 | x1,
#   data = sim$data,
#   u = "u",
#   family = "gaussian",
#   k = 2,
#   bandwidth_grid = c(0.24, 0.30, 0.36),
#   folds = 3,
#   u_grid = seq(0.1, 0.9, length.out = 5),
#   control = list(maxit = 80, n_starts = 2, seed = 3),
#   seed = 4
# )
# 
# selection
# selection$best_bandwidth

