1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
// Temporary solution of RRef operations.
// TODO: Remove all these messages and use rpc + registered functions instead.
class TORCH_API RRefMessageBase : public RpcCommandBase {
public:
RRefMessageBase(const RRefId& rrefId, MessageType type)
: rrefId_(rrefId), type_(type) {}
~RRefMessageBase() override = default;
const RRefId& rrefId();
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const RRefId rrefId_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const MessageType type_;
};
class TORCH_API ForkMessageBase : public RRefMessageBase {
public:
ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type)
: RRefMessageBase(rrefId, type), forkId_(forkId) {}
const ForkId& forkId();
c10::intrusive_ptr<Message> toMessageImpl() && override;
static std::pair<RRefId, ForkId> fromMessage(
const Message& message,
MessageType type);
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const ForkId forkId_;
};
// UserRRef uses this message to fetch the remote RRef value from the owner.
class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase {
public:
ScriptRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId)
: RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL),
fromWorkerId_(fromWorkerId) {}
inline worker_id_t fromWorkerId() const {
return fromWorkerId_;
}
c10::intrusive_ptr<Message> toMessageImpl() && override;
static std::unique_ptr<ScriptRRefFetchCall> fromMessage(
const Message& message);
private:
const worker_id_t fromWorkerId_;
};
class TORCH_API PythonRRefFetchCall final : public RRefMessageBase {
public:
PythonRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId)
: RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL),
fromWorkerId_(fromWorkerId) {}
c10::intrusive_ptr<Message> toMessageImpl() && override;
static std::unique_ptr<PythonRRefFetchCall> fromMessage(
const Message& message);
private:
const worker_id_t fromWorkerId_;
};
// OwnerRRef uses this message to send the RRef value to a remote UserRRef
class TORCH_API RRefFetchRet : public RpcCommandBase {
public:
RRefFetchRet(std::vector<at::IValue> values, MessageType type)
: values_(std::move(values)), type_(type) {}
const std::vector<at::IValue>& values();
c10::intrusive_ptr<Message> toMessageImpl() && override;
private:
std::vector<at::IValue> values_;
const MessageType type_;
};
class TORCH_API ScriptRRefFetchRet final : public RRefFetchRet {
public:
explicit ScriptRRefFetchRet(std::vector<at::IValue> values)
: RRefFetchRet(std::move(values), MessageType::SCRIPT_RREF_FETCH_RET) {}
static std::unique_ptr<ScriptRRefFetchRet> fromMessage(
const Message& message);
};
class TORCH_API PythonRRefFetchRet final : public RRefFetchRet {
public:
explicit PythonRRefFetchRet(std::vector<at::IValue> values)
: RRefFetchRet(std::move(values), MessageType::PYTHON_RREF_FETCH_RET) {}
static std::unique_ptr<PythonRRefFetchRet> fromMessage(
const Message& message);
};
// UserRRef (regardless it's the creator or not) uses this message to notiify
// OwnerRRef on delete.
class TORCH_API RRefUserDelete final : public ForkMessageBase {
public:
RRefUserDelete(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {}
static std::unique_ptr<RRefUserDelete> fromMessage(const Message& message);
};
class TORCH_API RemoteRet final : public ForkMessageBase {
public:
RemoteRet(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {}
static std::unique_ptr<RemoteRet> fromMessage(const Message& message);
};
// A child RRef uses this message to notify its parent that the child has been
// confirmed by the owner.
class TORCH_API RRefChildAccept final : public RpcCommandBase {
public:
explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {}
const ForkId& forkId() const;
c10::intrusive_ptr<Message> toMessageImpl() && override;
static std::unique_ptr<RRefChildAccept> fromMessage(const Message& message);
private:
const ForkId forkId_;
};
// A child RRef uses this message to send a fork request to the owner.
class TORCH_API RRefForkRequest final : public ForkMessageBase {
public:
RRefForkRequest(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {}
static std::unique_ptr<RRefForkRequest> fromMessage(const Message& message);
};
class TORCH_API RRefAck final : public RpcCommandBase {
public:
RRefAck() = default;
c10::intrusive_ptr<Message> toMessageImpl() && override;
static std::unique_ptr<RRefAck> fromMessage(const Message& message);
};
} // namespace rpc
} // namespace distributed
} // namespace torch
|