File: test_pytorch_helper.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (71 lines) | stat: -rw-r--r-- 2,691 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Owner(s): ["module: onnx"]

# Some standard imports
import unittest

import numpy as np

import torch.nn.init as init
import torch.onnx
from caffe2.python.core import workspace
from caffe2.python.model_helper import ModelHelper
from pytorch_helper import PyTorchModule
from torch import nn
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoLapack


class TestCaffe2Backend(common_utils.TestCase):
    @skipIfNoLapack
    @unittest.skip("test broken because Lapack was always missing.")
    def test_helper(self):
        class SuperResolutionNet(nn.Module):
            def __init__(self, upscale_factor, inplace=False):
                super().__init__()

                self.relu = nn.ReLU(inplace=inplace)
                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
                self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

                self._initialize_weights()

            def forward(self, x):
                x = self.relu(self.conv1(x))
                x = self.relu(self.conv2(x))
                x = self.relu(self.conv3(x))
                x = self.pixel_shuffle(self.conv4(x))
                return x

            def _initialize_weights(self):
                init.orthogonal(self.conv1.weight, init.calculate_gain("relu"))
                init.orthogonal(self.conv2.weight, init.calculate_gain("relu"))
                init.orthogonal(self.conv3.weight, init.calculate_gain("relu"))
                init.orthogonal(self.conv4.weight)

        torch_model = SuperResolutionNet(upscale_factor=3)

        fake_input = torch.randn(1, 1, 224, 224, requires_grad=True)

        # use ModelHelper to create a C2 net
        helper = ModelHelper(name="test_model")
        start = helper.Sigmoid(["the_input"])
        # Embed the ONNX-converted pytorch net inside it
        (toutput,) = PyTorchModule(helper, torch_model, (fake_input,), [start])
        output = helper.Sigmoid(toutput)

        workspace.RunNetOnce(helper.InitProto())
        workspace.FeedBlob("the_input", fake_input.data.numpy())
        # print([ k for k in workspace.blobs ])
        workspace.RunNetOnce(helper.Proto())
        c2_out = workspace.FetchBlob(str(output))

        torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))

        np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)


if __name__ == "__main__":
    common_utils.run_tests()