1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
#include <torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.h>
#include <mutex>
namespace torch {
namespace jit {
namespace mobile {
BuildFeatureTracer::BuildFeatureTracer() {
auto recorder_cb =
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::string name = fn.name();
getBuildFeatures().withLock(
[&](BuildFeatureTracer::build_feature_type& build_features) {
build_features.insert(name);
});
return nullptr;
};
handle_ =
at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
.scopes({at::RecordScope::BUILD_FEATURE}));
}
c10::Synchronized<BuildFeatureTracer::build_feature_type>& BuildFeatureTracer::
getBuildFeatures() {
static c10::Synchronized<build_feature_type> build_features;
return build_features;
}
} // namespace mobile
} // namespace jit
} // namespace torch
|