1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
# Owner(s): ["module: dynamo"]
import unittest
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
f32 = torch.float32
i64 = torch.int64
i32 = torch.int32
class TestDebugUtils(TestCase):
def test_cast_model_to_fp64_dtype_args(self):
# Test that dtype arguments are converted to fp64
def fn(x):
return (
torch.ops.prims.convert_element_type(x, torch.float16),
x.to(torch.float16),
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
x.new_empty(x.shape),
)
x = torch.randn(32, device="cpu")
decomps = torch._decomp.core_aten_decompositions()
fx = make_fx(fn, decomposition_table=decomps)(x)
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
@requires_cuda
def test_aot_graph_parser(self):
from torch import device
def forward(
self,
primals_1: "f32[1001, 6]",
primals_2: "f32[1001]",
primals_3: "f32[1001, 64]",
primals_4: "f32[4190]",
primals_5: "f32[4190]",
primals_6: "f32[1739, 4190]",
primals_48: "f32[6144, 4191]",
):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant0
)
_tensor_constant0 = None
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy]
)
lift_fresh_copy = None
_tensor_constant1: "i64[6]" = self._tensor_constant1
lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant1
)
_tensor_constant1 = None
index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy_1]
)
primals_48 = lift_fresh_copy_1 = None
permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
primals_1 = None
addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
primals_2, index_1, permute
)
primals_2 = permute = None
amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
sub = None
sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
exp = None
full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
[6144, 1001],
1,
dtype=torch.int32,
layout=torch.strided,
device=device(type="cuda", index=0),
pin_memory=False,
)
iota: "i32[1001]" = torch.ops.prims.iota.default(
1001,
start=0,
step=1,
dtype=torch.int32,
device=device(type="cuda"),
requires_grad=False,
)
mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
full_default = iota = None
iota_1: "i32[6144]" = torch.ops.prims.iota.default(
6144,
start=0,
step=1001,
dtype=torch.int32,
device=device(type="cuda", index=0),
requires_grad=False,
)
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
mul = None
view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
div = None
_embedding_bag = torch.ops.aten._embedding_bag.default(
primals_3, view, iota_1, False, 0, False, view_1
)
return _embedding_bag
kwargs = aot_graph_input_parser(forward, device="cuda")
# runs successfully
forward(**kwargs)
@requires_cuda
def test_sym_aot_graph_parser(self):
def forward(
self,
primals_1: "f32[1001, 6]", # noqa: F821
primals_2: "f32[s0]", # noqa: F821
primals_3: "Sym(s0)", # noqa: F821,
primals_4: "f32[s1]", # noqa: F821,
primals_5: "Sym(s1)", # noqa: F821,
):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
kwargs = aot_graph_input_parser(
forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5
)
self.assertEqual(list(kwargs["primals_2"].shape), [10])
self.assertEqual(kwargs["primals_3"], 10)
self.assertEqual(list(kwargs["primals_4"].shape), [5])
self.assertEqual(kwargs["primals_5"], 5)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|