1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
|
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
#include <mutex>
namespace torch {
namespace jit {
namespace mobile {
CustomClassTracer::CustomClassTracer() {
auto recorder_cb =
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::string name = fn.name();
getLoadedClasses().withLock(
[&name](CustomClassTracer::custom_classes_type& custom_classes) {
custom_classes.insert(name);
});
return nullptr;
};
handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
.scopes({at::RecordScope::CUSTOM_CLASS}));
}
c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer::
getLoadedClasses() {
static c10::Synchronized<custom_classes_type> loaded_classes;
return loaded_classes;
}
} // namespace mobile
} // namespace jit
} // namespace torch
|