File: itt_observer.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (69 lines) | stat: -rw-r--r-- 2,321 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <torch/csrc/profiler/standalone/itt_observer.h>

#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>

namespace torch::profiler::impl {

struct ITTThreadLocalState : ProfilerStateBase {
  explicit ITTThreadLocalState(const ProfilerConfig& config)
      : ProfilerStateBase(config) {
    // Only `report_input_shapes` makes sense in this context.
    TORCH_CHECK(!config.profile_memory);
    TORCH_CHECK(!config.with_stack);
    TORCH_CHECK(!config.with_flops);
    TORCH_CHECK(!config.with_modules);
  }
  ~ITTThreadLocalState() override = default;

  ActiveProfilerType profilerType() override {
    return ActiveProfilerType::ITT;
  }

  void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
  }

  static ITTThreadLocalState* getTLS() {
    auto tls = ProfilerStateBase::get(/*global=*/false);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
    return static_cast<ITTThreadLocalState*>(tls);
  }
};

template <bool report_input_shapes>
std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
  if (ITTThreadLocalState::getTLS() != nullptr) {
    torch::profiler::impl::ittStubs()->rangePush(fn.name());
  }
  return nullptr;
}

void pushITTCallbacks(
    const ProfilerConfig& config,
    const std::unordered_set<at::RecordScope>& scopes) {
  TORCH_CHECK(
      torch::profiler::impl::ittStubs()->enabled(),
      "Can't use ITT profiler - PyTorch was compiled without ITT");

  c10::ThreadLocalDebugInfo::_push(
      c10::DebugInfoKind::PROFILER_STATE,
      std::make_shared<ITTThreadLocalState>(config));

  auto state_ptr = ITTThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");

  auto handle = at::addThreadLocalCallback(
      at::RecordFunctionCallback(
          state_ptr->config().report_input_shapes
              ? &enterITT</*report_input_shapes=*/true>
              : &enterITT</*report_input_shapes=*/false>,
          [](const at::RecordFunction&, at::ObserverContext*) {
            torch::profiler::impl::ittStubs()->rangePop();
          })
          .needsInputs(config.report_input_shapes)
          .scopes(scopes));
  state_ptr->setCallbackHandle(handle);
}

} // namespace torch::profiler::impl