File: caffe_translator_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (90 lines) | stat: -rw-r--r-- 3,553 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# This a large test that goes through the translation of the bvlc caffenet
# model, runs an example through the whole model, and verifies numerically
# that all the results look right. In default, it is disabled unless you
# explicitly want to run it.

from google.protobuf import text_format
import numpy as np
import os
import sys

CAFFE_FOUND = False
try:
    from caffe.proto import caffe_pb2
    from caffe2.python import caffe_translator
    CAFFE_FOUND = True
except Exception as e:
    # Safeguard so that we only catch the caffe module not found exception.
    if ("'caffe'" in str(e)):
        print(
            "PyTorch/Caffe2 now requires a separate installation of caffe. "
            "Right now, this is not found, so we will skip the caffe "
            "translator test.")

from caffe2.python import utils, workspace, test_util
import unittest

def setUpModule():
    # Do nothing if caffe and test data is not found
    if not (CAFFE_FOUND and os.path.exists('data/testdata/caffe_translator')):
        return
    # We will do all the computation stuff in the global space.
    caffenet = caffe_pb2.NetParameter()
    caffenet_pretrained = caffe_pb2.NetParameter()
    with open('data/testdata/caffe_translator/deploy.prototxt') as f:
        text_format.Merge(f.read(), caffenet)
    with open('data/testdata/caffe_translator/'
              'bvlc_reference_caffenet.caffemodel') as f:
        caffenet_pretrained.ParseFromString(f.read())
    for remove_legacy_pad in [True, False]:
        net, pretrained_params = caffe_translator.TranslateModel(
            caffenet, caffenet_pretrained, is_test=True,
            remove_legacy_pad=remove_legacy_pad
        )
        with open('data/testdata/caffe_translator/'
                  'bvlc_reference_caffenet.translatedmodel',
                  'w') as fid:
            fid.write(str(net))
        for param in pretrained_params.protos:
            workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param))
        # Let's also feed in the data from the Caffe test code.
        data = np.load('data/testdata/caffe_translator/data_dump.npy').astype(
            np.float32)
        workspace.FeedBlob('data', data)
        # Actually running the test.
        workspace.RunNetOnce(net.SerializeToString())


@unittest.skipIf(not CAFFE_FOUND,
                 'No Caffe installation found.')
@unittest.skipIf(not os.path.exists('data/testdata/caffe_translator'),
                 'No testdata existing for the caffe translator test. Exiting.')
class TestNumericalEquivalence(test_util.TestCase):
    def testBlobs(self):
        names = [
            "conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3",
            "conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob"
        ]
        for name in names:
            print('Verifying {}'.format(name))
            caffe2_result = workspace.FetchBlob(name)
            reference = np.load(
                'data/testdata/caffe_translator/' + name + '_dump.npy'
            )
            self.assertEqual(caffe2_result.shape, reference.shape)
            scale = np.max(caffe2_result)
            np.testing.assert_almost_equal(
                caffe2_result / scale,
                reference / scale,
                decimal=5
            )


if __name__ == '__main__':
    if len(sys.argv) == 1:
        print(
            'If you do not explicitly ask to run this test, I will not run it. '
            'Pass in any argument to have the test run for you.'
        )
        sys.exit(0)
    unittest.main()