File: net_printer_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (99 lines) | stat: -rw-r--r-- 3,190 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99





from caffe2.python import net_printer
from caffe2.python.checkpoint import Job
from caffe2.python.net_builder import ops
from caffe2.python.task import Task, final_output, WorkspaceType
import unittest


def example_loop():
    with Task():
        total = ops.Const(0)
        total_large = ops.Const(0)
        total_small = ops.Const(0)
        total_tiny = ops.Const(0)
        with ops.loop(10) as loop:
            outer = ops.Mul([loop.iter(), ops.Const(10)])
            with ops.loop(loop.iter()) as inner:
                val = ops.Add([outer, inner.iter()])
                with ops.If(ops.GE([val, ops.Const(80)])) as c:
                    ops.Add([total_large, val], [total_large])
                with c.Elif(ops.GE([val, ops.Const(50)])) as c:
                    ops.Add([total_small, val], [total_small])
                with c.Else():
                    ops.Add([total_tiny, val], [total_tiny])
                ops.Add([total, val], total)


def example_task():
    with Task():
        with ops.task_init():
            one = ops.Const(1)
        two = ops.Add([one, one])
        with ops.task_init():
            three = ops.Const(3)
        accum = ops.Add([two, three])
        # here, accum should be 5
        with ops.task_exit():
            # here, accum should be 6, since this executes after lines below
            seven_1 = ops.Add([accum, one])
        six = ops.Add([accum, one])
        ops.Add([accum, one], [accum])
        seven_2 = ops.Add([accum, one])
        o6 = final_output(six)
        o7_1 = final_output(seven_1)
        o7_2 = final_output(seven_2)

    with Task(num_instances=2):
        with ops.task_init():
            one = ops.Const(1)
        with ops.task_instance_init():
            local = ops.Const(2)
        ops.Add([one, local], [one])
        ops.LogInfo('ble')

    return o6, o7_1, o7_2

def example_job():
    with Job() as job:
        with job.init_group:
            example_loop()
        example_task()
    return job


class TestNetPrinter(unittest.TestCase):
    def test_print(self):
        self.assertTrue(len(net_printer.to_string(example_job())) > 0)

    def test_valid_job(self):
        job = example_job()
        with job:
            with Task():
                # distributed_ctx_init_* ignored by analyzer
                ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b'])
        # net_printer.analyze(example_job())
        print(net_printer.to_string(example_job()))

    def test_undefined_blob(self):
        job = example_job()
        with job:
            with Task():
                ops.Add(['a', 'b'])
        with self.assertRaises(AssertionError) as e:
            net_printer.analyze(job)
        self.assertEqual("Blob undefined: a", str(e.exception))

    def test_multiple_definition(self):
        job = example_job()
        with job:
            with Task(workspace_type=WorkspaceType.GLOBAL):
                ops.Add([ops.Const(0), ops.Const(1)], 'out1')
            with Task(workspace_type=WorkspaceType.GLOBAL):
                ops.Add([ops.Const(2), ops.Const(3)], 'out1')
        with self.assertRaises(AssertionError):
            net_printer.analyze(job)