--- title: "Getting started with rkaf" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Getting started with rkaf} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", message = FALSE, warning = FALSE ) rkaf_torch_available <- tryCatch( { torch::torch_manual_seed(123) invisible(torch::torch_tensor(0)) TRUE }, error = function(e) FALSE ) knitr::opts_chunk$set(eval = rkaf_torch_available) ``` ```{r torch-unavailable, echo = FALSE, eval = !rkaf_torch_available} knitr::asis_output( "> Torch/Lantern is not available in this build environment, so the code chunks are shown but not executed." ) ``` # Overview `rkaf` provides Kolmogorov-Arnold Fourier Networks for R users through the `torch` backend. The package supports: - regression - binary classification - multiclass classification - formula and matrix interfaces - mini-batch training - validation splits - early stopping - automatic standardization - best-model restoration - KAF-specific diagnostics This vignette gives a quick tour of the main workflow. ```{r setup} library(rkaf) set.seed(123) torch::torch_manual_seed(123) ``` # Regression with the matrix interface We first fit a KAF model to a synthetic one-dimensional function with both low-frequency and high-frequency structure. ```{r regression-data} x <- as.matrix(seq(-1, 1, length.out = 128)) y <- sin(8 * pi * x) + 0.35 * cos(3 * pi * x) + 0.15 * x^2 ``` ```{r regression-fit} fit <- kaf_fit( x = x, y = y, hidden = c(256, 256), num_grids = 32, use_layernorm = FALSE, epochs = 1000, lr = 1e-3, standardize_x = FALSE, standardize_y = TRUE, fourier_init_scale = 5e-2, restore_best = TRUE, verbose = FALSE, seed = 123 ) fit ``` ```{r regression-predict} pred <- predict(fit, x) head(data.frame( observed = round(as.numeric(y), 3), predicted = round(pred, 3) )) ``` ```{r regression-plot, fig.width=7, fig.height=4} plot( x, y, type = "l", lwd = 2, xlab = "x", ylab = "f(x)", main = "KAF regression fit" ) lines(x, pred, lwd = 2, lty = 2) legend( "topright", legend = c("Observed", "Predicted"), lty = c(1, 2), lwd = 2, bty = "n" ) ``` # Regression with the formula interface For tabular data, `rkaf` also supports a formula interface. ```{r formula-regression} fit_mtcars <- kaf_fit_formula( mpg ~ wt + hp + cyl, data = mtcars, hidden = c(32, 32), num_grids = 16, epochs = 200, verbose = FALSE, seed = 123 ) fit_mtcars ``` ```{r formula-predict} mtcars_pred <- predict(fit_mtcars, mtcars) head(data.frame( observed = mtcars$mpg, predicted = round(mtcars_pred, 2) )) ``` # Binary classification If the response is a factor with two classes, `rkaf` automatically treats the problem as binary classification. ```{r binary-data} df <- mtcars df$high_mpg <- factor( ifelse(df$mpg > median(df$mpg), "yes", "no"), levels = c("no", "yes") ) ``` ```{r binary-fit} fit_binary <- kaf_fit_formula( high_mpg ~ wt + hp + cyl, data = df, hidden = c(32, 32), num_grids = 16, epochs = 200, verbose = FALSE, seed = 123 ) fit_binary ``` Predicted probabilities and classes: ```{r binary-prob} prob_binary <- predict(fit_binary, df, type = "prob") class_binary <- predict(fit_binary, df, type = "class") head(data.frame( observed = df$high_mpg, prob_yes = round(prob_binary, 3), predicted = class_binary )) ``` Confusion matrix ```{r binary-confusion-matrix} table( observed = df$high_mpg, predicted = class_binary ) ``` Raw logits: ```{r binary-link} head(predict(fit_binary, df, type = "link")) ``` # Multiclass classification If the response is a factor with more than two classes, `rkaf` fits a multiclass classifier. ```{r multiclass-fit} fit_iris <- kaf_fit_formula( Species ~ ., data = iris, hidden = c(32, 32), num_grids = 16, epochs = 300, verbose = FALSE, seed = 123 ) fit_iris ``` Confusion matrix ```{r multiclass-confusion-matrix} class_iris <- predict(fit_iris, iris, type = "class") table( observed = iris$Species, predicted = class_iris ) ``` Class probabilities: ```{r multiclass-prob} head(round(predict(fit_iris, iris, type = "prob"), 3)) ``` # Validation and early stopping `kaf_fit()` supports validation splits, mini-batches, and early stopping. ```{r validation-fit, eval=FALSE} fit_val <- kaf_fit( x = x, y = y, hidden = c(64, 64), num_grids = 16, use_layernorm = FALSE, epochs = 300, lr = 5e-4, batch_size = 64, validation_split = 0.2, patience = 100, restore_best = TRUE, verbose = FALSE, seed = 123 ) plot(fit_val) ``` The fitted object stores both `train_loss_history` and `validation_loss_history`, so users can inspect training and validation behavior directly. # KAF diagnostics The KAF architecture contains a base/GELU branch and a Fourier branch. The package exposes helper functions to inspect the learned branch scales and Fourier parameters. ```{r diagnostics-scales} scales <- extract_kaf_scales(fit) head(scales) ``` ```{r diagnostics-fourier} fourier_params <- extract_fourier_params(fit, layer = 1) head(fourier_params) ``` # Low-level torch interface Advanced users can use the low-level torch modules directly. ```{r low-level} model <- nn_kaf( layers = c(4, 16, 16, 1), num_grids = 8 ) x_tensor <- torch::torch_randn(10, 4) y_tensor <- model(x_tensor) y_tensor$shape ``` # Summary The standard workflow is: ```{r workflow, eval=FALSE} fit <- kaf_fit_formula( y ~ ., data = df, hidden = c(64, 64), num_grids = 16, validation_split = 0.2, patience = 30 ) predict(fit, newdata) plot(fit) extract_kaf_scales(fit) ``` For classification, use: ```{r classification-workflow, eval=FALSE} predict(fit, newdata, type = "prob") predict(fit, newdata, type = "class") ```