snippets

🔌 Toolbox of short, reusable pieces of code and knowledge.

View the Project on GitHub rosikand/snippets

PyTorch training loop (and plot)

Simple PyTorch training loop code adapted from the docs.

Model is defined as net.

EPOCHS = 50  # num epochs 
epoch_num = 0
loss_vals = []

for epoch in range(EPOCHS):  # loop over the dataset (EPOCHS) times 
    epoch_num += 1
    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data[0]  # add .to(device) if training on GPU 
        labels = data[1]  # add .to(device) if training on GPU 

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs) 
        loss = criterion(outputs, labels) 
        loss.backward()
        optimizer.step()

        running_loss += loss.item() 

    print("Loss for epoch " + str(epoch_num) + ":", running_loss)
    loss_vals.append(running_loss)
    
print('Finished Training')

Plotting loss values

We can plot the loss value at each epoch from the loss_vals array like so:

# graph training loss 
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg

plt.plot(loss_vals)
plt.title('Training Loss')
plt.ylabel('loss')
plt.show()

It will plot something that looks like this:

cell-photo