1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
namespace torch {
namespace jit {
class JitLoggingConfig {
public:
static JitLoggingConfig& getInstance() {
static JitLoggingConfig instance;
return instance;
}
JitLoggingConfig(JitLoggingConfig const&) = delete;
void operator=(JitLoggingConfig const&) = delete;
private:
std::string logging_levels;
std::unordered_map<std::string, size_t> files_to_levels;
std::ostream* out;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
JitLoggingConfig() {
const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
out = &std::cerr;
parse();
}
void parse();
public:
std::string getLoggingLevels() const {
return this->logging_levels;
}
void setLoggingLevels(std::string levels) {
this->logging_levels = std::move(levels);
parse();
}
const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
return this->files_to_levels;
}
void setOutputStream(std::ostream& out_stream) {
this->out = &out_stream;
}
std::ostream& getOutputStream() {
return *(this->out);
}
};
std::string get_jit_logging_levels() {
return JitLoggingConfig::getInstance().getLoggingLevels();
}
void set_jit_logging_levels(std::string level) {
JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
}
void set_jit_logging_output_stream(std::ostream& stream) {
JitLoggingConfig::getInstance().setOutputStream(stream);
}
std::ostream& get_jit_logging_output_stream() {
return JitLoggingConfig::getInstance().getOutputStream();
}
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
void JitLoggingConfig::parse() {
std::stringstream in_ss;
in_ss << "function:" << this->logging_levels;
files_to_levels.clear();
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.size() == 0) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
files_to_levels.insert({filename, logging_level});
}
}
bool is_enabled(const char* cfname, JitLoggingLevels level) {
const auto& files_to_levels =
JitLoggingConfig::getInstance().getFilesToLevels();
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
const auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
const auto fname_no_ext = fname.substr(0, end_index);
const auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch
|