1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/Export.h>
namespace c10d {
// Broadcast many tensors to all processes in the process group.
TORCH_API void broadcast_coalesced(
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank = 0);
// This class passes bucket contents tensor to DDP communication hook.
class TORCH_API GradBucket {
public:
explicit GradBucket(
size_t index,
size_t bucket_count,
const at::Tensor& tensor,
const std::vector<size_t>& offsets,
const std::vector<size_t>& lengths,
const std::vector<c10::IntArrayRef>& sizes_vec,
const std::vector<at::Tensor>& parameters)
: index_(index),
bucket_count_(bucket_count),
buffer_(tensor),
offsets_(offsets),
lengths_(lengths),
sizes_vec_(sizes_vec),
parameters_(parameters) {}
// Returns the index of the bucket, which is unique across all the buckets.
size_t getIndex() const {
return index_;
}
const at::Tensor& getBuffer() const {
return buffer_;
}
// Returns a mutable buffer compared with the above method.
at::Tensor& getBufferRef() {
return buffer_;
}
// Overwrites the buffer at a specific index.
void setBuffer(at::Tensor& buffer) {
buffer_ = buffer;
}
// Each tensor in the list that getGradients corresponds to a
// parameter.
std::vector<at::Tensor> getGradients() const;
// Returns model parameters belonging to this bucket. They are returned in the
// same order as gradient tensors via getGradients(). For example,
// getParameters[i] will have its gradient stored in
// getGradients[i]
const std::vector<at::Tensor> getParameters() const {
return parameters_;
}
// Returns whther this bucket is the last bucket to allreduce in an iteration.
bool isLast() const {
return index_ == bucket_count_ - 1;
}
private:
size_t index_;
size_t bucket_count_;
at::Tensor buffer_;
// Per-variable info in buffer_.
std::vector<size_t> offsets_;
std::vector<size_t> lengths_;
std::vector<c10::IntArrayRef> sizes_vec_;
// Model parameters for this bucket.
const std::vector<at::Tensor> parameters_;
};
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook
// result into a tensor.
class TORCH_API CommHookInterface {
public:
virtual ~CommHookInterface() = default;
// Passes the input grad bucket to the registered communication hook.
// Once the tensor in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
GradBucket& bucket) = 0;
// Returns the resulting tensor once the communication hook result is
// ready. The resulting tensor will then be copied to the grads of
// individual parameters.
virtual at::Tensor parseHookResult(
const c10::IValue& result) = 0;
};
namespace detail {
// This helper function is called both by CppCommHookInterface below and inside
// reducer.
at::Tensor parseCppCommHookResult(const c10::IValue& result);
} // namespace detail
// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
template <typename T>
class CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(const T& state) : state_(state) {}
~CppCommHookInterface() override = default;
at::Tensor parseHookResult(const c10::IValue& result) override {
return detail::parseCppCommHookResult(result);
}
protected:
T state_;
};
} // namespace c10d
|