1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <vector>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
namespace torch {
namespace jit {
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
static std::unordered_map<std::string, size_t> parseJITLogOption(
const char* option) {
std::stringstream in_ss;
in_ss << "function:";
if (option) {
in_ss << option;
}
std::unordered_map<std::string, size_t> files_to_levels;
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.size() == 0) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
files_to_levels.insert({filename, logging_level});
}
return files_to_levels;
}
bool is_enabled(const char* cfname, JitLoggingLevels level) {
static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
static const std::unordered_map<std::string, size_t> files_to_levels =
parseJITLogOption(c_log_level);
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
auto fname_no_ext = fname.substr(0, end_index);
auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch
|