File: test_codegen_unboxing.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (219 lines) | stat: -rw-r--r-- 7,287 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/torch.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
// Cover codegen'd unboxing logic for these types:
//'Device',
//'Device?',
//'Dimname',
//'Dimname[1]',
//'Dimname[]',
//'Dimname[]?',
//'Generator?',
//'Layout?',
//'MemoryFormat',
//'MemoryFormat?',
//'Scalar',
//'Scalar?',
//'ScalarType',
//'ScalarType?',
//'Scalar[]',
//'Storage',
//'Stream',
//'Tensor',
//'Tensor(a!)',
//'Tensor(a!)[]',
//'Tensor(a)',
//'Tensor(b!)',
//'Tensor(c!)',
//'Tensor(d!)',
//'Tensor?',
//'Tensor?[]',
//'Tensor[]',
//'bool',
//'bool?',
//'bool[2]',
//'bool[3]',
//'bool[4]',
//'float',
//'float?',
//'float[]?',
//'int',
//'int?',
//'int[1]',
//'int[1]?',
//'int[2]',
//'int[2]?',
//'int[3]',
//'int[4]',
//'int[5]',
//'int[6]',
//'int[]',
//'int[]?',
//'str',
//'str?'
namespace torch {
namespace jit {
namespace mobile {
// covers int[], ScalarType?, Layout?, Device?, bool?
TEST(LiteInterpreterTest, Ones) {
  // Load check in model: ModelWithDTypeDeviceLayoutPinMemory.ptl
  auto testModelFile = "ModelWithDTypeDeviceLayoutPinMemory.ptl";

  //  class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module):
  //    def forward(self, x: int):
  //        a = torch.ones([3, x], dtype=torch.int64, layout=torch.strided, device="cpu")
  //        return a
  Module bc = _load_for_mobile(testModelFile);
  std::vector<c10::IValue> input{c10::IValue(4)};
  const auto result = bc.forward(input);
  ASSERT_EQ(result.toTensor().size(0), 3);
  ASSERT_EQ(result.toTensor().size(1), 4);
}

TEST(LiteInterpreterTest, Index) {
  // Load check in model: ModelWithTensorOptional.ptl
  auto testModelFile = "ModelWithTensorOptional.ptl";

  //    class ModelWithTensorOptional(torch.nn.Module):
  //      def forward(self, index):
  //        a = torch.zeros(2, 2)
  //        a[0][1] = 1
  //        a[1][0] = 2
  //        a[1][1] = 3
  //        return a[index]
  Module bc = _load_for_mobile(testModelFile);
  int64_t ind_1 = 0;

  const auto result_1 = bc.forward({at::tensor(ind_1)});

  at::Tensor expected = at::empty({1, 2}, c10::TensorOptions(c10::ScalarType::Float));
  expected[0][0] = 0;
  expected[0][1] = 1;

  AT_ASSERT(result_1.toTensor().equal(expected));
}

TEST(LiteInterpreterTest, Gradient) {
  // Load check in model: ModelWithScalarList.ptl
  auto testModelFile = "ModelWithScalarList.ptl";

  //    class ModelWithScalarList(torch.nn.Module):
  //      def forward(self, a: int):
  //        values = torch.tensor([4., 1., 1., 16.], )
  //        if a == 0:
  //          return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64))
  //        elif a == 1:
  //          return torch.gradient(values, spacing=[torch.tensor(1.).item()])
  Module bc = _load_for_mobile(testModelFile);

  const auto result_1 = bc.forward({0});
  at::Tensor expected_1 = at::tensor({-1.5, -0.75, 3.75, 7.5}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result_1.toList().get(0).toTensor().equal(expected_1));

  const auto result_2 = bc.forward({1});
  at::Tensor expected_2 = at::tensor({-3.0, -1.5, 7.5, 15.0}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result_2.toList().get(0).toTensor().equal(expected_2));
}

TEST(LiteInterpreterTest, Upsample) {
  // Load check in model: ModelWithFloatList.ptl
  auto testModelFile = "ModelWithFloatList.ptl";

  // model = torch.nn.Upsample(scale_factor=(2.0,), mode="linear")
  Module bc = _load_for_mobile(testModelFile);

  const auto result_1 = bc.forward({at::ones({1, 2, 3})});
  at::Tensor expected_1 = at::ones({1, 2, 6}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result_1.toTensor().equal(expected_1));
}

TEST(LiteInterpreterTest, IndexTensor) {
  // Load check in model: ModelWithListOfOptionalTensors.ptl
  auto testModelFile = "ModelWithListOfOptionalTensors.ptl";

  // class ModelWithListOfOptionalTensors(torch.nn.Module):
  //   def forward(self, index):
  //      values = torch.tensor([4., 1., 1., 16.], )
  //      return values[[index, torch.tensor(0)]]
  Module bc = _load_for_mobile(testModelFile);
  const auto result_1 = bc.forward({at::tensor({1}, c10::TensorOptions(c10::ScalarType::Long))});

  at::Tensor expected_1 = at::tensor({1.}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result_1.toTensor().equal(expected_1));
}

TEST(LiteInterpreterTest, Conv2d) {
  // Load check in model: ModelWithArrayOfInt.ptl
  auto testModelFile = "ModelWithArrayOfInt.ptl";

  // model = torch.nn.Conv2d(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
  Module bc = _load_for_mobile(testModelFile);
  const auto result_1 = bc.forward({at::ones({1, 1, 1, 1})});

  ASSERT_EQ(result_1.toTensor().sizes(), c10::IntArrayRef ({1,2,2,2}));
}

TEST(LiteInterpreterTest, AddTensor) {
  // Load check in model: ModelWithTensors.ptl
  auto testModelFile = "ModelWithTensors.ptl";

  //  class ModelWithTensors(torch.nn.Module):
  //    def forward(self, a):
  //      values = torch.ones(size=[2, 3], names=['N', 'C'])
  //      values[0][0] = a[0]
  //      return values
  Module bc = _load_for_mobile(testModelFile);
  const auto result_1 = bc.forward({at::tensor({1, 2, 3}, c10::TensorOptions(c10::ScalarType::Long))});

  at::Tensor expected_1 = at::tensor({2, 3, 4}, c10::TensorOptions(c10::ScalarType::Long));
  AT_ASSERT(result_1.toTensor().equal(expected_1));
}

TEST(LiteInterpreterTest, DivideTensor) {
  // Load check in model: ModelWithStringOptional.ptl
  auto testModelFile = "ModelWithStringOptional.ptl";

  //  class ModelWithStringOptional(torch.nn.Module):
  //    def forward(self, b):
  //      a = torch.tensor(3, dtype=torch.int64)
  //      out = torch.empty(size=[1], dtype=torch.float)
  //      torch.div(b, a, out=out)
  //      return [torch.div(b, a, rounding_mode='trunc'), out]
  Module bc = _load_for_mobile(testModelFile);
  const auto result_1 = bc.forward({at::tensor({-12}, c10::TensorOptions(c10::ScalarType::Long))});

  at::Tensor expected_1 = at::tensor({-4}, c10::TensorOptions(c10::ScalarType::Long));
  at::Tensor expected_2 = at::tensor({-4.}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result_1.toList().get(0).toTensor().equal(expected_1));
  AT_ASSERT(result_1.toList().get(1).toTensor().equal(expected_2));
}

TEST(LiteInterpreterTest, MultipleOps) {
  // Load check in model: ModelWithMultipleOps.ptl
  auto testModelFile = "ModelWithMultipleOps.ptl";

  // class ModelWithMultipleOps(torch.nn.Module):
  //           def __init__(self):
  //               super(Model, self).__init__()
  //               self.ops = torch.nn.Sequential(
  //                   torch.nn.ReLU(),
  //                   torch.nn.Flatten(),
  //               )
  //           def forward(self, x):
  //               x[1] = -2
  //               return self.ops(x)

  Module bc = _load_for_mobile(testModelFile);
  auto b = at::ones({2, 2, 2, 2});
  const auto result = bc.forward({b});

  at::Tensor expected = torch::tensor({{1, 1, 1, 1, 1, 1, 1, 1}, {0, 0, 0, 0, 0, 0, 0, 0}}, c10::TensorOptions(c10::ScalarType::Float));
  AT_ASSERT(result.toTensor().equal(expected));
}
} // namespace mobile
} // namespace jit
} // namespace torch