1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
import re
import torch
"""
Instructions:
1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
2. python test/xfail_suggester.py
"""
with open('result.txt') as f:
lines = f.readlines()
failed = [line for line in lines if line.startswith('FAILED')]
p = re.compile('FAILED test/test_\w+.py::\w+::(\S+)') # noqa: W605
def get_failed_test(line):
m = p.match(line)
if m is None:
return None
return m.group(1)
base_names = {
'test_grad_',
'test_vjp_',
'test_vmapvjp_',
'test_vmapvjp_has_batch_rule_',
'test_vjpvmap_',
'test_jvp_',
'test_vmapjvp_',
'test_vmapjvpall_has_batch_rule_',
'test_vmapjvpall_',
'test_jvpvjp_',
'test_vjpvjp_',
'test_decomposition_',
'test_make_fx_exhaustive_',
'test_vmap_exhaustive_',
'test_op_has_batch_rule_',
'test_vmap_autograd_grad_',
}
failed_tests = [get_failed_test(line) for line in lines]
failed_tests = [match for match in failed_tests if match is not None]
failed_tests = sorted(failed_tests)
suggested_xfails = {}
def remove_device_dtype(test):
return '_'.join(test.split('_')[:-2])
def belongs_to_base(test, base):
if not test.startswith(base):
return False
candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
for candidate in candidates:
if test.startswith(candidate):
return False
return True
def parse_namespace(base):
mappings = {
'nn_functional_': 'nn.functional',
'fft_': 'fft',
'linalg_': 'linalg',
'_masked_': '_masked',
'sparse_': 'sparse',
'speical_': 'special',
}
for heading in mappings.keys():
if base.startswith(heading):
return mappings[heading], base[len(heading):]
return None, base
def get_torch_module(namespace):
if namespace is None:
return torch
if namespace == 'nn.functional':
return torch.nn.functional
return getattr(torch, namespace)
def parse_base(base):
namespace, rest = parse_namespace(base)
apis = dir(get_torch_module(namespace))
apis = sorted(apis, key=lambda x: -len(x))
api = rest
variant = ''
for candidate in apis:
if rest.startswith(candidate):
api = candidate
variant = rest[len(candidate) + 1:]
break
print(base, namespace, api, variant)
return namespace, api, variant
def any_starts_with(strs, thing):
for s in strs:
if s.startswith(thing):
return True
return False
def get_suggested_xfails(base, tests):
result = []
tests = [test[len(base):] for test in tests if
belongs_to_base(test, base)]
base_tests = set([remove_device_dtype(test) for test in tests])
tests = set(tests)
for base in base_tests:
cpu_variant = base + '_cpu_float32'
cuda_variant = base + '_cuda_float32'
namespace, api, variant = parse_base(base)
if namespace is None:
api = api
else:
api = f'{namespace}.{api}'
if cpu_variant in tests and cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}'),")
continue
if cpu_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
continue
if cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
continue
result.append(f"skip('{api}', '{variant}',")
return result
result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
for k, v in result.items():
print('=' * 50)
print(k)
print('=' * 50)
print('\n'.join(v))
|