File: itt_observer.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (72 lines) | stat: -rw-r--r-- 2,321 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <torch/csrc/profiler/itt_observer.h>

#include <torch/csrc/profiler/util.h>

namespace torch {
namespace profiler {
namespace impl {

struct ITTThreadLocalState : ProfilerStateBase {
  explicit ITTThreadLocalState(const ProfilerConfig& config)
      : ProfilerStateBase(config) {
    // Only `report_input_shapes` makes sense in this context.
    TORCH_CHECK(!config.profile_memory);
    TORCH_CHECK(!config.with_stack);
    TORCH_CHECK(!config.with_flops);
    TORCH_CHECK(!config.with_modules);
  }
  ~ITTThreadLocalState() override = default;

  ActiveProfilerType profilerType() override {
    return ActiveProfilerType::ITT;
  }

  void reportMemoryUsage(void*, int64_t, int64_t, int64_t, c10::Device)
      override {}

  static ITTThreadLocalState* getTLS() {
    auto tls = ProfilerStateBase::get(/*global=*/false);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
    return static_cast<ITTThreadLocalState*>(tls);
  }
};

template <bool report_input_shapes>
std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
  if (ITTThreadLocalState::getTLS() != nullptr) {
    torch::profiler::impl::ittStubs()->rangePush(fn.name());
  }
  return nullptr;
}

void pushITTCallbacks(
    const ProfilerConfig& config,
    const std::unordered_set<at::RecordScope>& scopes) {
  TORCH_CHECK(
      torch::profiler::impl::ittStubs()->enabled(),
      "Can't use ITT profiler - PyTorch was compiled without ITT");

  c10::ThreadLocalDebugInfo::_push(
      c10::DebugInfoKind::PROFILER_STATE,
      std::make_shared<ITTThreadLocalState>(config));

  auto state_ptr = ITTThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");

  auto handle = at::addThreadLocalCallback(
      at::RecordFunctionCallback(
          state_ptr->config().report_input_shapes
              ? &enterITT</*report_input_shapes=*/true>
              : &enterITT</*report_input_shapes=*/false>,
          [](const at::RecordFunction&, at::ObserverContext*) {
            torch::profiler::impl::ittStubs()->rangePop();
          })
          .needsInputs(config.report_input_shapes)
          .scopes(scopes));
  state_ptr->setCallbackHandle(handle);
}

} // namespace impl
} // namespace profiler
} // namespace torch