1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorMethods(JitTestCase):
def test_getitem(self):
def tensor_getitem(inp: torch.Tensor):
indices = torch.tensor([0, 2], dtype=torch.long)
return inp.__getitem__(indices)
inp = torch.rand(3, 4)
self.checkScript(tensor_getitem, (inp, ))
scripted = torch.jit.script(tensor_getitem)
FileCheck().check("aten::index").run(scripted.graph)
def test_getitem_invalid(self):
def tensor_getitem_invalid(inp: torch.Tensor):
return inp.__getitem__()
with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"):
torch.jit.script(tensor_getitem_invalid)
|