1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace profiler {
namespace processglobal {
using namespace torch::autograd::profiler;
std::vector<thread_event_lists> State::results() {
std::unique_lock<std::mutex> lock(resultsMutex_);
std::vector<thread_event_lists> results;
results.swap(results_);
return results;
}
mutexType currentStateStackEntryMutex;
std::shared_ptr<StateStackEntry> currentStateStackEntryPtr = nullptr;
void StateStackEntry::pushRange(
std::shared_ptr<State> profilerProcessGlobalStatePtr) {
wLockType wlock(currentStateStackEntryMutex);
auto previousStateStackEntryPtr = currentStateStackEntryPtr;
currentStateStackEntryPtr = std::make_shared<StateStackEntry>(
previousStateStackEntryPtr, std::move(profilerProcessGlobalStatePtr));
}
std::shared_ptr<State> StateStackEntry::popRange() {
wLockType wlock(currentStateStackEntryMutex);
auto poppedStateStackEntryPtr = currentStateStackEntryPtr;
TORCH_INTERNAL_ASSERT(
poppedStateStackEntryPtr && poppedStateStackEntryPtr->statePtr_);
currentStateStackEntryPtr = poppedStateStackEntryPtr->prevPtr_;
return poppedStateStackEntryPtr->statePtr_;
}
void pushResultRecursive(
std::shared_ptr<StateStackEntry> stateStackEntryPtr,
const thread_event_lists& result) {
while (stateStackEntryPtr) {
// Put event_lists into the process-global profiler state.
stateStackEntryPtr->statePtr()->pushResult(result);
stateStackEntryPtr = stateStackEntryPtr->prevPtr();
}
}
void enableServer(const ProfilerConfig& new_config) {
auto new_state = std::make_shared<State>(new_config);
StateStackEntry::pushRange(std::move(new_state));
}
std::vector<thread_event_lists> disableServer() {
auto statePtr = StateStackEntry::popRange();
return statePtr->results();
}
} // namespace processglobal
} // namespace profiler
} // namespace rpc
} // namespace distributed
} // namespace torch
|