1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
|
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
namespace tensorexpr {
/*
ExecutionTrigger and ExecutionCounter builds instrumentation counters so
underlying functionalities can be checked.
In the code to be instrumented:
// worker.cpp
DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done"
void run() {
USE_TRIGGER(useful_work_done); // this triggers the underlying counter
// in "useful_work_done"
}
// in C++ client.cpp
DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that
// will be defined elsewhere
ExecutionCounter counter(useful_work_done); // This starts the counter from the
// underlying trigger.
... call run() ...
counter.elapsed_value(); // this returns the incremented value from the
// trigger since the creation of the counter
// in Python client.py
counter = ExecutionCounter("useful_work_done") // this starts the counter from
// the underlying trigger
... call C++ run() ...
counter.elapsed_value() // This returns the incremented value from the
// trigger since the creation of the counter.
*/
class ExecutionTrigger;
class ExecutionTriggerList {
public:
TORCH_API static ExecutionTriggerList& GetInstance() {
static ExecutionTriggerList instance;
return instance;
}
ExecutionTrigger* FindByName(const std::string& name) const {
auto iter = trigger_list_.find(name);
if (iter == trigger_list_.end()) {
throw std::runtime_error("Invalid trigger name: " + name);
}
return iter->second;
}
private:
friend class ExecutionTrigger;
ExecutionTriggerList() {}
ExecutionTriggerList(const ExecutionTriggerList&) = delete;
ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete;
void AddTrigger(const std::string& name, ExecutionTrigger* trigger) {
auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger));
if (!insert_ret.second) {
std::cerr << "Warning: duplicated trigger name: " << name << "\n";
}
}
std::unordered_map<std::string, ExecutionTrigger*> trigger_list_;
};
class ExecutionTrigger {
public:
explicit ExecutionTrigger(const std::string& name) : name_(name) {
ExecutionTriggerList::GetInstance().AddTrigger(name, this);
}
int value() const {
return value_;
}
void trigger() {
value_++;
}
private:
ExecutionTrigger(const ExecutionTrigger&) = delete;
ExecutionTrigger& operator=(const ExecutionTrigger&) = delete;
int value_ = 0;
const std::string name_;
};
class ExecutionCounter {
public:
explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) {
start_value_ = trigger_.value();
}
int elapsed_value() const {
return trigger_.value() - start_value_;
}
private:
ExecutionTrigger& trigger_;
int start_value_ = 0;
};
#define DEFINE_TRIGGER(name) ExecutionTrigger name(#name)
#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name
#define USE_TRIGGER(name) (name).trigger()
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|