1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
#pragma once
#include <shared_mutex>
#include <utility>
#include <torch/csrc/autograd/profiler.h>
namespace torch::distributed::rpc::profiler::processglobal {
using namespace torch::autograd::profiler;
// Process global profiler state.
//
// This class holds information about a profiling range, from "enable" to
// "disable".
// An instance of this ``State`` will be
// pushed into a global stack, so nested profiling range is supported.
//
// It has 2 members.
// One is ``autograd::profiler::ProfilerConfig``. It's set by user and
// will be copied to thread-local profiler state of RPC threads.
// The other is a container that aggregates recorded
// ``autograd::profiler::Event``s from all thread-local profilers on RPC
// threads.
class State {
public:
explicit State(ProfilerConfig config) : config_(std::move(config)) {}
~State() = default;
const ProfilerConfig& config() const {
return config_;
}
void pushResult(thread_event_lists result) {
std::unique_lock<std::mutex> lock(resultsMutex_);
// NB: When a thread wants to push an entry into the this container,
// main control logic might have exited the process-global profile range.
results_.emplace_back(std::move(result));
}
std::vector<thread_event_lists> results();
private:
// Each result comes from a profile range. In each profile range, there is a
// "__profiler_start" marker event that all following events calculate time
// relative to it, so it's required to call
// parse_cpu_trace(result) for results of all profile range.
std::mutex resultsMutex_;
std::vector<thread_event_lists> results_;
const ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
};
class StateStackEntry;
#if defined(__MACH__)
// Compiler error: 'shared_timed_mutex' is unavailable: introduced in
// macOS 10.12
using mutexType = std::mutex;
// Compiler error: 'shared_lock' is unavailable: introduced in
// macOS 10.12
using rLockType = std::unique_lock<std::mutex>;
using wLockType = std::unique_lock<std::mutex>;
#else
using mutexType = std::shared_timed_mutex;
using rLockType = std::shared_lock<std::shared_timed_mutex>;
using wLockType = std::unique_lock<std::shared_timed_mutex>;
#endif
// This is the global stack of ``State``s.
TORCH_API extern std::shared_ptr<StateStackEntry> currentStateStackEntryPtr;
TORCH_API extern mutexType currentStateStackEntryMutex;
// This class is used to implement a stack of ``State``s.
// It has 2 members.
// One is `prevPtr`, a shared_ptr pointing to previous element in the
// stack.
// The other is ``statePtr``, a shared_ptr pointing to ``State``.
class StateStackEntry {
public:
StateStackEntry(
std::shared_ptr<StateStackEntry> prevPtr,
std::shared_ptr<State> statePtr)
: prevPtr_(std::move(prevPtr)), statePtr_(std::move(statePtr)) {}
static void pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr);
static std::shared_ptr<State> popRange();
static std::shared_ptr<StateStackEntry> current() {
rLockType rlock(currentStateStackEntryMutex);
return currentStateStackEntryPtr;
}
std::shared_ptr<StateStackEntry> prevPtr() const {
return prevPtr_;
}
std::shared_ptr<State> statePtr() const {
return statePtr_;
}
private:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::shared_ptr<StateStackEntry> prevPtr_{nullptr};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::shared_ptr<State> statePtr_{nullptr};
};
// Push the result to ``State``s of current profile range and recursively outer
// profile ranges.
TORCH_API void pushResultRecursive(
std::shared_ptr<StateStackEntry> stateStackEntryPtr,
const thread_event_lists& result);
// User-facing API.
//
// Enter a server-side process-global profiling range.
// Profiling range can be neste, so it's ok to call this API for multiple
// times. This enables all RPC threads running server-side request callbacks.
TORCH_API void enableServer(const ProfilerConfig& new_config);
//
// Exit a server-side process-global profiling range.
// Profiling range can be neste, so it's possible that profiler is still on
// after calling this API.
// This enables all RPC threads running server-side request callbacks.
TORCH_API std::vector<thread_event_lists> disableServer();
} // namespace torch::distributed::rpc::profiler::processglobal
|