1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
|
#pragma once
#include <cstdint>
#include <functional>
#include <ATen/core/Dict.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
namespace torch {
namespace distributed {
namespace autograd {
class RecvRpcBackward;
// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {
public:
using GradCallback = std::function<bool(torch::Tensor&)>;
explicit DistAutogradContext(int64_t contextId);
// Retrieves the autograd context id for this context.
int64_t contextId() const;
// Records a 'send' autograd function for this context with the provided
// message id.
void addSendFunction(
const std::shared_ptr<SendRpcBackward>& func,
int64_t autograd_message_id);
// Records a 'recv' autograd function for this context with the provided
// message id.
void addRecvFunction(
std::shared_ptr<RecvRpcBackward>& func,
int64_t autograd_message_id);
// Given an autograd_message_id, retrieve the appropriate send function.
std::shared_ptr<SendRpcBackward> retrieveSendFunction(
int64_t autograd_message_id);
// Return all send functions for this context.
std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>> sendFunctions()
const;
// Return all recv functions for this context.
std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>> recvFunctions()
const;
// Adds a future message recording an outstanding RPC.
void addOutstandingRpc(const c10::intrusive_ptr<rpc::JitFuture>& jitFuture);
// Returns all gradients.
const c10::Dict<torch::Tensor, torch::Tensor> getGradients() const;
// This function gives a mutable grad reference to the callback.
// If the callback returns true, it means the grad in the context
// needs to be updated.
void runGradCallbackForVariable(
const torch::autograd::Variable& variable,
GradCallback&& cb);
DistAutogradContext(const DistAutogradContext&) = delete;
DistAutogradContext& operator=(const DistAutogradContext&) = delete;
DistAutogradContext(DistAutogradContext&&) = delete;
DistAutogradContext& operator=(DistAutogradContext&&) = delete;
// records the workerID of a node that we sent an RPC to.
// workerIDs are added here when we attach a send function to this autograd
// context
void addKnownWorkerId(const rpc::worker_id_t workerId);
// Retrieves a set containing the known workerIds for this context
// These are the different workers that this context has sent RPCs to.
std::unordered_set<rpc::worker_id_t> getKnownWorkerIds() const;
private:
friend class BackwardPassCleanupGuard;
friend class DistEngine;
friend class RecvRpcBackward;
friend class DistAccumulateGradCaptureHook;
// Record that we would like to accumulate the provided gradient on the given
// variable.
void accumulateGrad(
const torch::autograd::Variable& variable,
const torch::Tensor& grad,
size_t num_expected_refs);
// Retrieve the GraphTask.
std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask();
// Set the appropriate graph task for the backward pass. Can be called only
// once.
void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask);
// Resets the graph task to ensure we can run another distributed backward
// pass for the same autograd context.
void resetGraphTask();
// Waits for all outstanding RPCs for this context to finish and clears all
// outstanding rpcs held in this context. This should be called only once.
c10::intrusive_ptr<c10::ivalue::Future> clearAndWaitForOutstandingRpcsAsync();
void clearOutstandingRpcs();
// Record an event to mark the completion of gradient computation. These
// events will later help to properly synchronize gradients consumptions
// in getGradients(). We need these events because backward and
// optimizer.step are separate RPC calls, and will occur on different CUDA
// streams. Without synchronization, it is possible that gradients are
// consumed before they are ready.
void recordGradEvent(c10::Device device);
const int64_t contextId_;
// Set containing known worker IDs, used in cleaning up autograd context.
// Whenever a sendRpcBackward is attached to the autograd graph for this
// context, the destination is added here.
std::unordered_set<rpc::worker_id_t> knownWorkerIds_;
// Map from autograd_message_id to appropriate 'send' autograd function.
std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
sendAutogradFunctions_;
// Map from autograd_message_id to appropriate 'recv' autograd function.
std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
recvAutogradFunctions_;
// Gradients accumulated in this context so far. The key is the variable on
// which the gradient needs to be accumulated and the value is the gradient
// that needs to be accumulated on that variable..
c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;
// See comments for recordGradEvent(c10::Device device);
std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
const c10::impl::VirtualGuardImpl impl_;
// The autograd GraphTask for the backward pass on this node for this context.
std::shared_ptr<torch::autograd::GraphTask> graphTask_;
// List of futures for RPCs initiated by this node to propagate gradients to
// other nodes. The distributed autograd engine on this node can return
// successfully only if all these futures are done and are successful.
std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;
// Lock to protect concurrent modification of the context.
mutable std::mutex lock_;
};
using ContextPtr = std::shared_ptr<DistAutogradContext>;
// This class stores a shared_ptr to a DistAutogradContext instance in a
// thread local variable. The instance is given by the call site. The class
// doesn't know the current context. It's just a util class.
class TORCH_API ThreadLocalDistAutogradContext {
public:
// Store 'new_context' to the thread local variable maintained by this class.
explicit ThreadLocalDistAutogradContext(ContextPtr&& new_context);
~ThreadLocalDistAutogradContext();
// Retrieve the stored DistAutogradContext instance.
static ContextPtr getContextPtr();
private:
ContextPtr prev_context_ptr_;
};
} // namespace autograd
} // namespace distributed
} // namespace torch
|