1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
#include <torch/csrc/jit/passes/clear_undefinedness.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
void clearUndefinedness(Value* o) {
if (o->type()->kind() == TensorType::Kind) {
o->setType(TensorType::get());
} else if (
o->type()->kind() == ListType::Kind &&
o->type()->expectRef<ListType>().getElementType()->kind() ==
TensorType::Kind) {
o->setType(ListType::create(TensorType::get()));
}
}
void clearUndefinedness(Block* block) {
for (auto n : block->nodes()) {
for (auto o : n->outputs()) {
clearUndefinedness(o);
}
for (auto ib : n->blocks()) {
clearUndefinedness(ib);
}
}
}
void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
clearUndefinedness(i);
}
clearUndefinedness(graph->block());
GRAPH_DUMP("After removeUndefinedness: ", graph);
}
} // namespace jit
} // namespace torch
|