Custom loops with luz

library(torch)
library(luz)

Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.

In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.

Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):

Let’s consider a simplified version of the net that we implemented in the getting started vignette:

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  }
)

Using the highest level of luz API we would fit it using:

fitted <- net %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy
    )
  ) %>%
  fit(train_dl, epochs = 10, valid_data = test_dl)

Multiple optimizers

Suppose we want to do an experiment where we train the first fully connected layer using a learning rate of 0.1 and the second one using a learning rate of 0.01. We will minimize the same nn_cross_entropy_loss() for both, but for the first layer we want to add L1 regularization on the weights.

In order to use luz for this, we will implement two methods in the net module:

Let’s go to the code:

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  loss = function(input, target) {
    pred <- ctx$model(input)
  
    if (ctx$opt_name == "opt_fc1") 
      nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1)
    else if (ctx$opt_name == "opt_fc2")
      nnf_cross_entropy(pred, target)
  }
)

Notice that the model optimizers will be initialized according to the set_optimizers() method’s return value (a list). In this case, we are initializing the optimizers using different model parameters and learning rates.

The loss() method is responsible for computing the loss that will then be back-propagated to compute gradients and update the weights. This loss() method can access the ctx object that will contain an opt_name field, describing which optimizer is currently being used. Note that this function will be called once for each optimizer for each training and validation step. See help("ctx") for complete information about the context object.

We can finally setup and fit this module, however we no longer need to specify optimizers and loss functions.

fitted <- net %>% 
  setup(metrics = list(luz_metric_accuracy)) %>% 
  fit(train_dl, epochs = 10, valid_data = test_dl)

Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.

Fully flexible step

Instead of implementing the loss() method, we can implement the step() method. This allows us to flexibly modify what happens when training and validating for each batch in the dataset. You are now responsible for updating the weights by stepping the optimizers and back-propagating the loss.

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  step = function() {
    ctx$loss <- list()
    for (opt_name in names(ctx$optimizers)) {
    
      pred <- ctx$model(ctx$input)
      opt <- ctx$optimizers[[opt_name]]
      loss <- nnf_cross_entropy(pred, target)
      
      if (opt_name == "opt_fc1") {
        # we have L1 regularization in layer 1
        loss <- nnf_cross_entropy(pred, target) + 
          torch_norm(self$fc1$weight, p = 1)
      }
        
      if (ctx$training) {
        opt$zero_grad()
        loss$backward()
        opt$step()  
      }
      
      ctx$loss[[opt_name]] <- loss$detach()
    }
  }
)

The important things to notice here are:

Next steps

In this article you learned how to customize the step() of your training loop using luz layered functionality.

Luz also allows more flexible modifications of the training loop described in the Accelerator vignette (vignette("accelerator")).

You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.