File: test_cpp_thread.cpp

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 161,668 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (119 lines) | stat: -rw-r--r-- 3,753 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

#include <torch/csrc/autograd/profiler_kineto.h>  // @manual
#include <torch/torch.h>
#include <string>

using namespace torch::autograd::profiler;

void blueprint(const std::string& text) {
  printf("\33[94m%s\33[0m\n", text.c_str());
}

/**
 * We're emulating a C++ training engine calling into Python to allow Python
 * code controlling how profiling should be done.
 */
class ProfilerEventHandler
    : public std::enable_shared_from_this<ProfilerEventHandler> {
 public:
  static std::shared_ptr<ProfilerEventHandler> Handler;
  static void Register(const std::shared_ptr<ProfilerEventHandler>& handler) {
    Handler = handler;
  }

 public:
  virtual ~ProfilerEventHandler() {}
  virtual void onIterationStart(int) {}
  virtual void emulateTraining(int, int) {}
};
std::shared_ptr<ProfilerEventHandler> ProfilerEventHandler::Handler;

class ProfilerEventHandlerTrampoline : public ProfilerEventHandler {
 public:
  virtual void onIterationStart(int iteration) override {
    PYBIND11_OVERRIDE(void, ProfilerEventHandler, onIterationStart, iteration);
  }
  virtual void emulateTraining(int iteration, int thread_id) override {
    PYBIND11_OVERRIDE(
        void, ProfilerEventHandler, emulateTraining, iteration, thread_id);
  }
};

/**
 * This is the entry point for the C++ training engine.
 */
void start_threads(int thread_count, int iteration_count, bool attach) {
  blueprint("start_cpp_threads called");

  static std::atomic<int> barrier = 0;
  barrier = 0;
  static std::atomic<int> another_barrier = 0;
  another_barrier = 0;
  thread_local bool enabled_in_main_thread = false;

  std::vector<std::thread> threads;
  for (int id = 0; id < thread_count; id++) {
    blueprint("starting thread " + std::to_string(id));
    threads.emplace_back([thread_count, iteration_count, id, attach]() {
      for (int iteration = 0; iteration < iteration_count; iteration++) {
        if (id == 0) {
          ProfilerEventHandler::Handler->onIterationStart(iteration);
        }

        // this barrier makes sure all child threads will be turned on
        // with profiling when main thread is enabled
        ++barrier;
        while (barrier % thread_count) {
          std::this_thread::yield();
        }

        if (id > 0 && attach) {
          bool enabled = isProfilerEnabledInMainThread();
          if (enabled != enabled_in_main_thread) {
            if (enabled) {
              enableProfilerInChildThread();
            } else {
              disableProfilerInChildThread();
            }
            enabled_in_main_thread = enabled;
          }
        }

        ProfilerEventHandler::Handler->emulateTraining(iteration, id);

        // We need another barrier here to ensure that the main thread doesn't
        // stop the profiler while other threads are still using it. This fixes
        // https://github.com/pytorch/pytorch/issues/132331
        ++another_barrier;
        while (another_barrier % thread_count) {
          std::this_thread::yield();
        }
      }
    });
  }
  for (auto& t : threads) {
    t.join();
  }
}

PYBIND11_MODULE(profiler_test_cpp_thread_lib, m) {
  py::class_<
      ProfilerEventHandler,
      ProfilerEventHandlerTrampoline,
      std::shared_ptr<ProfilerEventHandler>>(m, "ProfilerEventHandler")
      .def(py::init<>())
      .def_static("Register", &ProfilerEventHandler::Register)
      .def(
          "onIterationStart",
          &ProfilerEventHandler::onIterationStart,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "emulateTraining",
          &ProfilerEventHandler::emulateTraining,
          py::call_guard<py::gil_scoped_release>());

  m.def(
      "start_threads",
      &start_threads,
      py::call_guard<py::gil_scoped_release>());
};