1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
#include <gtest/gtest.h>
#include <torch/nativert/graph/GraphSignature.h>
namespace torch::nativert {
class GraphSignatureTest : public ::testing::Test {
protected:
// Member to hold the GraphSignature object
GraphSignature graph_sig;
void SetUp() override {
torch::_export::TensorArgument param_tensor_arg;
param_tensor_arg.set_name("param");
torch::_export::InputToParameterSpec param_input_spec;
param_input_spec.set_arg(param_tensor_arg);
param_input_spec.set_parameter_name("param");
torch::_export::InputSpec input_spec_0;
input_spec_0.set_parameter(param_input_spec);
torch::_export::TensorArgument input_tensor_arg;
input_tensor_arg.set_name("input");
torch::_export::Argument input_arg;
input_arg.set_as_tensor(input_tensor_arg);
torch::_export::UserInputSpec user_input_spec;
user_input_spec.set_arg(input_arg);
torch::_export::InputSpec input_spec_1;
input_spec_1.set_user_input(user_input_spec);
torch::_export::TensorArgument loss_tensor_arg;
loss_tensor_arg.set_name("loss");
torch::_export::LossOutputSpec loss_output_spec;
loss_output_spec.set_arg(loss_tensor_arg);
torch::_export::OutputSpec output_spec_0;
output_spec_0.set_loss_output(loss_output_spec);
torch::_export::TensorArgument output_tensor_arg;
output_tensor_arg.set_name("output");
torch::_export::Argument output_arg;
output_arg.set_as_tensor(output_tensor_arg);
torch::_export::UserOutputSpec user_output_spec;
user_output_spec.set_arg(output_arg);
torch::_export::OutputSpec output_spec_1;
output_spec_1.set_user_output(user_output_spec);
torch::_export::GraphSignature mock_storage;
mock_storage.set_input_specs({input_spec_0, input_spec_1});
mock_storage.set_output_specs({output_spec_0, output_spec_1});
// Initialize the GraphSignature object
graph_sig = GraphSignature(mock_storage);
}
};
// Test the constructor with a simple GraphSignature
TEST_F(GraphSignatureTest, ConstructorTest) {
std::vector<std::string_view> expected_params = {"param"};
EXPECT_EQ(graph_sig.parameters(), expected_params);
std::vector<std::string> expected_inputs = {"input"};
EXPECT_EQ(graph_sig.userInputs(), expected_inputs);
EXPECT_EQ(graph_sig.userInputs().size(), 1);
EXPECT_EQ(graph_sig.parameters().size(), 1);
EXPECT_EQ(graph_sig.lossOutput(), "loss");
std::vector<std::optional<std::string>> expected_outputs = {"output"};
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}
// Test the replaceAllUses method
TEST_F(GraphSignatureTest, ReplaceAllUsesTest) {
graph_sig.replaceAllUses("output", "new_output");
std::vector<std::optional<std::string>> expected_outputs = {"new_output"};
EXPECT_EQ(graph_sig.userOutputs(), expected_outputs);
}
} // namespace torch::nativert
|