1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
#include <torch/csrc/autograd/functions/comm.h>
#include <ATen/core/functional.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/cuda/comm.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Optional.h>
#include <cstddef>
#include <memory>
#include <vector>
namespace torch {
namespace autograd {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Scatter::Scatter(
std::vector<at::Device> devices,
// NOLINTNEXTLINE(modernize-pass-by-value)
const c10::optional<std::vector<int64_t>>& chunk_sizes,
int64_t dim,
// NOLINTNEXTLINE(modernize-pass-by-value)
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
streams,
bool unsqueeze_scalars)
: devices_(std::move(devices)),
chunk_sizes_(chunk_sizes),
dim_(dim),
streams_(streams),
unsqueeze_scalars_(unsqueeze_scalars) {}
Scatter::~Scatter() = default;
variable_list Scatter::apply(variable_list&& inputs) {
AT_ASSERT(inputs.size() == 1);
auto& input = inputs.front();
std::shared_ptr<Node> grad_fn;
if (compute_requires_grad(input)) {
grad_fn =
std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
grad_fn->set_next_edges(collect_next_edges(input));
}
auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
return device.index();
});
auto tensors = torch::cuda::scatter(
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(input),
device_indices,
chunk_sizes_,
dim_,
streams_);
std::vector<Variable> variables;
variables.reserve(tensors.size());
for (auto& tensor : tensors) {
AT_ASSERT(tensor.defined());
if (unsqueeze_scalars_) {
AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
variables.push_back(tensor[0]);
} else {
variables.push_back(std::move(tensor));
}
}
if (grad_fn) {
set_history(variables, grad_fn);
}
return variables;
}
Gather::Gather(const at::Device& destination_device, int64_t dim)
: destination_device_(destination_device), dim_(dim) {}
Gather::~Gather() = default;
variable_list Gather::apply(variable_list&& inputs) {
bool all_are_zero_dim = true;
for (const auto& input : inputs) {
TORCH_CHECK(
input.is_cuda(),
"All inputs to Gather must be CUDA tensors, got ",
input.toString());
if (input.dim() > 0) {
all_are_zero_dim = false;
}
}
const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
if (unsqueeze_scalars) {
TORCH_WARN(
"Was asked to gather along dimension 0, but all "
"input tensors were scalars; will instead unsqueeze "
"and return a vector.");
}
std::shared_ptr<Node> grad_fn;
// compute this before moving variables from `inputs`
if (compute_requires_grad(inputs)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::Device> source_devices;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<int64_t> input_sizes;
for (auto& input : inputs) {
source_devices.push_back(input.device());
input_sizes.push_back(input.size(dim_));
}
grad_fn = std::make_shared<Scatter>(
std::move(source_devices),
std::move(input_sizes),
dim_,
/*streams=*/c10::nullopt,
/*unsqueeze_scalars=*/unsqueeze_scalars);
grad_fn->set_next_edges(collect_next_edges(inputs));
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::Tensor> tensors;
tensors.reserve(inputs.size());
for (auto& variable : inputs) {
if (unsqueeze_scalars) {
tensors.push_back(variable.view(1));
} else {
tensors.push_back(std::move(variable));
}
}
// Disable the autograd during the actual computation
// torch::cuda::gather does not return a view or change things inplace
// so no need for extra logic here
at::Tensor variable;
{
at::AutoDispatchBelowAutograd mode;
// This is special logic for torch::cuda::gather!
const auto destination_index =
destination_device_.is_cpu() ? -1 : destination_device_.index();
variable = torch::cuda::gather(tensors, dim_, destination_index);
}
if (grad_fn) {
set_history(variable, grad_fn);
}
return {variable};
}
} // namespace autograd
} // namespace torch
|