File: test_graph_signature.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (77 lines) | stat: -rw-r--r-- 2,830 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <gtest/gtest.h>
#include <torch/nativert/graph/GraphSignature.h>

namespace torch::nativert {

class GraphSignatureTest : public ::testing::Test {
 protected:
  // Member to hold the GraphSignature object
  GraphSignature graph_sig;

  void SetUp() override {
    torch::_export::TensorArgument param_tensor_arg;
    param_tensor_arg.set_name("param");
    torch::_export::InputToParameterSpec param_input_spec;
    param_input_spec.set_arg(param_tensor_arg);
    param_input_spec.set_parameter_name("param");
    torch::_export::InputSpec input_spec_0;
    input_spec_0.set_parameter(param_input_spec);

    torch::_export::TensorArgument input_tensor_arg;
    input_tensor_arg.set_name("input");
    torch::_export::Argument input_arg;
    input_arg.set_as_tensor(input_tensor_arg);
    torch::_export::UserInputSpec user_input_spec;
    user_input_spec.set_arg(input_arg);
    torch::_export::InputSpec input_spec_1;
    input_spec_1.set_user_input(user_input_spec);

    torch::_export::TensorArgument loss_tensor_arg;
    loss_tensor_arg.set_name("loss");
    torch::_export::LossOutputSpec loss_output_spec;
    loss_output_spec.set_arg(loss_tensor_arg);
    torch::_export::OutputSpec output_spec_0;
    output_spec_0.set_loss_output(loss_output_spec);

    torch::_export::TensorArgument output_tensor_arg;
    output_tensor_arg.set_name("output");
    torch::_export::Argument output_arg;
    output_arg.set_as_tensor(output_tensor_arg);
    torch::_export::UserOutputSpec user_output_spec;
    user_output_spec.set_arg(output_arg);
    torch::_export::OutputSpec output_spec_1;
    output_spec_1.set_user_output(user_output_spec);

    torch::_export::GraphSignature mock_storage;
    mock_storage.set_input_specs({input_spec_0, input_spec_1});
    mock_storage.set_output_specs({output_spec_0, output_spec_1});

    // Initialize the GraphSignature object
    graph_sig = GraphSignature(mock_storage);
  }
};

// Test the constructor with a simple GraphSignature
TEST_F(GraphSignatureTest, ConstructorTest) {
  std::vector<std::string_view> expected_params = {"param"};
  EXPECT_EQ(graph_sig.parameters(), expected_params);

  std::vector<std::string> expected_inputs = {"input"};
  EXPECT_EQ(graph_sig.userInputs(), expected_inputs);

  EXPECT_EQ(graph_sig.userInputs().size(), 1);
  EXPECT_EQ(graph_sig.parameters().size(), 1);
  EXPECT_EQ(graph_sig.lossOutput(), "loss");

  std::vector<std::optional<std::string>> expected_outputs = {"output"};
  EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}

// Test the replaceAllUses method
TEST_F(GraphSignatureTest, ReplaceAllUsesTest) {
  graph_sig.replaceAllUses("output", "new_output");
  std::vector<std::optional<std::string>> expected_outputs = {"new_output"};
  EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}

} // namespace torch::nativert