--- title: "Unified Focal Loss Framework" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{unifiedFocalLossDemo} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo=TRUE, comment="#>", collapse=TRUE, warning=FALSE, message=FALSE, fit.cap="") ``` ```r library(geodl) library(torch) ``` ```r device <- torch_device("cuda") ``` ## Background **geodl** provides a modified version of the **unified focal loss** proposed in the following study: Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified focal loss: Generalising dice and cross entropy-based losses to handle class imbalanced medical image segmentation. *Computerized Medical Imaging and Graphics*, 95, p.102026. The table below describes the loss parameterization. Our implementation is different from the originally proposed implementation because we do not implement symmetric and asymmetric forms. Instead, we allow the user to define class weights for both the distribution- and region-based loss components. ![Modified unified focal loss parameterization](figRend/unifiedFocalLossTable.PNG){width=80%} The *lambda* parameter controls the relative weighting between the distribution- and region-based losses. When *lambda* = 1, only the distribution-based loss is used. When *lambda* = 0, only the region-based loss is used. Values between 0 and 1 yield a loss metric that incorporate both the distribution- and region-based loss components. A *lambda* of 0.5 yield equal weighting, values larger than 0.5 put more weight on the distribution-based loss, and values lower than 0.5 put more weighting on the region-based loss. The *gamma* parameter controls the weight applied to difficult-to-predict samples, defined as samples or classes with low predicted rescaled logits relative to their correct class. For the distribution-based loss, focal corrections are implemented sample-by-sample. For the region-based loss, corrections are implemented class-by-class during macro-averaging. *gamma* must be larger than 0 and less than or equal to 1. When *gamma* = 1, no focal correction is applied. Lower values result in a larger focal correction. The *delta* parameter controls the relative weights of FN and FP errors and should be between 0 and 1. A *delta* of 0.5 places equal weight on FN and FP errors. Values larger than 0.5 place more weight on FN in comparison to FP samples while values smaller than 0.5 place more weight on FP samples in comparison to FN samples. The *clsWghtsDist* parameter controls the relative weights of classes in the distribution-based loss and is applied sample-by-sample. The *clsWghtsReg* parameter controls the relative weights of classes in the region-based loss and are applied to each class when calculating a macro average. By default, all classes are weighted equally. If you want to implement class weights, you must provide a vector of class weights equal in length to the number of classes being differentiated. Lastly, the *useLogCosH* parameter determines whether or not to apply a log cosh transformation to the region-based loss. If it is set to TRUE, this transformation is applied. Using different parameterization, users can define a variety of loss metrics including cross entropy (CE) loss, weighted CE loss, focal CE loss, focal weighted CE loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss. ## Examples We will now demonstrate how different loss metrics can be obtained using different parameterizations. We first load in example data using the *rast()* function from **terra** representing class reference numeric codes (refC) and predicted class logits (predL). ```r refC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_reference.tif") predL <- terra::rast("C:/myFiles/data/metricCheck/multiclass_logits.tif") ``` The spatRaster objects are then converted to **torch** tensors with the correct shape and data type. We simulate a mini-batch of two samples by concatenating two copies of the tensors. ```r predL <- terra::as.array(predL) refC <- terra::as.array(refC) target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) ``` ### Example 1: Dice Loss The **Dice loss** is obtained by setting the *lambda* parameter to 0, the *gamma* parameter to 1, and the *delta* parameter to 0.5. This results in only the region-based loss being considered, no focal correction being applied, and equal weighting between FN and FP errors. ```r myDiceLoss <- defineUnifiedFocalLoss(nCls=5, lambda=0, #Only use region-based loss gamma= 1, delta= 0.5, #Equal weights for FP and FN smooth = 1e-8, zeroStart=TRUE, clsWghtsDist=1, clsWghtsReg=1, useLogCosH =FALSE, device=device) myDiceLoss(pred=pred, target=target) torch_tensor 0.17861 [ CUDAFloatType{} ] ``` ### #Example 2: Tversky Loss The **Tversky Loss** can be obtained using the same settings as those used for the **Dice loss** except that the *delta* parameter must be set to a value other than 0.5 so that different weights are applied to FN and FP errors. In the example, we use a weighting of 0.6, which places more weight on FN errors relative to FP errors. Setting *gamma* to a value lower than 1 results in a **focal Tversky loss**. Note that we regenerate the tensors so that the computational graphs are re-initialized. ```r target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) myTverskyLoss <- defineUnifiedFocalLoss(nCls=5, lambda=0, #Only use region-based loss gamma= 1, delta= 0.6, #FN weighted higher than FP smooth = 1e-8, zeroStart=TRUE, clsWghtsDist=1, clsWghtsReg=1, useLogCosH =FALSE, device=device) myTverskyLoss(pred=pred, target=target) torch_tensor 0.177986 [ CUDAFloatType{} ] ``` ### Example 3: Cross Entropy (CE) Loss The **cross entropy** (**CE**) loss is obtained by setting *lambda* to 1, so that only the distribution-based loss is considered, and setting *gamma* to 1, so that no focal correction is applied. Setting *gamma* to a value lower than 1 results in a **focal CE loss**. ```r target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) myCELoss <- defineUnifiedFocalLoss(nCls=5, lambda=1, #Only use distribution-based loss gamma= 1, delta= 0.5, smooth = 1e-8, zeroStart=TRUE, clsWghtsDist=1, clsWghtsReg=1, useLogCosH =FALSE, device=device) myCELoss(pred=pred, target=target) torch_tensor 1.58689 [ CUDAFloatType{} ] ``` ### Example 4: Combo-Loss A **combo-loss** can be obtained by setting *lambda* to a value between 0 and 1. In the example, we have used 0.5, which results in equal weights being applied to the distribution- and region-based losses. We also apply a focal correction using a *gamma* of 0.8 and weight FN errors higher than FP errors by using a *delta* of 0.6. The result is a combination of the **focal CE** and **focal Tversky** losses. ```r target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device) pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device) target <- target$permute(c(3,1,2)) pred <- pred$permute(c(3,1,2)) target <- target$unsqueeze(1) pred <- pred$unsqueeze(1) target <- torch::torch_cat(list(target, target), dim=1) pred <- torch::torch_cat(list(pred, pred), dim=1) myComboLoss <- defineUnifiedFocalLoss(nCls=5, lambda=.5, #Use both distribution and region-based losses gamma= 0.8, #Apply a focal adjustment delta= 0.6, #Weight FN higher than FP smooth = 1e-8, zeroStart=TRUE, clsWghtsDist=1, clsWghtsReg=1, useLogCosH =FALSE, device=device) myComboLoss(pred=pred, target=target) torch_tensor 0.9143 [ CUDAFloatType{1} ] ```