1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
#include <torch/csrc/profiler/itt_observer.h>
#include <torch/csrc/profiler/util.h>
namespace torch {
namespace profiler {
namespace impl {
struct ITTThreadLocalState : ProfilerStateBase {
explicit ITTThreadLocalState(const ProfilerConfig& config)
: ProfilerStateBase(config) {
// Only `report_input_shapes` makes sense in this context.
TORCH_CHECK(!config.profile_memory);
TORCH_CHECK(!config.with_stack);
TORCH_CHECK(!config.with_flops);
TORCH_CHECK(!config.with_modules);
}
~ITTThreadLocalState() override = default;
ActiveProfilerType profilerType() override {
return ActiveProfilerType::ITT;
}
void reportMemoryUsage(void*, int64_t, int64_t, int64_t, c10::Device)
override {}
static ITTThreadLocalState* getTLS() {
auto tls = ProfilerStateBase::get(/*global=*/false);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
return static_cast<ITTThreadLocalState*>(tls);
}
};
template <bool report_input_shapes>
std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
if (ITTThreadLocalState::getTLS() != nullptr) {
torch::profiler::impl::ittStubs()->rangePush(fn.name());
}
return nullptr;
}
void pushITTCallbacks(
const ProfilerConfig& config,
const std::unordered_set<at::RecordScope>& scopes) {
TORCH_CHECK(
torch::profiler::impl::ittStubs()->enabled(),
"Can't use ITT profiler - PyTorch was compiled without ITT");
c10::ThreadLocalDebugInfo::_push(
c10::DebugInfoKind::PROFILER_STATE,
std::make_shared<ITTThreadLocalState>(config));
auto state_ptr = ITTThreadLocalState::getTLS();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(
at::RecordFunctionCallback(
state_ptr->config().report_input_shapes
? &enterITT</*report_input_shapes=*/true>
: &enterITT</*report_input_shapes=*/false>,
[](const at::RecordFunction&, at::ObserverContext*) {
torch::profiler::impl::ittStubs()->rangePop();
})
.needsInputs(config.report_input_shapes)
.scopes(scopes));
state_ptr->setCallbackHandle(handle);
}
} // namespace impl
} // namespace profiler
} // namespace torch
|