File: test_operators.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (27 lines) | stat: -rw-r--r-- 816 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# mypy: allow-untyped-defs
import torch.library
from torch import Tensor
from torch.autograd import Function


if not torch._running_with_deploy():
    _test_lib_def = torch.library.Library("_inductor_test", "DEF")
    _test_lib_def.define(
        "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag
    )

    _test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
    for dispatch_key in ("CPU", "CUDA", "Meta"):
        _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)

    class Realize(Function):
        @staticmethod
        def forward(ctx, x):
            return torch.ops._inductor_test.realize(x)

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    def realize(x: Tensor) -> Tensor:
        return Realize.apply(x)