1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
from torchvision import models
import torch
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
class MobileNetV2Module:
def getModule(self):
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module
class MobileNetV2VulkanModule:
def getModule(self):
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module
class Resnet18Module:
def getModule(self):
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module
|