File: parallel_benchmark.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (88 lines) | stat: -rw-r--r-- 2,138 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <torch/torch.h>
#include <chrono>
#include <condition_variable>
#include <mutex>

class Baton {
 public:
  void post() {
    std::unique_lock<std::mutex> l(lock_);
    done_ = true;
    cv_.notify_all();
  }
  void wait() {
    std::unique_lock<std::mutex> l(lock_);
    while (!done_) {
      cv_.wait(l);
    }
  }

 private:
  std::mutex lock_;
  std::condition_variable cv_;
  bool done_{false};
};

void AtLaunch_Base(int32_t numIters) {
  struct Helper {
    explicit Helper(int32_t lim) : limit_(lim) {}
    void operator()() {
      if (++val_ == limit_) {
        done.post();
      } else {
        at::launch([this]() { (*this)(); });
      }
    }
    int val_{0};
    int limit_;
    Baton done;
  };
  Helper h(numIters);
  auto start = std::chrono::system_clock::now();
  h();
  h.done.wait();
  std::cout << "NoData "
            << static_cast<double>(
                   std::chrono::duration_cast<std::chrono::microseconds>(
                       std::chrono::system_clock::now() - start)
                       .count()) /
          static_cast<double>(numIters)
            << " usec/each\n";
}

void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
  struct Helper {
    explicit Helper(int32_t lim) : limit_(lim) {}
    void operator()(std::vector<int32_t> v) {
      if (++val_ == limit_) {
        done.post();
      } else {
        at::launch([this, v = std::move(v)]() { (*this)(v); });
      }
    }
    int val_{0};
    int limit_;
    Baton done;
  };
  Helper h(numIters);
  std::vector<int32_t> v(vecSize, 0);
  auto start = std::chrono::system_clock::now();
  h(v);
  h.done.wait();
  std::cout << "WithData(" << vecSize << "): "
            << static_cast<double>(
                   std::chrono::duration_cast<std::chrono::microseconds>(
                       std::chrono::system_clock::now() - start)
                       .count()) /
          static_cast<double>(numIters)
            << " usec/each\n";
}

int main(int argc, char** argv) {
  int32_t N = 1000000;
  AtLaunch_Base(N);
  AtLaunch_WithData(N, 0);
  AtLaunch_WithData(N, 4);
  AtLaunch_WithData(N, 256);
  return 0;
}