1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
# Owner(s): ["module: onnx"]
"""Unit tests for the onnx dynamo exporter."""
from __future__ import annotations
import torchvision
import torch
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.testing._internal import common_utils
@common_utils.instantiate_parametrized_tests
class DynamoExporterTest(common_utils.TestCase):
def test_insert_contiguous_between_transpose_and_view(self):
class Model(torch.nn.Module):
def forward(self, query, key, value):
res = torch.nn.functional.scaled_dot_product_attention(
query, key, value
)
rest = res.transpose(0, 1)
return rest.view(8, 32, 128 * 64)
model = Model()
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
ep = torch.export.export(model, (query, key, value), strict=False)
self.assertNotIn("call_method", str(ep.graph))
onnx_program = torch.onnx.export(
model, (query, key, value), dynamo=True, fallback=False
)
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
def test_constant_complex(self):
class MulModule(torch.nn.Module):
def forward(self, x):
y = 2 + 3j
return torch.ops.aten.mul(x, y)
# Example usage with complex inputs
x = torch.tensor(
[[1.0 + 2.0j, 3.0 + 4.0j], [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64
)
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)
def test_pow_does_not_trigger_type_promotion(self):
class Model(torch.nn.Module):
def forward(self, x):
return x**2.0
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)
self.assertNotIn("Cast", [node.op_type for node in onnx_program.model.graph])
def test_onnx_export_control_flow(self):
class CondModel(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return x + 1.0
def false_fn(x):
return x - 42.0
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
return y
onnx_program = torch.onnx.export(
CondModel(),
(torch.tensor([1, 2]),),
dynamo=True,
fallback=False,
)
onnx_model = onnx_program.model
self.assertIn("If", [node.op_type for node in onnx_model.graph])
onnx_testing.assert_onnx_program(onnx_program)
# Test different branches
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
def test_onnx_export_nested_control_flow_and_nested_weights(self):
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
# Nested weight
self.weight = torch.nn.Parameter(torch.tensor([100.0]))
def forward(self, x):
def true_fn(x):
return x * self.weight
def false_fn(x):
return x / self.weight
y = torch.cond(x.sum() <= 0, true_fn, false_fn, [x])
return y
class CondModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = Submodule()
self.weight = torch.nn.Parameter(torch.tensor([42.0]))
def forward(self, x):
def true_fn(x):
return self.submodule(x - self.weight)
def false_fn(x):
return x - self.weight
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
return y
onnx_program = torch.onnx.export(
CondModel(),
(torch.tensor([1, 2]),),
dynamo=True,
fallback=False,
)
onnx_testing.assert_onnx_program(onnx_program)
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
def test_onnx_export_torchvision_ops(self):
class VisionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, *x):
out = torchvision.ops.nms(x[0], x[1], x[2])
return out
args = (
torch.tensor([[0, 0, 1, 1], [0.5, 0.5, 1, 1]], dtype=torch.float),
torch.tensor([0.1, 0.2]),
0,
)
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)
# TODO(justinchuby): Test multi-output HOPs
if __name__ == "__main__":
common_utils.run_tests()
|