snippets

🔌 Toolbox of short, reusable pieces of code and knowledge.

View the Project on GitHub rosikand/snippets

Save (and load) PyTorch model

Save a trained model (defined as net) weights:

torch.save(net.state_dict(), "intended_path.pth")  # use .pth extension 

Load it back in:

Instantiate another instance of the model’s class and then load back in. What you are really loading in is the model’s weights.

net = ModelClass()
net.load_state_dict(torch.load("saved_model_path.pth"))

Note that the above saves the weights (separate from the model class). If you’d like to save the entire model class (with the weights encapsulated), PyTorch can do this too (but it is not recommended). See here for more.