File: hop_db.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (346 lines) | stat: -rw-r--r-- 10,780 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# mypy: ignore-errors

import functools
import unittest

import torch
from functorch.experimental.control_flow import map
from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import onlyCUDA
from torch.testing._internal.common_dtype import all_types_and, custom_types
from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
from torch._higher_order_ops.invoke_subgraph import mark_compile_region

def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )
    yield SampleInput(
        [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
        args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
    )


def inner_f(x, y0, y1):
    return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]


def simple_map(xs, y0, y1):
    def f(x, y0, y1):
        return inner_f(x, y0, y1)

    return map(f, xs, y0, y1)


def nested_map(xs, y0, y1):
    def f1(xx, y0, y1):
        def f2(x, y0, y1):
            return inner_f(x, y0, y1)

        return map(f2, xx, y0, y1)

    return map(f1, xs, y0, y1)


def triple_nested_map(xs, y0, y1):
    def f0(xs, y0, y1):
        def f1(xx, y0, y1):
            def f2(x, y0, y1):
                return inner_f(x, y0, y1)

            return map(f2, xx, y0, y1)

        return map(f1, xs, y0, y1)

    return map(f0, xs, y0, y1)


# Please consult with torch.export team before
# adding new entry to this list.
hop_that_doesnt_have_opinfo_test_allowlist = [
    "custom_function_call",
    "autograd_function_apply",
    "run_and_save_rng_state",
    "run_with_rng_state",
    "out_dtype",
    "trace_wrapped",
    "map",  # T183144629
    "map_impl",
    "with_effects",
    "strict_mode",
    "_export_tracepoint",
    "call_torchbind",
    "triton_kernel_wrapper_mutation",
    "triton_kernel_wrapper_functional",
    "hints_wrapper",
]

torch.library.define(
    "testlib::mutating_custom_op",
    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
    tags=torch.Tag.pt2_compliant_tag,
)


@torch.library.impl("testlib::mutating_custom_op", "cpu")
def foo_impl_cpu(x, z):
    x.add_(5)
    z.add_(5)
    return x, z, x + z


@torch.library.impl("testlib::mutating_custom_op", "cuda")
def foo_impl_cuda(x, z):
    x.add_(5)
    z.add_(5)
    return x, z, x + z


@torch.library.register_fake("testlib::mutating_custom_op")
def foo_impl_abstract(x, z):
    return x, z, x + z


def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )
    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))


def simple_cond(x):
    return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])


def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )
    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))


@mark_compile_region
def fn_for_invoke_subgraph(x):
    return torch.sin(x)

def simple_invoke_subgraph(x):
    return fn_for_invoke_subgraph(x)


def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=False
    )
    yield SampleInput(
        make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
    )


def simple_auto_functionalize(x, z):
    return torch.ops.testlib.mutating_custom_op(x, z)


def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )

    def score_mod(score, b, h, m, n):
        return score + h

    q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
    block_mask = _create_empty_block_mask(q, k)
    yield SampleInput(q, k, v, score_mod, block_mask)


def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=False
    )
    yield SampleInput(
        torch.tensor(3),
        make_arg(2, 3, 4, low=0.1, high=2),
    )


def simple_while_loop(iter_t, x):
    def cond_fn(iter_t, x):
        return iter_t > 0

    def body_fn(iter_t, x):
        return iter_t - 1, x.cos()

    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))


def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
    make_arg = functools.partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )
    yield SampleInput(
        make_arg(2, 2, low=0.1, high=2),
        make_arg(2, 2, 2, low=0.1, high=2),
    )


def simple_scan(init, xs):

    def combine_fn(carry, x):
        result = carry @ x + x
        return result, carry.clone()

    return torch._higher_order_ops.scan(combine_fn, init, xs)


hop_db = [
    OpInfo(
        name="scan",
        variant_test_name="simple",
        op=simple_scan,
        sample_inputs_func=sample_inputs_scan,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        supports_autograd=False,
        # "torch.compile with aot_autograd does not currently support double backward."
        supports_gradgrad=False,
    ),
    OpInfo(
        name="invoke_subgraph",
        variant_test_name="simple",
        op=simple_invoke_subgraph,
        sample_inputs_func=sample_inputs_invoke_subgraph,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        supports_autograd=True,
        # "torch.compile with aot_autograd does not currently support double backward."
        supports_gradgrad=False,
    ),
    OpInfo(
        name="map",
        variant_test_name="simple",
        op=simple_map,
        sample_inputs_func=sample_inputs_map,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
    ),
    OpInfo(
        name="map",
        variant_test_name="nested",
        op=nested_map,
        sample_inputs_func=sample_inputs_map,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
    ),
    OpInfo(
        name="map",
        variant_test_name="triple_nested",
        op=triple_nested_map,
        sample_inputs_func=sample_inputs_map,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
    ),
    OpInfo(
        name="cond",
        variant_test_name="simple",
        op=simple_cond,
        sample_inputs_func=sample_inputs_cond,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        supports_autograd=True,
        # "torch.compile with aot_autograd does not currently support double backward."
        supports_gradgrad=False,
    ),
    OpInfo(
        name="while_loop",
        variant_test_name="simple",
        op=simple_while_loop,
        sample_inputs_func=sample_inputs_while_loop,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        supports_autograd=False,
    ),
    OpInfo(
        name="auto_functionalize",
        variant_test_name="simple",
        op=simple_auto_functionalize,
        sample_inputs_func=sample_inputs_auto_functionalize,
        dtypes=all_types_and(torch.bool, torch.half),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        supports_autograd=False,
    ),
    OpInfo(
        name="flex_attention",
        variant_test_name="simple",
        op=flex_attention,
        sample_inputs_func=sample_inputs_flex_attention,
        dtypes=custom_types(torch.float16, torch.float32),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        skips=(
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
            DecorateInfo(
                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
            ),
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
        ),
        decorators=[onlyCUDA],
    ),
    OpInfo(
        name="flex_attention_backward",
        variant_test_name="simple",
        op=flex_attention,
        sample_inputs_func=sample_inputs_flex_attention,
        dtypes=custom_types(torch.float16, torch.float32),
        supports_out=False,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        check_batched_forward_grad=False,
        check_inplace_batched_forward_grad=False,
        skips=(
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
            DecorateInfo(
                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
            ),
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
        ),
        decorators=[onlyCUDA],
    ),
]