File: test_python_autograd.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (287 lines) | stat: -rw-r--r-- 8,851 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Owner(s): ["module: dynamo"]
from typing import Callable, Dict, List, NamedTuple, Optional

import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same


"""
This is an example of a pure-python version of autograd implemented by
@zdevito.  It represents a rather challenging test case for TorchDynamo
to push the limits of what it can do.
"""


_name: int = 0


def fresh_name() -> str:
    """create a new unique name for a variable: v0, v1, v2"""
    global _name
    r = f"v{_name}"
    _name += 1
    return r


class Variable:
    def __init__(self, value: torch.Tensor, name: Optional[str] = None):
        self.value = value
        self.name = name or fresh_name()

    # We need to start with some tensors whose values were not computed
    # inside the autograd. This function constructs leaf nodes.
    @staticmethod
    def constant(value: torch.Tensor, name: Optional[str] = None):
        return Variable(value, name)

    def __repr__(self):
        return repr(self.value)

    # This performs a pointwise multiplication of a Variable, tracking gradients
    def __mul__(self, rhs: "Variable") -> "Variable":
        # defined later in the notebook
        return operator_mul(self, rhs)

    def __add__(self, rhs: "Variable") -> "Variable":
        return operator_add(self, rhs)

    def sum(self, name: Optional[str] = None) -> "Variable":
        return operator_sum(self, name)

    def expand(self, sizes: List[int]) -> "Variable":
        return operator_expand(self, sizes)


class TapeEntry(NamedTuple):
    # names of the inputs to the original computation
    inputs: List[str]
    # names of the outputs of the original computation
    outputs: List[str]
    # apply chain rule
    propagate: "Callable[List[Variable], List[Variable]]"


gradient_tape: List[TapeEntry] = []


def reset_tape():
    gradient_tape.clear()
    global _name
    _name = 0


def grad(L, desired_results: List[Variable]) -> List[Variable]:
    # this map holds dL/dX for all values X
    dL_d: Dict[str, Variable] = {}
    # It starts by initializing the 'seed' dL/dL, which is 1
    dL_d[L.name] = Variable(torch.ones(()))
    # print(f'd{L.name} ------------------------')

    # look up dL_dentries. If a variable is never used to compute the loss,
    # we consider its gradient None, see the note below about zeros for more information.
    def gather_grad(entries: List[str]):
        return [dL_d[entry] if entry in dL_d else None for entry in entries]

    # propagate the gradient information backward
    for entry in reversed(gradient_tape):
        dL_doutputs = gather_grad(entry.outputs)
        if all(dL_doutput is None for dL_doutput in dL_doutputs):
            # optimize for the case where some gradient pathways are zero. See
            # The note below for more details.
            continue

        # perform chain rule propagation specific to each compute
        dL_dinputs = entry.propagate(dL_doutputs)

        # Accumulate the gradient produced for each input.
        # Each use of a variable produces some gradient dL_dinput for that
        # use. The multivariate chain rule tells us it is safe to sum
        # all the contributions together.
        for input, dL_dinput in zip(entry.inputs, dL_dinputs):
            if input not in dL_d:
                dL_d[input] = dL_dinput
            else:
                dL_d[input].value += dL_dinput.value

    # print some information to understand the values of each intermediate
    # for name, value in dL_d.items():
    #    print(f'd{L.name}_d{name} = {value.name}')
    # print(f'------------------------')

    return gather_grad(desired.name for desired in desired_results)


def operator_mul(self: Variable, rhs: Variable) -> Variable:
    if isinstance(rhs, float) and rhs == 1.0:
        # peephole optimization
        return self

    # define forward
    r = Variable(self.value * rhs.value)
    # print(f'{r.name} = {self.name} * {rhs.name}')

    # record what the inputs and outputs of the op were
    inputs = [self.name, rhs.name]
    outputs = [r.name]

    # define backprop
    def propagate(dL_doutputs: List[Variable]):
        (dL_dr,) = dL_doutputs

        dr_dself = rhs  # partial derivative of r = self*rhs
        dr_drhs = self  # partial derivative of r = self*rhs

        # chain rule propagation from outputs to inputs of multiply
        dL_dself = dL_dr * dr_dself
        dL_drhs = dL_dr * dr_drhs
        dL_dinputs = [dL_dself, dL_drhs]
        return dL_dinputs

    # finally, we record the compute we did on the tape
    gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
    return r


