File: test_reuse_ir.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (129 lines) | stat: -rw-r--r-- 4,370 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Owner(s): ["oncall: jit"]

import torch
import torch._lazy
import torch._lazy.config
import torch._lazy.ir_cache
import torch._lazy.ts_backend
import torch._lazy.metrics as metrics
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
import os
import unittest

torch._lazy.ts_backend.init()
torch._lazy.config.set_reuse_ir(True)

def get_test_device():
    return 'cuda' if 'LTC_TS_CUDA' in os.environ else 'cpu'

@unittest.skipIf(IS_WINDOWS, "To be fixed")
class TestLazyReuseIr(TestCase):
    def testAdd(self):
        device = get_test_device()
        x = torch.randn(2, 3, 4, device=device)
        y = torch.randn(2, 3, 4, device=device)
        z = torch.zeros(2, 3, 4, device=device)

        device = 'lazy'
        x_lazy = x.detach().clone().to(device=device)
        y_lazy = y.detach().clone().to(device=device)
        z_lazy = z.detach().clone().to(device=device)

        for i in range(10):
            z += (x + y)

        for i in range(10):
            z_lazy += (x_lazy + y_lazy)
            torch._lazy.mark_step()

        torch.testing.assert_close(z.cpu(), z_lazy.cpu())
        assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 14
        metrics.reset()
        torch._lazy.ir_cache.reset()

    def testAddSub(self):
        device = get_test_device()
        x = torch.randn(2, 3, 4, device=device)
        y = torch.randn(2, 3, 4, device=device)
        z = torch.zeros(2, 3, 4, device=device)

        device = 'lazy'
        x_lazy = x.detach().clone().to(device=device)
        y_lazy = y.detach().clone().to(device=device)
        z_lazy = z.detach().clone().to(device=device)

        for i in range(10):
            if i < 5:
                z += (x + y)
            else:
                z += (x - y)

        for i in range(10):
            if i < 5:
                z_lazy += (x_lazy + y_lazy)
            else:
                z_lazy += (x_lazy - y_lazy)
            torch._lazy.mark_step()

        torch.testing.assert_close(z.cpu(), z_lazy.cpu())
        assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
        metrics.reset()
        torch._lazy.ir_cache.reset()

    def testAddSubFallback(self):
        torch._lazy.config.set_force_fallback("aten::sub")
        device = get_test_device()
        x = torch.randn(2, 3, 4, device=device)
        y = torch.randn(2, 3, 4, device=device)
        z = torch.zeros(2, 3, 4, device=device)

        device = 'lazy'
        x_lazy = x.detach().clone().to(device=device)
        y_lazy = y.detach().clone().to(device=device)
        z_lazy = z.detach().clone().to(device=device)

        for i in range(10):
            if i < 5:
                z += (x + y)
            else:
                z += (x - y)

        for i in range(10):
            if i < 5:
                z_lazy += (x_lazy + y_lazy)
            else:
                z_lazy += (x_lazy - y_lazy)
            torch._lazy.mark_step()

        torch.testing.assert_close(z.cpu(), z_lazy.cpu())
        assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
        metrics.reset()
        torch._lazy.ir_cache.reset()
        torch._lazy.config.set_force_fallback("")

    def testBatchNorm(self):
        device = get_test_device()
        x = torch.randn(16, 3, 224, 224, device=device)
        weight = torch.randn(3, device=device)
        bias = torch.randn(3, device=device)

        for i in range(10):
            # BatchNorm2d does extra checks on dimensions which SymInts don't support yet
            # so we call `torch.ops.aten.native_batch_norm` to bypass the checks.
            z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5)

        device = "lazy"
        x_lazy = x.detach().clone().to(device=device)
        weight_lazy = weight.detach().clone().to(device=device)
        bias_lazy = bias.detach().clone().to(device=device)
        for i in range(10):
            z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5)
            torch._lazy.mark_step()

        torch.testing.assert_close(z.cpu(), z_lazy.cpu())
        assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7
        metrics.reset()
        torch._lazy.ir_cache.reset()

if __name__ == '__main__':
    run_tests()