File: counters.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (279 lines) | stat: -rw-r--r-- 7,715 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#pragma once

#include <bitset>
#include <mutex>
#include <sstream>
#include <unordered_map>
#include <vector>

#include <c10/macros/Macros.h>

#include <torch/csrc/monitor/events.h>

namespace torch {
namespace monitor {

constexpr int NUM_AGGREGATIONS = 7;

// Aggregation is the list of possible aggregations for Stats.
// These use bitwise flags so they can be efficiently stored.
enum class C10_API_ENUM Aggregation {
  // NONE means no aggregations are set.
  NONE = 0,
  // VALUE exports the most recently set value.
  VALUE = 1,
  // MEAN computes the mean of the set values within the window. Zero if no
  // values.
  MEAN = 2,
  // COUNT tracks the number of times a value is set within the window.
  COUNT = 3,
  // SUM computes the sum of the values set within the window.
  SUM = 4,
  // MIN computes the minimum of the values set within the window. Zero if no
  // values.
  MAX = 5,
  // MAX computes the maximum of the values set within the window. Zero if no
  // values.
  MIN = 6,
};

struct TORCH_API AggregationHash {
  template <typename T>
  std::size_t operator()(T t) const {
    return static_cast<std::size_t>(t);
  }
};

// aggregationName returns the human readable name corresponding to the
// aggregation.
TORCH_API const char* aggregationName(Aggregation agg);

template <typename T>
class Stat;

namespace {
template <typename T>
inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
  std::bitset<NUM_AGGREGATIONS> a;
  for (Aggregation b : list) {
    a.set(static_cast<int>(b));
  }
  return a;
}
} // namespace

namespace detail {
void TORCH_API registerStat(Stat<double>* stat);
void TORCH_API registerStat(Stat<int64_t>* stat);
void TORCH_API unregisterStat(Stat<double>* stat);
void TORCH_API unregisterStat(Stat<int64_t>* stat);
} // namespace detail

// Stat is used to compute summary statistics in a performant way over fixed
// intervals. Stat logs the statistics as an Event once every `windowSize`
// duration. When the window closes the stats are logged via the event handlers
// as a `torch.monitor.Stat` event.
//
// `windowSize` should be set to something relatively high to avoid a huge
// number of events being logged. Ex: 60s. Stat uses millisecond precision.
//
// If maxSamples is set, the stat will cap the number of samples per window by
// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
// all `add` calls during the window will be included.
// This is an optional field to make aggregations more directly comparable
// across windows when the number of samples might vary.
//
// Stats support double and int64_t data types depending on what needs to be
// logged and needs to be templatized with one of them.
//
// When the Stat is destructed it will log any remaining data even if the window
// hasn't elapsed.
template <typename T>
class Stat {
 private:
  struct Values {
    T value{0};
    T sum{0};
    T min{0};
    T max{0};
    int64_t count{0};
  };

 public:
  Stat(
      std::string name,
      std::initializer_list<Aggregation> aggregations,
      std::chrono::milliseconds windowSize,
      int64_t maxSamples = std::numeric_limits<int64_t>::max())
      : name_(std::move(name)),
        aggregations_(merge(aggregations)),
        windowSize_(windowSize),
        maxSamples_(maxSamples) {
    detail::registerStat(this);
  }

  Stat(
      std::string name,
      std::vector<Aggregation> aggregations,
      std::chrono::milliseconds windowSize,
      int64_t maxSamples = std::numeric_limits<int64_t>::max())
      : name_(std::move(name)),
        aggregations_(merge(aggregations)),
        windowSize_(windowSize),
        maxSamples_(maxSamples) {
    detail::registerStat(this);
  }

  virtual ~Stat() {
    {
      // on destruction log if there's unlogged data
      std::lock_guard<std::mutex> guard(mu_);
      logLocked();
    }
    detail::unregisterStat(this);
  }

  // add adds the value v to the current window.
  void add(T v) {
    std::lock_guard<std::mutex> guard(mu_);
    maybeLogLocked();

    if (alreadyLogged()) {
      return;
    }

    if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
      current_.value = v;
    }
    if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
        aggregations_.test(static_cast<int>(Aggregation::SUM))) {
      current_.sum += v;
    }

    if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
      if (current_.max < v || current_.count == 0) {
        current_.max = v;
      }
    }
    if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
      if (current_.min > v || current_.count == 0) {
        current_.min = v;
      }
    }

    current_.count += 1;
    maybeLogLocked();
  }

  const std::string& name() const noexcept {
    return name_;
  }

  // count returns the number of items in the current open window.
  int64_t count() noexcept {
    std::lock_guard<std::mutex> guard(mu_);

    return current_.count;
  }

  std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
    std::lock_guard<std::mutex> guard(mu_);
    return getLocked();
  }

 protected:
  virtual uint64_t currentWindowId() const {
    std::chrono::milliseconds now =
        std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::steady_clock::now().time_since_epoch());

    // always returns a currentWindowId of at least 1 to avoid 0 window issues
    return (now / windowSize_) + 1;
  }

 private:
  bool alreadyLogged() {
    return lastLoggedWindowId_ == currentWindowId();
  }

  void maybeLogLocked() {
    auto windowId = currentWindowId();
    bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
    if (shouldLog && !alreadyLogged()) {
      logLocked();
      lastLoggedWindowId_ = windowId_;
      windowId_ = windowId;
    }
  }

  void logLocked() {
    prev_ = current_;
    current_ = Values();

    // don't log event if there's no data
    if (prev_.count == 0) {
      return;
    }

    Event e;
    e.name = "torch.monitor.Stat";
    e.timestamp = std::chrono::system_clock::now();

    auto stats = getLocked();
    e.data.reserve(stats.size());
    for (auto& kv : stats) {
      std::stringstream key;
      key << name_;
      key << ".";
      key << aggregationName(kv.first);
      e.data[key.str()] = kv.second;
    }

    logEvent(e);
  }

  std::unordered_map<Aggregation, T, AggregationHash> getLocked()
      const noexcept {
    std::unordered_map<Aggregation, T, AggregationHash> out;
    out.reserve(aggregations_.count());

    if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
      out.emplace(Aggregation::VALUE, prev_.value);
    }
    if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
      if (prev_.count == 0) {
        out.emplace(Aggregation::MEAN, 0);
      } else {
        out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
      }
    }
    if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
      out.emplace(Aggregation::COUNT, prev_.count);
    }
    if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
      out.emplace(Aggregation::SUM, prev_.sum);
    }
    if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
      out.emplace(Aggregation::MAX, prev_.max);
    }
    if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
      out.emplace(Aggregation::MIN, prev_.min);
    }

    return out;
  }

  const std::string name_;
  const std::bitset<NUM_AGGREGATIONS> aggregations_;

  std::mutex mu_;
  Values current_;
  Values prev_;

  uint64_t windowId_{0};
  uint64_t lastLoggedWindowId_{0};
  const std::chrono::milliseconds windowSize_;
  const int64_t maxSamples_;
};
} // namespace monitor
} // namespace torch