1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import functools
import os
from io import BytesIO
import shutil
import sys
import torch
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
_OPERATORS = set()
_FILENAMES = []
_MODELS = []
def save_model(cls):
"""Save a model and dump all the ops"""
@functools.wraps(cls)
def wrapper_save():
_MODELS.append(cls)
model = cls()
scripted = torch.jit.script(model)
buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
ops = _export_operator_list(mobile_module)
_OPERATORS.update(ops)
path = f"./{cls.__name__}.ptl"
_FILENAMES.append(path)
scripted._save_for_lite_interpreter(path)
return wrapper_save
@save_model
class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module):
def forward(self, x: int):
a = torch.ones(size=[3, x], dtype=torch.int64, layout=torch.strided, device="cpu", pin_memory=False)
return a
@save_model
class ModelWithTensorOptional(torch.nn.Module):
def forward(self, index):
a = torch.zeros(2, 2)
a[0][1] = 1
a[1][0] = 2
a[1][1] = 3
return a[index]
# gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
@save_model
class ModelWithScalarList(torch.nn.Module):
def forward(self, a: int):
values = torch.tensor([4., 1., 1., 16.], )
if a == 0:
return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64))
elif a == 1:
return torch.gradient(values, spacing=[torch.tensor(1.).item()])
# upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
@save_model
class ModelWithFloatList(torch.nn.Upsample):
def __init__(self):
super().__init__(scale_factor=(2.0,), mode="linear", align_corners=False, recompute_scale_factor=True)
# index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
@save_model
class ModelWithListOfOptionalTensors(torch.nn.Module):
def forward(self, index):
values = torch.tensor([[4., 1., 1., 16.]])
return values[torch.tensor(0), index]
# conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1,
# int groups=1) -> Tensor
@save_model
class ModelWithArrayOfInt(torch.nn.Conv2d):
def __init__(self):
super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
# ones_like(Tensor self, *, ScalarType?, dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None,
# MemoryFormat? memory_format=None) -> Tensor
@save_model
class ModelWithTensors(torch.nn.Module):
def forward(self, a):
b = torch.ones_like(a)
return a + b
@save_model
class ModelWithStringOptional(torch.nn.Module):
def forward(self, b):
a = torch.tensor(3, dtype=torch.int64)
out = torch.empty(size=[1], dtype=torch.float)
torch.div(b, a, out=out)
return [torch.div(b, a, rounding_mode='trunc'), out]
@save_model
class ModelWithMultipleOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.ops = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Flatten(),
)
def forward(self, x):
x[1] = -2
return self.ops(x)
if __name__ == "__main__":
command = sys.argv[1]
ops_yaml = sys.argv[2]
backup = ops_yaml + ".bak"
if command == "setup":
tests = [
ModelWithDTypeDeviceLayoutPinMemory(),
ModelWithTensorOptional(),
ModelWithScalarList(),
ModelWithFloatList(),
ModelWithListOfOptionalTensors(),
ModelWithArrayOfInt(),
ModelWithTensors(),
ModelWithStringOptional(),
ModelWithMultipleOps(),
]
shutil.copyfile(ops_yaml, backup)
with open(ops_yaml, 'a') as f:
for op in _OPERATORS:
f.write(f"- {op}\n")
elif command == "shutdown":
for file in _MODELS:
if os.path.isfile(file):
os.remove(file)
shutil.move(backup, ops_yaml)
|