1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
|
# Owner(s): ["module: inductor"]
from torch._inductor import config
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU, TRITON_HAS_CPU
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
if HAS_CPU and TRITON_HAS_CPU:
@config.patch(cpu_backend="triton")
class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest):
pass
@config.patch(cpu_backend="triton")
class CpuTritonTests(test_torchinductor.TestCase):
common = test_torchinductor.check_model
device = "cpu"
test_torchinductor.copy_tests(
test_torchinductor.CommonTemplate,
CpuTritonTests,
"cpu",
xfail_prop="_expected_failure_triton_cpu",
)
if __name__ == "__main__":
if HAS_CPU and TRITON_HAS_CPU:
run_tests(needs="filelock")
|