1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
|
#include <torch/csrc/cuda/comm.h>
#include <torch/csrc/cuda/device_set.h>
#include <torch/csrc/utils/tensor_flatten.h>
#ifdef USE_NCCL
#include <torch/csrc/cuda/nccl.h>
#endif
#include <ATen/ATen.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/variable.h>
#include <cstddef>
#include <vector>
namespace torch {
namespace cuda {
using namespace at;
using namespace torch::autograd;
// Some operations can be performed more efficiently if we're handling tensors
// of a single type only. Adding this logic directly in the loop makes it a bit
// ugly, so here's a helper for it.
struct unique_type_checker {
void show(size_t type_id) {
if (!unique) {
return;
}
if (!type_id_) {
type_id_ = type_id;
}
unique = type_id_.value() == type_id;
}
c10::optional<size_t> type_id_;
bool unique = true;
};
// ***************** Broadcast *******************
//
// Broadcast a source tensor (CPU or CUDA) to a list of CUDA devices, or CUDA
// tensors on one or more devices.
// no checks
static inline std::vector<Tensor>& _broadcast_out_impl(
const Tensor& tensor,
std::vector<Tensor>& out_tensors) {
#ifdef USE_NCCL
std::vector<Tensor> nccl_list;
nccl_list.reserve(out_tensors.size() + 1);
nccl_list.push_back(tensor);
for (auto& out_tensor : out_tensors) {
nccl_list.push_back(out_tensor);
}
if (nccl::is_available(nccl_list)) {
nccl::broadcast(nccl_list);
} else {
#else
{
#endif
for (auto& out_tensor : out_tensors) {
out_tensor.copy_(tensor, /*non_blocking=*/true);
}
}
return out_tensors;
}
std::vector<Tensor>& broadcast_out(
const Tensor& tensor,
std::vector<Tensor>& out_tensors) {
for (const auto i : c10::irange(out_tensors.size())) {
TORCH_CHECK(
out_tensors[i].is_cuda(),
"Expected all output tensors to be CUDA tensors, but output tensor at index ",
i,
" has device '",
out_tensors[i].device(),
"'");
TORCH_CHECK(
out_tensors[i].sizes() == tensor.sizes(),
"Expected all output tensors to have same shape as the source tensor ",
tensor.sizes(),
", but output tensor at index ",
i,
" has shape ",
out_tensors[i].sizes());
}
return _broadcast_out_impl(tensor, out_tensors);
}
std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<Tensor> diff_device_dst_tensors;
diff_device_dst_tensors.reserve(devices.size());
for (auto device : devices) {
TORCH_CHECK(
device >= 0, "Expected non-negative device index, but got ", device);
if (device != tensor.get_device()) {
diff_device_dst_tensors.push_back(at::empty(
tensor.sizes(),
tensor.options().device(
at::Device(DeviceType::CUDA, device)))); // preserve memory format
}
}
_broadcast_out_impl(tensor, diff_device_dst_tensors);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<Tensor> dst_tensors;
dst_tensors.reserve(devices.size());
auto it = diff_device_dst_tensors.begin();
for (auto device : devices) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (device != tensor.get_device()) {
dst_tensors.push_back(*it++);
} else {
dst_tensors.push_back(tensor);
}
}
TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end());
return dst_tensors;
}
// NOTE [ Version Counter in comm.*_coalesced ]
//
// broadcast_coalesced
// ~~~~~~~~~~~~~~~~~~~
//
// In broadcast_coalesced, multiple variables may be coalesced into a single
// large one, broadcast to other devices, and the get split according to the
// original shapes.
//
// When splitting, the view operations will make all Variables broadcast
// together to share a single version counter, because they are all views of the
// large Variable. However, that large Variable is immediately discarded and all
// these Variables do not share storage at all.
//
// For example, when two buffers are broadcast together in `DataParallel` and
// one of them is modified in-place during `forward` but the other is needed in
// backward, autograd engine will complain.
//
// We thus re-wrap these Variables after broadcasting (i.e., effectively doing
// what is equivalent to .data in Python), and give them individual version
// counters.
//
// NB: Just calling detach() on the variables is not sufficient
//
// NB: For `device[0]` in broadcast_coalesced, the input Variables are always
// returned as-is, so **do not** re-wrap them.
//
// reduce_add_coalesced
// ~~~~~~~~~~~~~~~~~~~~
//
// Similarly for reduce_add_coalesced, when the output are newly created
// Variables.
tensor_list2d broadcast_coalesced(
TensorList tensors,
IntArrayRef devices,
size_t buffer_size) {
TORCH_CHECK(
std::all_of(
tensors.begin(),
tensors.end(),
[&](const at::Tensor& t) { return t.get_device() == devices[0]; }),
"All tensors must be on devices[0]: ",
devices[0]);
#ifdef USE_NCCL
buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size);
#endif
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
tensor_list2d outputs(devices.size());
outputs[0] = tensors.vec();
for (auto& o : outputs)
o.reserve(tensors.size());
unique_type_checker type_checker;
at::cuda::CUDAGuard device_guard(devices[0]);
for (auto& chunk : utils::take_tensors(tensors, buffer_size)) {
auto type_id = chunk.type_id();
type_checker.show(type_id);
std::vector<at::Tensor> results;
if (chunk.options().is_sparse()) {
auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
auto broadcast_indices = broadcast(flat_tuple.first, devices);
auto broadcast_values = broadcast(flat_tuple.second, devices);
results.reserve(devices.size());
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]);
auto& device_outputs = outputs[i];
auto& inds = broadcast_indices[i];
auto& vals = broadcast_values[i];
for (const auto& var :
utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
}
}
} else {
auto results =
broadcast(utils::flatten_dense_tensors(chunk.tensors), devices);
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]);
auto& device_outputs = outputs[i];
for (auto& var :
utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
}
}
}
}
// If we only saw a single tensor type, then we can skip expensive reordering
if (!type_checker.unique) {
for (auto& o : outputs)
utils::reorder_tensors_like(o, tensors);
}
return outputs;
}
// ***************** Scatter *******************
//
// Scatter a source tensor (CPU or CUDA) to a list of CUDA tensors on one or
// more devices.
std::vector<at::Tensor>& scatter_out(
const at::Tensor& tensor,
std::vector<at::Tensor>& out_tensors,
int64_t dim,
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
streams) {
TORCH_CHECK(
!out_tensors.empty(),
"Expected at least one output tensor to scatter to");
dim = at::maybe_wrap_dim(dim, tensor);
int64_t total_size = 0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<int64_t> chunk_sizes;
chunk_sizes.reserve(out_tensors.size());
for (const auto i : c10::irange(out_tensors.size())) {
TORCH_CHECK(
out_tensors[i].is_cuda(),
"Expected all output tensors to be CUDA tensors, but output tensor at index ",
i,
" has device '",
out_tensors[i].device(),
"'");
auto out_sizes = out_tensors[i].sizes().vec();
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
bool same_ndim = out_sizes.size() == tensor.dim();
if (same_ndim) {
total_size += out_sizes[dim];
chunk_sizes.push_back(out_sizes[dim]);
out_sizes[dim] = tensor.size(dim);
}
TORCH_CHECK(
same_ndim && out_sizes == tensor.sizes(),
"Output tensor at index ",
i,
" has incorrect shape: ",
out_tensors[i].sizes(),
". Expected same "
"shape except for scatter dim ",
dim,
" as the source tensor: ",
at::IntArrayRef(tensor.sizes()));
}
TORCH_CHECK(
total_size == tensor.size(dim),
"Total size for output tensors along scatter dim ",
dim,
" does not match "
"the source tensor size at dim ",
dim,
". Expected ",
tensor.size(dim),
", but got total size ",
total_size);
auto chunks =
tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
at::cuda::OptionalCUDAStreamGuard cuda_guard;
for (const auto i : c10::irange(chunks.size())) {
if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
const auto device_index =
static_cast<int16_t>(out_tensors[i].get_device());
TORCH_CHECK(
(*streams)[i]->device_index() == device_index,
"Expected the device associated with the stream at index ",
i,
" (was ",
(*streams)[i]->device_index(),
") ",
"to match the device supplied at that index ",
"(expected ",
device_index,
")");
cuda_guard.reset_stream(*(*streams)[i]);
}
// NB: We don't detect the case where `out_tensor` is already the correct
// view of `tensor` since that would be nontrivial and involve checking
// ptr, offset, and strides. So `scatter_out(src, src.chunk(...))` does
// more copying than `scatter(src)`.
out_tensors[i].copy_(chunks[i], /*non_blocking=*/true);
}
return out_tensors;
}
std::vector<at::Tensor> scatter(
const at::Tensor& tensor,
at::IntArrayRef devices,
const c10::optional<std::vector<int64_t>>& chunk_sizes,
int64_t dim,
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
streams) {
TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to");
if (chunk_sizes.has_value()) {
TORCH_CHECK(
chunk_sizes->size() == devices.size(),
"Expected devices and chunk_sizes to be of same length, but got "
"len(devices) = ",
devices.size(),
" and len(chunk_sizes) = ",
chunk_sizes->size());
}
dim = at::maybe_wrap_dim(dim, tensor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::Tensor> chunks = chunk_sizes
? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
: tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
at::cuda::OptionalCUDAStreamGuard cuda_guard;
for (const auto i : c10::irange(chunks.size())) {
const auto device_index = static_cast<int16_t>(devices[i]);
if (device_index != tensor.get_device()) {
if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
TORCH_CHECK(
(*streams)[i]->device_index() == device_index,
"Expected the device associated with the stream at index ",
i,
" (was ",
(*streams)[i]->device_index(),
") ",
"to match the device supplied at that index ",
"(expected ",
device_index,
")");
cuda_guard.reset_stream(*(*streams)[i]);
}
TORCH_CHECK(
device_index >= 0,
"Expected non-negative device index, but got ",
device_index);
chunks[i] = chunks[i].to(
{DeviceType::CUDA, device_index},
/*non_blocking=*/true,
/*copy=*/false,
/*memory_format=*/at::MemoryFormat::Preserve);
}
}
return chunks;
}
// ***************** Gather *******************
//
// Gather a list of CUDA tensors on one or more devices to a target tensor or
// device, either CPU or CUDA.
// no checks
static inline at::Tensor& _gather_out_impl(
at::TensorList tensors,
at::Tensor& out_tensor,
int64_t dim) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<int64_t> chunk_sizes;
chunk_sizes.reserve(tensors.size());
for (auto& tensor : tensors) {
chunk_sizes.push_back(tensor.size(dim));
}
auto chunks =
out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
for (const auto i : c10::irange(tensors.size())) {
chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_cuda());
}
return out_tensor;
}
at::Tensor& gather_out(
at::TensorList tensors,
at::Tensor& out_tensor,
int64_t dim) {
TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
int64_t total_size = 0;
auto& first = tensors.front();
const auto first_size = first.sizes();
dim = at::maybe_wrap_dim(dim, first);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
for (const auto i : c10::irange(tensors.size())) {
const auto& tensor = tensors[i];
TORCH_CHECK(
tensor.is_cuda(),
"Expected all input tensors to be CUDA tensors, but "
"tensor at index ",
i,
" has device '",
tensor.device(),
"'");
TORCH_CHECK(
tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
"Expected all input tensors to have the same number of dimensions, but ",
"tensor at index ",
i,
"has ",
tensor.ndimension(),
" dimensions, (expected ",
expected_size.size(),
")");
expected_size[dim] = tensor.size(dim);
for (const auto dimension : c10::irange(expected_size.size())) {
TORCH_CHECK(
expected_size[dimension] == tensor.size(dimension),
"Input tensor at index ",
i,
" has invalid shape ",
tensor.sizes(),
", but expected ",
at::IntArrayRef(expected_size));
}
total_size += tensor.size(dim);
}
expected_size[dim] = total_size;
TORCH_CHECK(
out_tensor.sizes() == expected_size,
"Expected out tensor to have shape ",
at::IntArrayRef(expected_size),
", but got ",
out_tensor.sizes())
return _gather_out_impl(tensors, out_tensor, dim);
}
at::Tensor gather(
at::TensorList tensors,
int64_t dim,
c10::optional<int32_t> destination_index) {
TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
int64_t total_size = 0;
auto& first = tensors.front();
const auto first_size = first.sizes();
dim = at::maybe_wrap_dim(dim, first);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
auto memory_format = first.suggest_memory_format();
for (const auto i : c10::irange(tensors.size())) {
const auto& tensor = tensors[i];
TORCH_CHECK(
tensor.is_cuda(),
"Expected all input tensors to be CUDA tensors, but "
"tensor at index ",
i,
" has device ",
tensor.device());
TORCH_CHECK(
tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
"Expected all input tensors to have the same number of dimensions, but ",
"tensor at index ",
i,
"has ",
tensor.ndimension(),
" dimensions, (expected ",
expected_size.size(),
")");
expected_size[dim] = tensor.size(dim);
for (const auto dimension : c10::irange(expected_size.size())) {
TORCH_CHECK(
expected_size[dimension] == tensor.size(dimension),
"Input tensor at index ",
i,
" has invalid shape ",
tensor.sizes(),
", but expected ",
at::IntArrayRef(expected_size));
}
total_size += tensor.size(dim);
if (memory_format != MemoryFormat::Contiguous &&
tensor.suggest_memory_format() != memory_format) {
memory_format = MemoryFormat::Contiguous;
}
}
expected_size[dim] = total_size;
at::Device device(DeviceType::CPU);
if (!destination_index || *destination_index != -1) {
device = at::Device(
DeviceType::CUDA, destination_index ? *destination_index : -1);
}
at::Tensor result =
at::empty(expected_size, first.options().device(device), memory_format);
return _gather_out_impl(tensors, result, dim);
}
} // namespace cuda
} // namespace torch
|