A Simple functorch Example

Author

Rohan Sikand

Published

November 5, 2022

← back to home, blog


In recent years, there has been a small movement of people trying to go from stateful (Python OOP, class-based modules) to stateless (pure functions) neural network code. The standard PyTorch nn.module is indeed a OOP-based class. But more recent libraries such as JAX, introduce the ability to feasiby create stateless (just functions!) machine learning models. Now, in PyTorch version 1.13, we have functorch in-tree (in the main package). Why stateless? Read this blog post for the differences, but some reasons for why I like stateless code is because:

This might sound like an advertisement for JAX (which might be coming up in a future blog post!), but it is really to set the stage for functorch. functorch is a library that allows you to accomplish nearly all of the above, but in PyTorch! The basic idea is to purify stateful PyTorch modules into stateless functions like this (source):

import torch
import functorch
from functorch import make_functional

model = torch.nn.Linear(3, 3)
func_model, params = make_functional(model)

As functorch is relatively new, there aren’t many examples out there showing how to use the library. So the goal of the rest of this post is to provide a simple example for creating an image classifier using functorch and PyTorch and updating the weights using SGD (no torch.optim!).

Here is the code:

import torch
import torchplate
from torchplate import experiment
from torchplate import utils
import functorch
from functorch import grad, grad_and_value
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from rsbox import ml
# import torchopt
import requests
from tqdm.auto import tqdm
import cloudpickle as cp
from urllib.request import urlopen



class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3*32*32, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 16)
        self.fc4 = nn.Linear(16, 3)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x



class OptExp:
    def __init__(self): 
        self.model_module = Net()
        self.criterion = nn.CrossEntropyLoss()
        dataset = cp.load(urlopen("https://stanford.edu/~rsikand/assets/datasets/mini_cifar.pkl")) 
        self.trainloader, self.testloader = torchplate.utils.get_xy_loaders(dataset)
        self.model, self.params = functorch.make_functional(self.model_module)  # init network 
    

    def predict(self, x):
        """returns logits"""
        assert self.model is not None
        assert self.params is not None
        logits = self.model(self.params, x)
        return logits 


    @staticmethod
    def sgd_step(params, gradients, lr):
        """one gradient step for updating the weights"""
        updated_params = []
        for param, gradient in zip(params, gradients):
            update = param - (lr * gradient)
            updated_params.append(update)
        
        return tuple(updated_params)
    

    @staticmethod
    def stateless_loss(params, model, criterion, batch):
        """
        Need to perform forward pass and loss calculation in one function
        since we need gradients w.r.t params (must be args[0]). The first
        value we return also needs to be the scalar loss value.  
        """
        x, y = batch
        logits = model(params, x)
        loss_val = criterion(logits, y)
        return loss_val, logits
    


    @staticmethod
    def train_step(params, model, criterion, batch, lr):
        """Combine this all into one function for modularity"""
        # has_aux means we can return more than just the scalar loss 
        grad_and_loss_fn = grad_and_value(OptExp.stateless_loss, has_aux=True)  
        grads, aux_outputs = grad_and_loss_fn(params, model, criterion, batch)  # get the grads 
        loss_val, logits = aux_outputs
        params = OptExp.sgd_step(params, grads, lr) 
        return params, loss_val, logits

    
    def train(self, num_epochs=10, lr=0.01):
        print('Beginning training!')
        epoch_num = 0
        for epoch in range(num_epochs):
            running_loss = 0.0
            epoch_num += 1
            tqdm_loader = tqdm(self.trainloader)
            for batch in tqdm_loader:
                tqdm_loader.set_description(f"Epoch {epoch_num}")

                # update params with one step 
                self.params, loss_val, logits = OptExp.train_step(self.params, self.model, self.criterion, batch, lr)

                running_loss += loss_val

            # print loss
            epoch_avg_loss = running_loss/len(self.trainloader)
            print("Training Loss (epoch " + str(epoch_num) + "):", epoch_avg_loss)


        print('Finished training!')


exp = OptExp()
exp.train(num_epochs=50, lr=0.01)