1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
#pragma once
#include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
namespace torch {
namespace distributed {
namespace autograd {
// Represents an RPC that includes autograd information. This class basically
// wraps another `RpcCommandBase` object which represents the actual RPC and has
// additional autograd information associated with that RPC.
class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
public:
// Used when we are sending an RPC over the wire.
RpcWithAutograd(
rpc::worker_id_t fromWorkerId,
rpc::MessageType messageType,
const AutogradMetadata& autogradMetadata,
c10::intrusive_ptr<rpc::Message> wrappedMessage,
rpc::DeviceMap deviceMap = {});
// Used when receiving an RPC over the wire.
RpcWithAutograd(
rpc::worker_id_t fromWorkerId,
rpc::MessageType messageType,
const AutogradMetadata& autogradMetadata,
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors,
rpc::DeviceMap deviceMap = {});
c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
static std::unique_ptr<RpcWithAutograd> fromMessage(
const rpc::Message& message);
// Retrieves tensors as part of this RPC, which need to be considered for
// autograd computations.
std::vector<torch::Tensor>& tensors();
const AutogradMetadata& autogradMetadata() const;
RpcCommandBase& wrappedRpc();
void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
// Message type of the wrapped RPC.
rpc::MessageType wrappedMessageType() const;
// Retrieve the worker id from which the RPC originated.
rpc::worker_id_t fromWorkerId() const;
// Retrieve the device map.
const rpc::DeviceMap& deviceMap();
private:
// WorkerId from which this RPC originated. This is necessary for knowing
// which worker we need to contact during the backward pass.
rpc::worker_id_t fromWorkerId_;
// Message type for this call.
rpc::MessageType messageType_;
AutogradMetadata autogradMetadata_;
// Since wrappedMessage_ is destructively constructed from wrappedRpc_,
// they are valid exclusively. They are used for different purpose.
// wrappedRpc_ is used while constructing receive rpcWithAutograd;
// wrappedMessage_ is used while constructing send rpcWithAutograd;
// When receive rpcWithAutograd is constructed fromMessage, it is valid;
// When send rpcWithAutograd is constructed before toMessage, it is nullptr;
std::unique_ptr<RpcCommandBase> wrappedRpc_;
// Serialized message representing wrappedRpc_. Used mostly as a cache to
// avoid serializing the request twice.
// When receive rpcWithAutograd is constructed fromMessage, it is nullptr;
// When send rpcWithAutograd is constructed before toMessage, it is valid;
c10::intrusive_ptr<rpc::Message> wrappedMessage_;
// message type of the wrappedMessage, this is stored separately since
// wrappedMessage_ is not always guaranteed to be populated.
rpc::MessageType wrappedMessageType_;
// Tensors part of the wrappedRpc that need to be considered for autograd.
std::vector<torch::Tensor> tensors_;
// Device mapping for tensors that are sent across an RPC to another node.
rpc::DeviceMap deviceMap_;
};
} // namespace autograd
} // namespace distributed
} // namespace torch
|