1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
#pragma once
#ifdef USE_TENSORPIPE
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
namespace torch {
namespace distributed {
namespace rpc {
struct TORCH_API FaultyTensorPipeRpcBackendOptions
: public TensorPipeRpcBackendOptions {
FaultyTensorPipeRpcBackendOptions(
int num_worker_threads,
float rpc_timeout,
std::string init_method,
std::vector<std::string> messages_to_fail,
std::unordered_map<std::string, float> messages_to_delay,
int num_fail_sends = 0)
: TensorPipeRpcBackendOptions(
num_worker_threads,
optional<std::vector<std::string>>(),
optional<std::vector<std::string>>(),
rpc_timeout,
std::move(init_method)),
messagesToFail(std::move(messages_to_fail)),
messagesToDelay(std::move(messages_to_delay)),
numFailSends(num_fail_sends) {
TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
}
std::vector<std::string> messagesToFail;
std::unordered_map<std::string, float> messagesToDelay;
int numFailSends;
};
class TORCH_API FaultyTensorPipeAgent : public TensorPipeAgent {
public:
FaultyTensorPipeAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string selfName,
worker_id_t selfId,
int worldSize,
FaultyTensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
std::vector<c10::Device> devices,
std::unique_ptr<RequestCallback> callback);
// Faulty send function for this class.
c10::intrusive_ptr<JitFuture> send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
const DeviceMap& deviceMap = {}) override;
// Add delay to writes
void pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
c10::intrusive_ptr<Message> rpcMessage,
std::vector<c10::Device>&& devices,
std::vector<c10::Stream> streams,
std::function<void(const tensorpipe::Error&)> fn) noexcept override;
protected:
// This function checks the messageTypesToFail_ to determine whether to use
// the faulty send or not.
bool shouldFailMessage(MessageType type) const;
private:
// This function parses the list of strings passed in by the python tests and
// resolves the Message Types that must use the faulty send.
std::vector<MessageType> parseMessagesToFailInput(
const std::vector<std::string>& messagesToFail) const;
// Returns amount of time in seconds to delay sending of the given message
// type.
float getDelayForMessage(MessageType type) const;
// Parse message types that we should inject arbitrary delays for.
std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
const std::unordered_map<std::string, float>& messageTypesToDelay) const;
// Number of sends to intentionally fail before allowing one to succeed.
const int numFailSends_;
// Vector of the MessageTypes that we must use the faulty send for. This is
// parsed based on a list of strings passed in by the python tests.
const std::vector<MessageType> messageTypesToFail_;
// Mapping of message types to amount we should delay send for in the ::send()
// function.
std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
// Map to track the number of sends we've failed for each RPC.
std::unordered_map<std::string, int> failMessageCountMap_;
// Mutex to guard failMessageCountMap_
std::mutex failMapMutex_;
MessageType messageStringToType(const std::string& messageString) const;
};
} // namespace rpc
} // namespace distributed
} // namespace torch
#endif // USE_TENSORPIPE
|