1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
#include <torch/csrc/autograd/functions/comm.h>
#include <ATen/core/functional.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/cuda/comm.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#include <vector>
namespace torch::autograd {
Scatter::Scatter(
std::vector<at::Device> devices,
std::optional<std::vector<int64_t>> chunk_sizes,
int64_t dim,
std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams,
bool unsqueeze_scalars)
: devices_(std::move(devices)),
chunk_sizes_(std::move(chunk_sizes)),
dim_(dim),
streams_(std::move(streams)),
unsqueeze_scalars_(unsqueeze_scalars) {}
Scatter::~Scatter() = default;
variable_list Scatter::apply(variable_list&& inputs) {
AT_ASSERT(inputs.size() == 1);
auto& input = inputs.front();
std::shared_ptr<Node> grad_fn;
if (compute_requires_grad(input)) {
grad_fn =
std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
grad_fn->set_next_edges(collect_next_edges(input));
}
auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
return device.index();
});
auto tensors =
torch::cuda::scatter(input, device_indices, chunk_sizes_, dim_, streams_);
std::vector<Variable> variables;
variables.reserve(tensors.size());
for (auto& tensor : tensors) {
AT_ASSERT(tensor.defined());
if (unsqueeze_scalars_) {
AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
variables.push_back(tensor[0]);
} else {
variables.push_back(std::move(tensor));
}
}
if (grad_fn) {
set_history(variables, grad_fn);
}
return variables;
}
Gather::Gather(const at::Device& destination_device, int64_t dim)
: destination_device_(destination_device), dim_(dim) {}
Gather::~Gather() = default;
variable_list Gather::apply(variable_list&& inputs) {
bool all_are_zero_dim = true;
for (const auto& input : inputs) {
TORCH_CHECK(
input.is_cuda(),
"All inputs to Gather must be CUDA tensors, got ",
input.toString());
if (input.dim() > 0) {
all_are_zero_dim = false;
}
}
const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
if (unsqueeze_scalars) {
TORCH_WARN(
"Was asked to gather along dimension 0, but all "
"input tensors were scalars; will instead unsqueeze "
"and return a vector.");
}
std::shared_ptr<Node> grad_fn;
// compute this before moving variables from `inputs`
if (compute_requires_grad(inputs)) {
std::vector<at::Device> source_devices;
source_devices.reserve(inputs.size());
std::vector<int64_t> input_sizes;
input_sizes.reserve(inputs.size());
for (auto& input : inputs) {
source_devices.push_back(input.device());
input_sizes.push_back(input.size(dim_));
}
grad_fn = std::make_shared<Scatter>(
std::move(source_devices),
std::move(input_sizes),
dim_,
/*streams=*/std::nullopt,
/*unsqueeze_scalars=*/unsqueeze_scalars);
grad_fn->set_next_edges(collect_next_edges(inputs));
}
std::vector<at::Tensor> tensors;
tensors.reserve(inputs.size());
for (auto& variable : inputs) {
if (unsqueeze_scalars) {
tensors.push_back(variable.view(1));
} else {
tensors.push_back(std::move(variable));
}
}
// Disable the autograd during the actual computation
// torch::cuda::gather does not return a view or change things inplace
// so no need for extra logic here
at::Tensor variable;
{
at::AutoDispatchBelowAutograd mode;
// This is special logic for torch::cuda::gather!
const auto destination_index =
destination_device_.is_cpu() ? -1 : destination_device_.index();
variable = torch::cuda::gather(tensors, dim_, destination_index);
}
if (grad_fn) {
set_history(variable, grad_fn);
}
return {variable};
}
} // namespace torch::autograd
|