1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
|
import torch
import torchvision
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
class MobileNetV2Module:
def __init__(self):
super(MobileNetV2Module, self).__init__()
def getModule(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example, ),
],
)
optimized_module(example)
return optimized_module
|