Extending lolR with Arbitrary Classification Algorithms

Eric Bridgeford

2020-06-25

Writing New Classification Algorithms

The lolR package makes it easy for users to write their own classification algorithms for cross-validation.

Writing a Classification Method

For example, consider the below classification algorithm built into the platform, the nearestCentroid classifier:

#' Nearest Centroid Classifier Training
#'
#' A function that trains a classifier based on the nearest centroid.
#' @param X \code{[n, d]} the data with \code{n} samples in \code{d} dimensions.
#' @param Y \code{[n]} the labels of the \code{n} samples.
#' @param ... optional args.
#' @return A list of class \code{nearestCentroid}, with the following attributes:
#' \item{centroids}{\code{[K, d]} the centroids of each class with \code{K}  classes in \code{d} dimensions.}
#' \item{ylabs}{\code{[K]} the ylabels for each of the \code{K} unique classes, ordered.}
#' \item{priors}{\code{[K]} the priors for each of the \code{K} classes.}
#' @author Eric Bridgeford
#'
#' @examples
#' library(lolR)
#' data <- lol.sims.rtrunk(n=200, d=30)  # 200 examples of 30 dimensions
#' X <- data$X; Y <- data$Y
#' model <- lol.classify.nearestCentroid(X, Y)
#' @export
lol.classify.nearestCentroid <- function(X, Y, ...) {
  # class data
  classdat <- lol.utils.info(X, Y)
  priors <- classdat$priors; centroids <- t(classdat$centroids)
  K <- classdat$K; ylabs <- classdat$ylabs
  model <-  list(centroids=centroids, ylabs=ylabs, priors=priors)
  return(structure(model, class="nearestCentroid"))
}

As we can see in the above segment, the function lol.classify.nearestCentroid returns a list of parameters for the nearestCentroid model. To use many of the lol functionality, researchers can trivially write a classification method following the below spec:

Inputs:
keyworded arguments for:
- X: a [n, d] data matrix with n samples in d dimensions.
- Y: a [n] vector of class labels for each sample.
- <additional-arguments>: additional arguments and hyperparameters your algorithm requires.
Outputs:
a list containing the following:
- <param1>: the first parameter of your model required for prediction.
- <param2>: the second parameter of your model required for prediction.
- ...: additional outputs you might need.

For example, my classifier takes the following arguments:

Inputs:
keyworded arguments for:
- X: a [n, d] data matrix with n samples in d dimensions.
- Y: a [n] vector of class labels for each sample.
Outputs:
a list containing the following:
- centroids: a [K, d] the centroids for each of the K classes.
- ylabs: [K] the label names associated with each of the K classes.
- priors:  [K] the priors for each of the K classes.

Note that the inputs MUST be named X, Y.

Your classifier will produce results as follows:

# given X, Y your data matrix and class labels as above
model <- lol.classify.nearestCentroid(X, Y)

Writing a prediction method

To use the lol.xval.eval, your classification technique must be compatible with the S3 method stats::predict. Below is an example of the prediction method for the nearestCentroid classifier shown above:

#' Nearest Centroid Classifier Prediction
#'
#' A function that predicts the class of points based on the nearest centroid
#' @param object An object of class \code{nearestCentroid}, with the following attributes:
#' \itemize{
#' \item{centroids}{\code{[K, d]} the centroids of each class with \code{K} classes in \code{d} dimensions.}
#' \item{ylabs}{\code{[K]} the ylabels for each of the \code{K} unique classes, ordered.}
#' \item{priors}{\code{[K]} the priors for each of the \code{K} classes.}
#' }
#' @param X \code{[n, d]} the data to classify with \code{n} samples in \code{d} dimensions.
#' @param ... optional args.
#' @return Yhat \code{[n]} the predicted class of each of the \code{n} data point in \code{X}.
#' @author Eric Bridgeford
#'
#' @examples
#' library(lolR)
#' data <- lol.sims.rtrunk(n=200, d=30)  # 200 examples of 30 dimensions
#' X <- data$X; Y <- data$Y
#' model <- lol.classify.nearestCentroid(X, Y)
#' Yh <- predict(model, X)
#' @export
predict.nearestCentroid <- function(object, X, ...) {
  K <- length(object$ylabs); n <-  dim(X)[1]
  dists <- array(0, dim=c(n, K))
  for (i in 1:n) {
    for (j in 1:K) {
      dists[i, j] <- sqrt(sum((X[i,] - object$centroids[j,])^2))
    }
  }
  Yass <- apply(dists, c(1), which.min)
  Yhat <- object$ylabs[Yass]
  return(Yhat)
}

