File: throughput_benchmark-inl.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (151 lines) | stat: -rw-r--r-- 5,246 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#pragma once

#include <random>
#include <thread>

#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>

#include <ATen/Parallel.h>
#include <c10/util/irange.h>

namespace torch {
namespace throughput_benchmark {
namespace detail {

template <class Input, class Output, class Model>
BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
    const BenchmarkConfig& config) const {
  CHECK(initialized_);
  TORCH_CHECK(
      config.num_worker_threads == 1,
      "Only parallelization by callers is supported");

  LOG(INFO) << at::get_parallel_info();

  // We pre-generate inputs here for each of the threads. This allows us to
  // safely move inputs out for each of the threads independently and thus avoid
  // overhead from the benchmark runner itself
  std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
  std::vector<size_t> input_iters(config.num_calling_threads);
  {
    std::random_device seeder;
    std::mt19937 engine(seeder());
    TORCH_CHECK(
        !inputs_.empty(),
        "Please provide benchmark inputs."
        "Did you forget to call add_input()? ");
    std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);

    for (const auto thread_id : c10::irange(config.num_calling_threads)) {
      // Just in case we generate num_iters inputs for each of the threads
      // This was if one thread does all the work we will be fine
      for (const auto i :
           c10::irange(config.num_iters + config.num_warmup_iters)) {
        thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
      }
      input_iters[thread_id] = 0;
    }
  }

  std::mutex m;
  std::condition_variable worker_main_cv;
  std::condition_variable main_worker_cv;
  // TODO: add GUARDED_BY once it is available
  int64_t initialized{0};
  int64_t finished{0};
  bool start{false};
  std::atomic<int64_t> num_attempted_iters{0};
  std::vector<std::thread> callers;

  callers.reserve(config.num_calling_threads);
  for (const auto thread_id : c10::irange(config.num_calling_threads)) {
    callers.emplace_back([&, thread_id]() {
      // We use conditional variable as a barrier to make sure each thread
      // performs required warmeup iterations before we start measuring
      for (const auto j : c10::irange(config.num_warmup_iters)) {
        (void)j;
        runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
        ++input_iters[thread_id];
      }
      {
        std::unique_lock<std::mutex> lock(m);
        ++initialized;
        worker_main_cv.notify_one();
        // NOLINTNEXTLINE(bugprone-infinite-loop)
        while (!start) {
          main_worker_cv.wait(lock);
        }
      }
      LOG(INFO) << "Starting forward thread " << thread_id;
      while (num_attempted_iters.fetch_add(1) < config.num_iters) {
        runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
        ++input_iters[thread_id];
      }

      {
        std::unique_lock<std::mutex> lock(m);
        ++finished;
        worker_main_cv.notify_one();
        LOG(INFO) << "Shutting down forward thread " << thread_id
                  << ". Total number of finished threads: " << finished;
      }
    });
  }

  using Clock = std::chrono::high_resolution_clock;
  using RecordProfile = torch::autograd::profiler::RecordProfile;
  using TimePoint = std::chrono::time_point<Clock>;
  TimePoint start_time;

  std::unique_ptr<RecordProfile> profiler_guard;
  {
    std::unique_lock<std::mutex> lock(m);
    while (initialized != config.num_calling_threads) {
      worker_main_cv.wait(lock);
    }
    if (!config.profiler_output_path.empty()) {
      LOG(INFO) << "Using Autograd profiler. Trace will be saved to "
                << config.profiler_output_path;
      profiler_guard =
          std::make_unique<RecordProfile>(config.profiler_output_path);
    }
    LOG(INFO) << "Starting threads";
    start = true;
    start_time = Clock::now();
  }

  main_worker_cv.notify_all();
  {
    std::unique_lock<std::mutex> lock(m);
    worker_main_cv.wait(
        lock, [&]() { return finished == config.num_calling_threads; });
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  profiler_guard.reset();
  LOG(INFO) << "Finished benchmark";

  BenchmarkExecutionStats stats;
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
  float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
                            end_time - start_time)
                            .count() /
      1000.0 / 1000.0;
  // We use config.num_iters instead of num_attempted_iters as it is
  // repsesatative of the real work done. Last attempted iteration on each
  // calling threads doesn't represent the real work (i.e. running the model)
  stats.latency_avg_ms =
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      total_time_ms * config.num_calling_threads / config.num_iters;
  stats.num_iters = config.num_iters;

  for (auto& t : callers) {
    t.join();
  }
  return stats;
}

} // namespace detail
} // namespace throughput_benchmark
} // namespace torch