File: pytorch_load_save.py

package info (click to toggle)
bandit 1.7.10-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 5,796 kB
  • sloc: python: 19,688; makefile: 23; sh: 14
file content (21 lines) | stat: -rw-r--r-- 655 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Save the model
torch.save(loaded_model.state_dict(), 'model_weights.pth')

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

# Save the model
torch.save(another_model.state_dict(), 'model_weights.pth')