File: counter_ops_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (84 lines) | stat: -rw-r--r-- 3,348 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84





from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
import tempfile


class TestCounterOps(TestCase):

    def test_counter_ops(self):
        workspace.RunOperatorOnce(core.CreateOperator(
            'CreateCounter', [], ['c'], init_count=1))

        workspace.RunOperatorOnce(core.CreateOperator(
            'CountDown', ['c'], ['t1']))  # 1 -> 0
        assert not workspace.FetchBlob('t1')

        workspace.RunOperatorOnce(core.CreateOperator(
            'CountDown', ['c'], ['t2']))  # 0 -> -1
        assert workspace.FetchBlob('t2')

        workspace.RunOperatorOnce(core.CreateOperator(
            'CountUp', ['c'], ['t21']))  # -1 -> 0
        assert workspace.FetchBlob('t21') == -1
        workspace.RunOperatorOnce(core.CreateOperator(
            'RetrieveCount', ['c'], ['t22']))
        assert workspace.FetchBlob('t22') == 0

        workspace.RunOperatorOnce(core.CreateOperator(
            'ResetCounter', ['c'], [], init_count=1))  # -> 1
        workspace.RunOperatorOnce(core.CreateOperator(
            'CountDown', ['c'], ['t3']))  # 1 -> 0
        assert not workspace.FetchBlob('t3')

        workspace.RunOperatorOnce(core.CreateOperator(
            'ResetCounter', ['c'], ['t31'], init_count=5))  # 0 -> 5
        assert workspace.FetchBlob('t31') == 0
        workspace.RunOperatorOnce(core.CreateOperator(
            'ResetCounter', ['c'], ['t32']))  # 5 -> 0
        assert workspace.FetchBlob('t32') == 5

        workspace.RunOperatorOnce(core.CreateOperator(
            'ConstantFill', [], ['t4'], value=False, shape=[],
            dtype=core.DataType.BOOL))
        assert workspace.FetchBlob('t4') == workspace.FetchBlob('t1')

        workspace.RunOperatorOnce(core.CreateOperator(
            'ConstantFill', [], ['t5'], value=True, shape=[],
            dtype=core.DataType.BOOL))
        assert workspace.FetchBlob('t5') == workspace.FetchBlob('t2')

        assert workspace.RunOperatorOnce(core.CreateOperator(
            'And', ['t1', 't2'], ['t6']))
        assert not workspace.FetchBlob('t6')  # True && False

        assert workspace.RunOperatorOnce(core.CreateOperator(
            'And', ['t2', 't5'], ['t7']))
        assert workspace.FetchBlob('t7')  # True && True

        workspace.RunOperatorOnce(core.CreateOperator(
            'CreateCounter', [], ['serialized_c'], init_count=22))
        with tempfile.NamedTemporaryFile() as tmp:
            workspace.RunOperatorOnce(core.CreateOperator(
                'Save', ['serialized_c'], [], absolute_path=1,
                db_type='minidb', db=tmp.name))
            for i in range(10):
                workspace.RunOperatorOnce(core.CreateOperator(
                    'CountDown', ['serialized_c'], ['t8']))
            workspace.RunOperatorOnce(core.CreateOperator(
                'RetrieveCount', ['serialized_c'], ['t8']))
            assert workspace.FetchBlob('t8') == 12
            workspace.RunOperatorOnce(core.CreateOperator(
                'Load', [], ['serialized_c'], absolute_path=1,
                db_type='minidb', db=tmp.name))
            workspace.RunOperatorOnce(core.CreateOperator(
                'RetrieveCount', ['serialized_c'], ['t8']))
            assert workspace.FetchBlob('t8') == 22

if __name__ == "__main__":
    import unittest
    unittest.main()