As we can see, the predict.nearestCentroid prediction takes as arguments an object input, and a data matrix X of points to classify. To be compatible with lol.xval.eval, your method should obey the following spec:

Inputs:
- object: a list containing the parameters required by your model for prediction. This is required by stats::predict.
- X: a [n, d] data matrix with n samples in d dimensions to predict.
Outputs:
A list containing the following (At least one of <your-output#> should be labels for the n samples):
- <your-output1>: the first output of your classification technique.
- <your-output2>: the second output of your classification technique.
- ... additional outputs you may want.
OR
- <your-prediction-labels>: [n] the prediction labels for each of the n samples.

For example, my prediction can be follows the following API:

Inputs:
- object: a list containing the parameters required by your model for prediction. This is required by stats::predict.
- X: a [n, d] data matrix with n samples in d dimensions to predict.
Outputs:
- Yhat: [n] the prediction labels for each of the n samples.

At least one of the outputs of your prediction method should contain the prediction labels. In my above example, I simply return the labels themselves, but you may want to return a list where one of the outputs are the prediction labels.

Using your Classification Technique for Embedding Evaluation

If your algorithm follows the above spec, you can easily use it with the lol.xval.eval for classification accuracy of your embedding algorithm. Having the algorithm and its prediction technique sourced:

classifier = <your-classifier>
# if your classifier requires additional hyperparameters, a keyworded list
# conaining additional arguments to train your classifier
classifier.opts = list(<additional-arguments>)
# if  your classifier requires no additional hyperparameters
classifier.opts = NaN
# if your classifier prediction returns a list containing the class labels
classifier.return = <return-labels-argname-string>
# if your classifier prediction returns only the class labels
classifier.return = NaN

For example, my algorithm can be be set up as follows:

classifier = lol.classify.nearestCentroid
classifier.opts = NaN  # my classifier takes no additional arguments, so NaN
classifier.return = NaN  # my classifier returns only the prediction labels, so NaN

The algorithm can then be incorporated as a classification technique to evaluate prediction accuracy after performing an embedding:

# given data X, Y as above
xval.out <- lol.xval.eval(X, Y, alg=<your-algorithm>, alg.opts=<your-algorithm-opts>,
                          alg.return = <your-algorithm-embedding-matrix>,
                          classifier=classifier, classifier.opts=classifier.opts,
                          classifier.return=classifier.return, k=<k>)

See the tutorial vignette extend_embedding for details on how to specify alg, alg.opts, and alg.return for a custom embedding algorithm.

Note for instance that the randomForest package includes the rf classification technique that is compabible with this spec:

require(randomForest)
classifier=randomForest
# use the randomForest classifier with the similarity matrix argument
classifier.opts = list(prox=TRUE)
classifier.return = NaN  # predict.randomForest returns only the prediction labels
xval.out <- lol.xval(X, Y, alg=<your-algorithm>, alg.opts=<your-algorithm-opts>,
                     alg.return = <your-algorithm-embedding-matrix>,
                     classifier=classifier, classifier.opts=classifier.opts,
                     classifier.return=classifier.return, k=<k>)

Now, you should be able to use your user-defined classification technique, or external classification techniques implementing the S3 method stats::predict, with the lol package.