1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
# Owner(s): ["module: onnx"]
import functools
import os
import sys
import unittest
import torch
from torch.autograd import function
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(-1, pytorch_test_dir)
torch.set_default_tensor_type("torch.FloatTensor")
BATCH_SIZE = 2
RNN_BATCH_SIZE = 7
RNN_SEQUENCE_LENGTH = 11
RNN_INPUT_SIZE = 5
RNN_HIDDEN_SIZE = 3
def _skipper(condition, reason):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if condition():
raise unittest.SkipTest(reason)
return f(*args, **kwargs)
return wrapper
return decorator
skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available")
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis")
skipIfNoBFloat16Cuda = _skipper(
lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
)
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions
# smaller than the currently tested opset_version
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version < min_opset_version:
raise unittest.SkipTest(
f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
)
return func(self, *args, **kwargs)
return wrapper
return skip_dec
# skips tests for all versions above max_opset_version.
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version > max_opset_version:
raise unittest.SkipTest(
f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
)
return func(self, *args, **kwargs)
return wrapper
return skip_dec
# skips tests for all opset versions.
def skipForAllOpsetVersions():
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version:
raise unittest.SkipTest(
"Skip verify test for unsupported opset_version"
)
return func(self, *args, **kwargs)
return wrapper
return skip_dec
def skipTraceTest(min_opset_version=float("inf")):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.is_trace_test_enabled = self.opset_version >= min_opset_version
if not self.is_trace_test_enabled and not self.is_script:
raise unittest.SkipTest("Skip verify test for torch trace")
return func(self, *args, **kwargs)
return wrapper
return skip_dec
def skipScriptTest(min_opset_version=float("inf")):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.is_script_test_enabled = self.opset_version >= min_opset_version
if not self.is_script_test_enabled and self.is_script:
raise unittest.SkipTest("Skip verify test for TorchScript")
return func(self, *args, **kwargs)
return wrapper
return skip_dec
# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version in unsupported_opset_versions:
raise unittest.SkipTest(
"Skip verify test for unsupported opset_version"
)
return func(self, *args, **kwargs)
return wrapper
return skip_dec
def skipShapeChecking(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.check_shape = False
return func(self, *args, **kwargs)
return wrapper
def skipDtypeChecking(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.check_dtype = False
return func(self, *args, **kwargs)
return wrapper
def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
|