1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
# Owner(s): ["module: onnx"]
import io
import onnx
import torch
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
from torch.onnx._internal._exporter_legacy import ResolvedExportOptions
from torch.testing._internal import common_utils
class SampleModel(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y.relu()
return (y, z)
class SampleModelTwoInputs(torch.nn.Module):
def forward(self, x, b):
y = x + b
z = y.relu()
return (y, z)
class SampleModelForDynamicShapes(torch.nn.Module):
def forward(self, x, b):
return x.relu(), b.sigmoid()
class TestExportOptionsAPI(common_utils.TestCase):
def test_dynamic_shapes_default(self):
options = ResolvedExportOptions(ExportOptions())
self.assertFalse(options.dynamic_shapes)
def test_dynamic_shapes_explicit(self):
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
self.assertFalse(options.dynamic_shapes)
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
self.assertTrue(options.dynamic_shapes)
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
self.assertFalse(options.dynamic_shapes)
class TestDynamoExportAPI(common_utils.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ONNXProgram)
self.assertIsInstance(output.model_proto, onnx.ModelProto)
def test_export_with_options(self):
self.assertIsInstance(
dynamo_export(
SampleModel(),
torch.randn(1, 1, 2),
export_options=ExportOptions(
dynamic_shapes=True,
),
),
ONNXProgram,
)
def test_save_to_file_default_serializer(self):
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)
def test_save_to_existing_buffer_default_serializer(self):
buffer = io.BytesIO()
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)
def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
self,
):
with self.assertRaises(torch.onnx.OnnxExporterError):
dynamo_export(
SampleModel(),
torch.randn(1, 1, 2),
export_options=ExportOptions(
diagnostic_options=torch.onnx.DiagnosticOptions(
warnings_as_errors=True
)
),
)
if __name__ == "__main__":
common_utils.run_tests()
|