File: xfail_suggester.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (145 lines) | stat: -rw-r--r-- 3,761 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import re
import torch

"""
Instructions:

1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
2. python test/xfail_suggester.py
"""

with open('result.txt') as f:
    lines = f.readlines()

failed = [line for line in lines if line.startswith('FAILED')]
p = re.compile('FAILED test/test_\w+.py::\w+::(\S+)')  # noqa: W605


def get_failed_test(line):
    m = p.match(line)
    if m is None:
        return None
    return m.group(1)


base_names = {
    'test_grad_',
    'test_vjp_',
    'test_vmapvjp_',
    'test_vmapvjp_has_batch_rule_',
    'test_vjpvmap_',
    'test_jvp_',
    'test_vmapjvp_',
    'test_vmapjvpall_has_batch_rule_',
    'test_vmapjvpall_',
    'test_jvpvjp_',
    'test_vjpvjp_',
    'test_decomposition_',
    'test_make_fx_exhaustive_',
    'test_vmap_exhaustive_',
    'test_op_has_batch_rule_',
    'test_vmap_autograd_grad_',
}

failed_tests = [get_failed_test(line) for line in lines]
failed_tests = [match for match in failed_tests if match is not None]
failed_tests = sorted(failed_tests)

suggested_xfails = {}


def remove_device_dtype(test):
    return '_'.join(test.split('_')[:-2])


def belongs_to_base(test, base):
    if not test.startswith(base):
        return False
    candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
    for candidate in candidates:
        if test.startswith(candidate):
            return False
    return True


def parse_namespace(base):
    mappings = {
        'nn_functional_': 'nn.functional',
        'fft_': 'fft',
        'linalg_': 'linalg',
        '_masked_': '_masked',
        'sparse_': 'sparse',
        'speical_': 'special',
    }
    for heading in mappings.keys():
        if base.startswith(heading):
            return mappings[heading], base[len(heading):]
    return None, base


def get_torch_module(namespace):
    if namespace is None:
        return torch
    if namespace == 'nn.functional':
        return torch.nn.functional
    return getattr(torch, namespace)


def parse_base(base):
    namespace, rest = parse_namespace(base)

    apis = dir(get_torch_module(namespace))
    apis = sorted(apis, key=lambda x: -len(x))

    api = rest
    variant = ''
    for candidate in apis:
        if rest.startswith(candidate):
            api = candidate
            variant = rest[len(candidate) + 1:]
            break
    print(base, namespace, api, variant)
    return namespace, api, variant


def any_starts_with(strs, thing):
    for s in strs:
        if s.startswith(thing):
            return True
    return False


def get_suggested_xfails(base, tests):
    result = []
    tests = [test[len(base):] for test in tests if
             belongs_to_base(test, base)]

    base_tests = set([remove_device_dtype(test) for test in tests])
    tests = set(tests)
    for base in base_tests:
        cpu_variant = base + '_cpu_float32'
        cuda_variant = base + '_cuda_float32'
        namespace, api, variant = parse_base(base)
        if namespace is None:
            api = api
        else:
            api = f'{namespace}.{api}'
        if cpu_variant in tests and cuda_variant in tests:
            result.append(f"xfail('{api}', '{variant}'),")
            continue
        if cpu_variant in tests:
            result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
            continue
        if cuda_variant in tests:
            result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
            continue
        result.append(f"skip('{api}', '{variant}',")
    return result


result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
for k, v in result.items():
    print('=' * 50)
    print(k)
    print('=' * 50)
    print('\n'.join(v))