--- title: "Assessment Metrics for Use in Training Loop" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{metricsDemo} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo=TRUE, comment="", collapse=TRUE, warning=FALSE, message=FALSE, fit.cap="") ``` ```r library(geodl) library(terra) library(yardstick) ``` ## Terminology During the training process, you may choose to monitor assessment metrics alongside of the loss function. Further, the assessment metrics can be calculated separately for the training and validation data for comparison and to assess for overfitting. The **geodl** package provides several classification assessment metrics that have been implemented using the *luz_metric()* function from the **luz** package. This function specifically defines metrics that can be monitored, calculated, and aggregated during the training processes over multiple mini-batches. The following metrics are provided: * *luz_metric_overall_accuracy()*: overall accuracy * *luz_metric_recall()*: class-aggregated, macro-averaged recall (producer's accuracy) * *luz_metric_precision()*: class-aggregated, macro-averaged precision (user's accuracy) * *luz_metric_f1score()*: class-aggregated, macro-averaged F1-score (harmonic mean of precision and recall) Note that macro-averaging is used since micro-averaged recall, precision, and F1-score are all equivalent to each other and to overall accuracy. So, there is no reason to calculate these metrics alongside overall accuracy. The *mode* argument can be set to either "binary" or "multiclass". If "binary", only the logit for the positive class prediction should be provided. If both the positive and negative or background class probabilities are provided for a binary classification, the "multiclass" *mode* should be used. Note that **geodl** is designed to treat all predictions as multiclass. The "binary" *mode* is only provided for use outside of the standard **geodl** workflow. For a binary classification as implemented with **geodl**, if you want to calculate precision, recall, and/or F1-score using only the positive case, you can set the weight for the background class to 0 in the calculation. This can be accomplished using the *clsWghts* parameter. The *biThresh* parameter allows you to select a probability threshold to differentiate positive and negative class predictions. The default is 0.5, or any pixel with a positive case probability, as approximated by applying a sigmoid function to the positive case logit, will be coded to the positive case. All other samples will be coded to the negative case. This parameter is not used for multiclass problems. The *smooth* parameter is used to avoid divide-by-zero errors and to improve numeric stability. The following two examples demonstrate the metrics. These metrics are also used within the example classification workflows provided in other articles. ## Example 1: Binary Classification This first example demonstrates calculating the assessment metrics for a binary classification where the background class is coded as 0 and the positive case is coded as 1. First, the categorical raster grids are read in as spatRaster objects using the *rast()* function from **terra**, and the bands are renamed to aid in interpretability. The refC object represents the reference class numeric codes, the predL object represents the predicted logits (before the data are rescaled with a sigmoid function), and the predC object represent the predicted class indices. ```r refC <- terra::rast("C:/myFiles/data/metricCheck/binary_reference.tif") predL <- terra::rast("C:/myFiles/data/metricCheck/binary_logits.tif") predC <- terra::rast("C:/myFiles/data/metricCheck/binary_prediction.tif") names(refC) <- "reference" names(predC) <- "prediction" ``` For comparison, we first create a confusion matrix by cross-tabulating the grids using the *crosstab()* function from **terra** then calculate assessment metrics using the **yardstick** package from **tidymodels**. ```r cm <- terra::crosstab(c(predC, refC)) ``` ```r yardstick::f_meas(cm, estimator="binary", event_level="second") # A tibble: 1 × 3 .metric .estimator .estimate 1 f_meas binary 0.951 yardstick::accuracy(cm, estimator="binary", event_level="second") # A tibble: 1 × 3 .metric .estimator .estimate 1 accuracy binary 0.994 yardstick::recall(cm, estimator="binary", event_level="second") # A tibble: 1 × 3 .metric .estimator .estimate 1 recall binary 0.973 yardstick::precision(cm, estimator="binary", event_level="second") # A tibble: 1 × 3 .metric .estimator .estimate 1 precision binary 0.929 ``` Before the assessment metrics can be calculated, the spatRaster objects need to be converted to **torch** tensors with the correct dimensionality and data types. This is accomplished as follows: 1. spatRaster objects are converted to R arrays. 2. The R arrays are converted to **torch** tensors with the target codes represented using a long integer data type and the logits represented with a 32-bit float data type. 3. The tensors are permuted so that the channel dimension is first, as is the standard within **torch**. 4. A mini-batch dimension is added at the first position, as is required for use inside of training loops. 5. Two copies are of the tensors are concatenated to simulate a mini-batch with two samples. ```r predL <- terra::as.array(predL) refC <- terra::as.array(refC) target <- torch::torch_tensor(refC, dtype=torch::torch_long()) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32()) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) ``` In the following blocks of code I calculate overall accuracy, F1-score, recall, and precision. This requires (1) instantiating an instance of the metric, (2) updating the metric using the predicted logits and class codes, and (3) computing the metric. ```r metric<-luz_metric_overall_accuracy(nCls=1, smooth=1e-8, mode = "binary", biThresh = 0.5, zeroStart=TRUE, usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.9945027 ``` ```r metric<-luz_metric_f1score(nCls=1, smooth=1e-8, mode = "binary", biThresh = 0.5, zeroStart=TRUE, usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.9521416 ``` ```r metric<-luz_metric_recall(nCls=1, smooth=1e-8, mode = "binary", biThresh = 0.5, zeroStart=TRUE, usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.9782254 ``` ```r metric<-luz_metric_precision(nCls=1, smooth=1e-8, mode = "binary", biThresh = 0.5, zeroStart=TRUE, usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.9274127 ``` ## Example 2: Multiclass Classification A multiclass classification is simulated in this example. The process is very similar to the binary example other than how the metrics are configured. Now, the "multiclass" *mode* is used. The *zeroStart* parameter is set to TRUE. This indicates that the class codes start at 0 as opposed to 1. The *zeroStart* parameter is used within many of **geodl**'s functions. Due to how one-hot encoding is implemented by the **torch** package, indexing starting at 0 can cause issues. By indicating that indexing starts at 0, the workflow will augment the codes so that an error is not generated by one-hot encoding. ```r refC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_reference.tif") predL <- terra::rast("C:/myFiles/data/metricCheck/multiclass_logits.tif") predC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_prediction.tif") names(refC) <- "reference" names(predC) <- "prediction" ``` ```r cm <- terra::crosstab(c(predC, refC)) ``` ```r yardstick::f_meas(cm, estimator="macro") # A tibble: 1 × 3 .metric .estimator .estimate 1 f_meas macro 0.822 yardstick::accuracy(cm, estimator="micro") # A tibble: 1 × 3 .metric .estimator .estimate 1 accuracy multiclass 0.913 yardstick::recall(cm, estimator="macro") # A tibble: 1 × 3 .metric .estimator .estimate 1 recall macro 0.827 yardstick::precision(cm, estimator="macro") # A tibble: 1 × 3 .metric .estimator .estimate 1 precision macro 0.821 ``` ```r predL <- terra::as.array(predL) refC <- terra::as.array(refC) target <- torch::torch_tensor(refC, dtype=torch::torch_long()) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32()) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) ``` ```r metric<-luz_metric_overall_accuracy(nCls=5, smooth=1e-8, mode = "multiclass", zeroStart=TRUE, usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.9134674 ``` ```r metric<-luz_metric_f1score(nCls=5, smooth=1e-8, mode = "multiclass", zeroStart=TRUE, clsWghts = c(1,1,1,1,1), usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.8237987 ``` ```r metric<-luz_metric_recall(nCls=5, smooth=1e-8, mode = "multiclass", zeroStart=TRUE, clsWghts = c(1,1,1,1,1), usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.8265532 ``` ```r metric<-luz_metric_precision(nCls=5, smooth=1e-8, mode = "multiclass", zeroStart=TRUE, clsWghts = c(1,1,1,1,1), usedDS=FALSE) ``` ```r metric<-metric$new() metric$update(pred,target) metric$compute() [1] 0.8210625 ``` Again, you will see more demonstrations of these assessment metrics within the example training loops.