File: control_ops_grad_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (49 lines) | stat: -rw-r--r-- 1,752 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49





import unittest
from caffe2.python import core, test_util, workspace
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
from caffe2.python.model_helper import ModelHelper
import numpy as np


class TestControl(test_util.TestCase):
    def test_disambiguate_grad_if_op_output(self):
        workspace.FeedBlob("cond", np.array(True))
        workspace.FeedBlob("then_grad", np.array(1))
        workspace.FeedBlob("else_grad", np.array(2))

        then_model = ModelHelper(name="then_test_model")
        then_model.net.Copy("then_grad", "input_grad")

        else_model = ModelHelper(name="else_test_model")
        else_model.net.Copy("else_grad", "else_temp_grad")
        else_model.net.Copy("else_temp", "input_grad")

        # to BuildGradientGenerators, in forward pass, we need else temp
        # as one of the output. Which later on results in a grad op like this:
        grad_op = core.CreateOperator(
            "If",
            ["cond", "then_grad", "else_grad"],
            ["input_grad", "else_temp_grad"],
            then_net=then_model.net.Proto(),
            else_net=else_model.net.Proto(),
        )

        # in certain cases, another branch of the net also generates input_grad
        # and we call _DisambiguateGradOpOutput in core.py
        new_grad_output = "input_grad" + "_autosplit_" + "0"
        disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
        self.assertEqual(grad_op.output[0], new_grad_output)
        for arg in grad_op.arg:
            if arg.name == "else_net":
                self.assertEqual(arg.n.op[1].output[0], new_grad_output)
            else:
                self.assertEqual(arg.name, "then_net")


if __name__ == '__main__':
    unittest.main()