1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
|
#pragma once
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
#include <ATen/Tensor.h>
#include <ATen/ThreadLocalState.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/warnings.h>
#include <c10/util/CallOnce.h>
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <queue>
#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
struct ReadyQueue;
}
} // namespace torch
namespace torch {
namespace autograd {
// Maximum reentrant backward depth before switching to a new thread
// This limit is based on the TSAN's deadlock detector, where it will
// fail if a program hold more than 65 locks in one thread at once.
// As we hold mutex in every of our custom C++ autograd Node, we would
// like to avoid TSAN complains on this when doing reentrant backwards
// For reference, see https://github.com/google/sanitizers/issues/950
static constexpr int MAX_DEPTH = 60;
void set_device(int device);
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
struct NodeTask {
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;
// This buffer serves as an implicit "addition" node for all of the
// gradients flowing here. Once all the dependencies are finished, we
// use the contents of this buffer to run the function.
InputBuffer inputs_;
// When worker receives a task with isShutdownTask = true, it will immediately
// exit. The engine sends a shutdown task to every queue upon its destruction.
bool isShutdownTask_;
int getReentrantDepth() const;
NodeTask(
// NOLINTNEXTLINE(modernize-pass-by-value)
std::weak_ptr<GraphTask> base,
std::shared_ptr<Node> fn,
InputBuffer inputs,
bool isShutdownTask = false)
: base_(base),
fn_(std::move(fn)),
inputs_(std::move(inputs)),
isShutdownTask_(isShutdownTask) {}
};
// Guard that sets and restores checkpoint_valid
class CheckpointValidGuard {
public:
explicit CheckpointValidGuard(
const std::shared_ptr<const GraphTask>& graph_task);
~CheckpointValidGuard();
private:
bool prev_checkpoint_valid_state;
};
struct ReadyQueue {
private:
// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
// Shutdown tasks are first and then empty NodeTask are next.
struct CompareNodeTaskTime {
bool operator()(NodeTask const& t1, NodeTask const& t2) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (t2.isShutdownTask_) {
return true;
} else if (!t1.fn_ || t1.isShutdownTask_) {
return false;
} else if (!t2.fn_) {
return true;
} else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
} else {
return t1.getReentrantDepth() < t2.getReentrantDepth();
}
}
};
// To notify threads waiting on the ReadyQueue of available tasks on the heap_
std::condition_variable not_empty_;
// To protect read and writes to heap_
mutable std::mutex mutex_;
std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime>
heap_;
public:
// incrementOutstandingTasks indicates whether or not we should increment
// 'outstanding_tasks_' for the associated GraphTask. This should mostly
// always be true and is only set false in certain cases (see docs for
// DistEngine.execute_graph_task_until_ready_queue_empty)
void push(NodeTask item, bool incrementOutstandingTasks = true);
void pushShutdownTask();
NodeTask pop();
bool empty() const;
size_t size() const;
};
// A single instance of this struct should be created through the whole process
// lifetime. The worker thread creation logic and Engine's destructor rely on
// this.
struct TORCH_API Engine {
/// Returns a reference to a static `Engine` instance.
static Engine& get_default_engine();
static Engine& get_base_engine();
Engine(const Engine&) = delete;
Engine(Engine&&) = delete;
virtual ~Engine();
// Given a list of (Node, input number) pairs computes the value of the graph
// by following next_edge references.
virtual variable_list execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs = {});
// Given a pre-populated GraphTask and GraphRoot, computes the backward pass
// for the graph.
//
// NB: This API should only be used by internal autograd specific
// machinery and shouldn't be exposed to users in anyway.
virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
InputBuffer&& input_buffer);
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
return std::make_unique<AnomalyMetadata>();
}
virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() {
return nullptr;
}
// We pass cpu_ready_queue to evaluate_function, so that it knows
// the correct ready queue to push to after a NodeTask is ready
void evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue);
void initialize_device_threads_pool();
virtual void thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e);
void queue_callback(std::function<void()> callback);
bool is_checkpoint_valid();
// Should be called after fork to notify that worker threads are gone
void release_workers();
// Must be called by subclass before destructing to avoid a data-race-on-vptr.
void stop();
// Initializes a device thread for the autograd engine.
virtual void thread_init(
int device,
const std::shared_ptr<ReadyQueue>& ready_queue,
bool should_increment = true);
protected:
Engine();
void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr);
// initialize the thread local ready queue with the ready queue that is
// created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new
// ready queue if ready_queue is not provided.
void init_local_ready_queue(
std::shared_ptr<ReadyQueue> ready_queue = nullptr);
std::shared_ptr<ReadyQueue> ready_queue(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
at::Device device);
std::shared_ptr<ReadyQueue> ready_queue_by_index(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
int device_index);
// start device threads (CUDA, XLA, etc.) in Engine,
// note that it does NOT start CPU thread.
void start_device_threads();
void increment_non_reentrant_thread_count();
void decrement_non_reentrant_thread_count();
virtual void thread_main(const std::shared_ptr<GraphTask>& task);
void reentrant_thread_init();
void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
// Ensures device_ready_queues_ are initialized only once
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
c10::once_flag start_device_threads_flag_;
// Safe to read device_ready_queues_ without synchronization after
// initialization
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
int max_recursion_depth_;
struct ThreadPoolShared {
// Data structures used by the threads for executing reentrant backwards
// tasks. See Note [Reentrant backwards]
// Number of available threads for processing new GraphTasks.
unsigned int num_workers_;
// The threads will wait on work_ to be notified of GraphTasks
std::condition_variable work_;
// To protect reads and writes to graphtask_queue_ and num_workers_
// and for synchronizing creating new threads when needed
std::mutex mutex_;
// Workers will process the GraphTasks added to this queue. A GraphTask is
// allocated inside Engine::execute and lives for the duration of execute
std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ThreadPoolShared() : num_workers_(0) {}
};
// Temporary workaround until shutting down threads is done
// We need shared ownership of all these objects because the threads are
// leaked when Engine shuts down, so there may be threads waiting on work_ for
// the graphtasks_queue_ to be nonempty.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
private:
// Number of non-reentrant threads
std::atomic<uint32_t> non_reentrant_device_thread_count_;
// Destructor will wait for non-reentrant threads to finish
std::condition_variable non_reentrant_device_thread_condvar_;
std::mutex non_reentrant_device_thread_mutex_;
// stop() must be called before the destruction path goes down to the base
// class, in order to avoid a data-race-on-vptr. Use this boolean to guard
// whether stop() has already been called, so we can call this in every
// destructor of the class hierarchy.
bool stopped_{false};
};
// allow python_engine to override the default engine when it loads
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
} // namespace autograd
} // namespace torch
|