1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <stack>
#include <unordered_set>
// Tests go in torch::jit
namespace torch {
namespace jit {
namespace {
bool validate_debug_info(
const DebugInfoTuple& pre_serialize,
const DebugInfoTuple& post_serialize) {
auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
if (sr1 != sr2) {
return false;
}
auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
if (!csptr1.defined()) {
return !csptr2.defined();
}
if (!csptr2.defined()) {
return false;
}
auto vec1 = csptr1->vec();
auto vec2 = csptr2->vec();
if (vec1.size() != vec2.size()) {
return false;
}
while (csptr1) {
auto rhs_sr = csptr1->source_range();
auto lhs_sr = csptr2->source_range();
auto rhs_module = csptr1->module_instance();
auto lhs_module = csptr2->module_instance();
std::string rhs_fn_name, lhs_fn_name;
if (csptr1->function()) {
rhs_fn_name = csptr1->function()->name();
} else {
rhs_fn_name = csptr1->function_name();
}
if (csptr2->function()) {
lhs_fn_name = csptr2->function()->name();
} else {
lhs_fn_name = csptr2->function_name();
}
if (!((rhs_module.has_value() == lhs_module.has_value()) &&
(rhs_module.has_value() &&
(rhs_module.value().class_type()->name().value() ==
lhs_module.value().class_type()->name().value()) &&
(rhs_module.value().instance_name() ==
lhs_module.value().instance_name())) &&
(rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
return false;
}
if (csptr1->callee()) {
csptr1 = csptr1->callee().value();
csptr2 = csptr2->callee().value();
} else {
csptr1 = c10::intrusive_ptr<InlinedCallStack>();
}
}
return true;
}
TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
Module a("A", cu);
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B", cu);
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C", cu);
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.B0.forward(x)
)JIT");
BackendDebugInfoRecorder debug_info_recorder;
auto graph = c.get_method("forward").graph();
Inline(*graph);
std::stack<Block*> blocks_to_visit;
// maps from source range to debug handle
SourceRangeTagMap source_range_tags;
// Maps from debug handle to source range
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
int64_t source_range_tag{0};
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
source_range_tags[n->sourceRange()] = source_range_tag;
source_range_map[source_range_tag] = n->sourceRange();
source_range_tag++;
debug_info_recorder.getNextDebugHandle(n);
if (n->callstack().has_value()) {
for (const auto& e : n->callstack().value()->vec()) {
auto sr = std::get<1>(e);
source_range_tags[sr] = source_range_tag;
source_range_map[source_range_tag] = sr;
source_range_tag++;
}
}
}
}
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
CallStackDebugInfoPickler cs_debug_info_pickler;
auto cs_data =
cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
CallStackDebugInfoUnpickler unpickler;
auto deserialized_cs_map = unpickler.unpickle(
std::move(data_ptr), cs_data.size(), source_range_map, cu);
for (const auto& it : debug_handle_cs_ptr_map) {
auto handle = it.first;
auto debug_info_one = it.second;
TORCH_CHECK(
deserialized_cs_map.count(handle),
"Serialized debug handle must be in deserialized map.");
auto debug_info_two = deserialized_cs_map[handle];
ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
}
}
} // namespace
} // namespace jit
} // namespace torch
|