1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
|
#pragma once
#include <bitset>
#include <mutex>
#include <sstream>
#include <unordered_map>
#include <vector>
#include <c10/macros/Macros.h>
#include <torch/csrc/monitor/events.h>
namespace torch {
namespace monitor {
constexpr int NUM_AGGREGATIONS = 7;
// Aggregation is the list of possible aggregations for Stats.
// These use bitwise flags so they can be efficiently stored.
enum class C10_API_ENUM Aggregation {
// NONE means no aggregations are set.
NONE = 0,
// VALUE exports the most recently set value.
VALUE = 1,
// MEAN computes the mean of the set values within the window. Zero if no
// values.
MEAN = 2,
// COUNT tracks the number of times a value is set within the window.
COUNT = 3,
// SUM computes the sum of the values set within the window.
SUM = 4,
// MIN computes the minimum of the values set within the window. Zero if no
// values.
MAX = 5,
// MAX computes the maximum of the values set within the window. Zero if no
// values.
MIN = 6,
};
struct TORCH_API AggregationHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};
// aggregationName returns the human readable name corresponding to the
// aggregation.
TORCH_API const char* aggregationName(Aggregation agg);
template <typename T>
class Stat;
namespace {
template <typename T>
inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
std::bitset<NUM_AGGREGATIONS> a;
for (Aggregation b : list) {
a.set(static_cast<int>(b));
}
return a;
}
} // namespace
namespace detail {
void TORCH_API registerStat(Stat<double>* stat);
void TORCH_API registerStat(Stat<int64_t>* stat);
void TORCH_API unregisterStat(Stat<double>* stat);
void TORCH_API unregisterStat(Stat<int64_t>* stat);
} // namespace detail
// Stat is used to compute summary statistics in a performant way over fixed
// intervals. Stat logs the statistics as an Event once every `windowSize`
// duration. When the window closes the stats are logged via the event handlers
// as a `torch.monitor.Stat` event.
//
// `windowSize` should be set to something relatively high to avoid a huge
// number of events being logged. Ex: 60s. Stat uses millisecond precision.
//
// If maxSamples is set, the stat will cap the number of samples per window by
// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
// all `add` calls during the window will be included.
// This is an optional field to make aggregations more directly comparable
// across windows when the number of samples might vary.
//
// Stats support double and int64_t data types depending on what needs to be
// logged and needs to be templatized with one of them.
//
// When the Stat is destructed it will log any remaining data even if the window
// hasn't elapsed.
template <typename T>
class Stat {
private:
struct Values {
T value{0};
T sum{0};
T min{0};
T max{0};
int64_t count{0};
};
public:
Stat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
Stat(
std::string name,
std::vector<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
virtual ~Stat() {
{
// on destruction log if there's unlogged data
std::lock_guard<std::mutex> guard(mu_);
logLocked();
}
detail::unregisterStat(this);
}
// add adds the value v to the current window.
void add(T v) {
std::lock_guard<std::mutex> guard(mu_);
maybeLogLocked();
if (alreadyLogged()) {
return;
}
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
current_.value = v;
}
if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
aggregations_.test(static_cast<int>(Aggregation::SUM))) {
current_.sum += v;
}
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
if (current_.max < v || current_.count == 0) {
current_.max = v;
}
}
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
if (current_.min > v || current_.count == 0) {
current_.min = v;
}
}
current_.count += 1;
maybeLogLocked();
}
const std::string& name() const noexcept {
return name_;
}
// count returns the number of items in the current open window.
int64_t count() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return current_.count;
}
std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return getLocked();
}
protected:
virtual uint64_t currentWindowId() const {
std::chrono::milliseconds now =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
// always returns a currentWindowId of at least 1 to avoid 0 window issues
return (now / windowSize_) + 1;
}
private:
bool alreadyLogged() {
return lastLoggedWindowId_ == currentWindowId();
}
void maybeLogLocked() {
auto windowId = currentWindowId();
bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
if (shouldLog && !alreadyLogged()) {
logLocked();
lastLoggedWindowId_ = windowId_;
windowId_ = windowId;
}
}
void logLocked() {
prev_ = current_;
current_ = Values();
// don't log event if there's no data
if (prev_.count == 0) {
return;
}
Event e;
e.name = "torch.monitor.Stat";
e.timestamp = std::chrono::system_clock::now();
auto stats = getLocked();
e.data.reserve(stats.size());
for (auto& kv : stats) {
std::stringstream key;
key << name_;
key << ".";
key << aggregationName(kv.first);
e.data[key.str()] = kv.second;
}
logEvent(e);
}
std::unordered_map<Aggregation, T, AggregationHash> getLocked()
const noexcept {
std::unordered_map<Aggregation, T, AggregationHash> out;
out.reserve(aggregations_.count());
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
out.emplace(Aggregation::VALUE, prev_.value);
}
if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
if (prev_.count == 0) {
out.emplace(Aggregation::MEAN, 0);
} else {
out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
}
}
if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
out.emplace(Aggregation::COUNT, prev_.count);
}
if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
out.emplace(Aggregation::SUM, prev_.sum);
}
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
out.emplace(Aggregation::MAX, prev_.max);
}
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
out.emplace(Aggregation::MIN, prev_.min);
}
return out;
}
const std::string name_;
const std::bitset<NUM_AGGREGATIONS> aggregations_;
std::mutex mu_;
Values current_;
Values prev_;
uint64_t windowId_{0};
uint64_t lastLoggedWindowId_{0};
const std::chrono::milliseconds windowSize_;
const int64_t maxSamples_;
};
} // namespace monitor
} // namespace torch
|