def operator_add(self: Variable, rhs: Variable) -> Variable:
    # Add follows a similar pattern to Mul, but it doesn't end up
    # capturing any variables.
    r = Variable(self.value + rhs.value)
    # print(f'{r.name} = {self.name} + {rhs.name}')

    def propagate(dL_doutputs: List[Variable]):
        (dL_dr,) = dL_doutputs
        dr_dself = 1.0
        dr_drhs = 1.0
        dL_dself = dL_dr * dr_dself
        dL_drhs = dL_dr * dr_drhs
        return [dL_dself, dL_drhs]

    gradient_tape.append(
        TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
    )
    return r


def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
    r = Variable(torch.sum(self.value), name=name)
    # print(f'{r.name} = {self.name}.sum()')

    def propagate(dL_doutputs: List[Variable]):
        (dL_dr,) = dL_doutputs
        size = self.value.size()
        return [dL_dr.expand(*size)]

    gradient_tape.append(
        TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
    )
    return r


def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
    assert self.value.dim() == 0  # only works for scalars
    r = Variable(self.value.expand(sizes))
    # print(f'{r.name} = {self.name}.expand({sizes})')

    def propagate(dL_doutputs: List[Variable]):
        (dL_dr,) = dL_doutputs
        return [dL_dr.sum()]

    gradient_tape.append(
        TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
    )
    return r


def simple(a, b):
    t = a + b
    return t * b


class TestPythonAutograd(TestCase):
    def _common(self, fn, expected_ops):
        args1 = [torch.randn(10), torch.randn(10)]
        args2 = [torch.randn(10), torch.randn(10)]
        cnt = CompileCounter()
        fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
        reset_tape()
        res1 = fn_dynamo(*args1)
        reset_tape()
        res2 = fn_dynamo(*args2)
        reset_tape()
        self.assertTrue(same(res1, fn(*args1)))
        reset_tape()
        self.assertTrue(same(res2, fn(*args2)))
        reset_tape()
        self.assertEqual(cnt.frame_count, 1)
        self.assertEqual(cnt.op_count, expected_ops)

    def test_forwards1(self):
        def fn(a, b):
            a = Variable.constant(a, name="a")
            b = Variable.constant(b, name="b")
            loss = simple(a, b).sum()
            return loss

        self._common(fn, 3)

    def test_forwards2(self):
        def fn(a, b):
            reset_tape()
            a = Variable.constant(a, name="a")
            b = Variable.constant(b, name="b")
            loss = simple(a, b).sum()
            reset_tape()
            return loss

        self._common(fn, 3)

    def test_backwards1(self):
        def fn(a, b):
            a = Variable.constant(a, name="a")
            b = Variable.constant(b, name="b")
            loss = simple(a, b).sum()
            return grad(loss, [a, b])

        self._common(fn, 8)

    def test_backwards2(self):
        def fn(a, b):
            reset_tape()
            a = Variable.constant(a, name="a")
            b = Variable.constant(b, name="b")
            loss = simple(a, b).sum()
            res = grad(loss, [a, b])
            reset_tape()
            return res

        self._common(fn, 8)

    def test_split(self):
        v1 = Variable.constant(torch.randn(10), name="a")
        v2 = Variable.constant(torch.randn(10), name="b")
        cnt = CompileCounter()

        def forward(a, b):
            return simple(a, b).sum()

        reset_tape()
        loss1 = forward(v1, v2)
        grad1 = grad(loss1, [v1, v2])

        reset_tape()
        opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
        opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
        loss2 = opt_forward(v1, v2)
        # force two frames
        grad2 = opt_grad(loss2, [v1, v2])

        self.assertTrue(same(loss1, loss2))
        self.assertTrue(same(grad1, grad2))
        self.assertEqual(cnt.frame_count, 2)
        self.assertEqual(cnt.op_count, 8)


if __name__ == "__main__":
    run_tests()