File: script_model.py

package info (click to toggle)
pytorch-vision 0.24.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 21,844 kB
  • sloc: python: 70,433; cpp: 11,502; ansic: 2,588; java: 550; sh: 317; xml: 79; objc: 56; makefile: 33
file content (10 lines) | stat: -rw-r--r-- 326 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
import torch
from torchvision import models

for model, name in (
    (models.resnet18(weights=None), "resnet18"),
    (models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None), "fasterrcnn_resnet50_fpn"),
):
    model.eval()
    traced_model = torch.jit.script(model)
    traced_model.save(f"{name}.pt")