File: test_weights.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (92 lines) | stat: -rw-r--r-- 3,054 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <gtest/gtest.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <memory>

#include <torch/nativert/executor/Placement.h>
#include <torch/nativert/executor/Weights.h>
#include <torch/nativert/graph/Graph.h>

namespace torch::nativert {
class WeightsTest : public ::testing::Test {
 protected:
  void SetUp() override {
    static constexpr std::string_view source =
        R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
    graph = stringToGraph(source);
    placement = std::make_unique<Placement>(c10::Device(c10::DeviceType::CPU));
  }
  std::shared_ptr<Graph> graph;
  std::unique_ptr<Placement> placement;
};
TEST_F(WeightsTest, ConstructEmptyStateDict) {
  std::unordered_map<std::string, c10::IValue> stateDict;
  Weights weights(graph.get(), stateDict);
  // Check that weights are initialized correctly
  EXPECT_TRUE(weights.parameters().empty());
  EXPECT_TRUE(weights.buffers().empty());
  EXPECT_FALSE(weights.contains("non_existent_weight"));
}
TEST_F(WeightsTest, SetAndGetValue) {
  std::unordered_map<std::string, c10::IValue> stateDict;
  Weights weights(graph.get(), stateDict);
  at::Tensor tensor = at::ones({2, 2});
  weights.setValue("added_weight", tensor);
  EXPECT_TRUE(weights.contains("added_weight"));
  EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes());
}

} // namespace torch::nativert

using namespace ::testing;
struct ContainsTensorDict : torch::CustomClassHolder {
  explicit ContainsTensorDict(at::Tensor t) : t_(t) {}

  explicit ContainsTensorDict(c10::Dict<std::string, at::Tensor> dict) {
    t_ = dict.at(std::string("init_tensor"));
  }

  c10::Dict<std::string, at::Tensor> serialize() const {
    c10::Dict<std::string, at::Tensor> dict;
    dict.insert(std::string("init_tensor"), t_);
    return dict;
  }

  at::Tensor t_;
};

static auto reg =
    torch::class_<ContainsTensorDict>("testing", "ContainsTensorDict")
        .def(torch::init<at::Tensor>())
        .def_pickle(
            // __getstate__
            [](const c10::intrusive_ptr<ContainsTensorDict>& self)
                -> c10::Dict<std::string, at::Tensor> {
              return self->serialize();
            },
            // __setstate__
            [](c10::Dict<std::string, at::Tensor> data)
                -> c10::intrusive_ptr<ContainsTensorDict> {
              return c10::make_intrusive<ContainsTensorDict>(std::move(data));
            });

TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) {
  // Save
  auto customObj =
      c10::make_intrusive<ContainsTensorDict>(torch::tensor({1, 2, 3}));
  const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj)));

  // Load
  const auto loadedCustomObj =
      torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()});
  EXPECT_TRUE(loadedCustomObj.isObject());
  EXPECT_EQ(
      loadedCustomObj.to<c10::intrusive_ptr<ContainsTensorDict>>()
          ->t_[0]
          .item<int>(),
      1);
}