File: seq2seq_model_helper_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (70 lines) | stat: -rw-r--r-- 1,838 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70





from caffe2.python.models.seq2seq import seq2seq_model_helper
from caffe2.python import scope, test_util


class Seq2SeqModelHelperTest(test_util.TestCase):
    def testConstuctor(self):
        model_name = 'TestModel'
        m = seq2seq_model_helper.Seq2SeqModelHelper(name=model_name)

        self.assertEqual(m.name, model_name)
        self.assertEqual(m.init_params, True)

        self.assertEqual(m.arg_scope, {
            'use_cudnn': True,
            'cudnn_exhaustive_search': False,
            'order': 'NHWC'
        })

    def testAddParam(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        param_name = 'test_param'
        param = m.AddParam(param_name, init_value=1)
        self.assertEqual(str(param), param_name)

    def testGetNonTrainableParams(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        m.AddParam('test_param1', init_value=1, trainable=True)
        p2 = m.AddParam('test_param2', init_value=2, trainable=False)

        self.assertEqual(
            m.GetNonTrainableParams(),
            [p2]
        )

        with scope.NameScope('A', reset=True):
            p3 = m.AddParam('test_param3', init_value=3, trainable=False)
            self.assertEqual(
                m.GetNonTrainableParams(),
                [p3]
            )

        self.assertEqual(
            m.GetNonTrainableParams(),
            [p2, p3]
        )

    def testGetAllParams(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        p1 = m.AddParam('test_param1', init_value=1, trainable=True)
        p2 = m.AddParam('test_param2', init_value=2, trainable=False)

        self.assertEqual(
            m.GetAllParams(),
            [p1, p2]
        )


if __name__ == "__main__":
    import unittest
    import random
    random.seed(2221)
    unittest.main()