1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
|
#include <torch/csrc/cuda/python_nccl.h>
#include <ATen/core/functional.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/Types.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/utils/pybind.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/irange.h>
#include <sstream>
#include <unordered_map>
using namespace at;
using namespace torch;
using namespace torch::cuda::nccl;
using namespace torch::cuda::nccl::detail;
static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
return PyInt_FromLong(version());
}
PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
ncclUniqueId id;
get_unique_id(id);
return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
END_HANDLE_TH_ERRORS
}
static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
ncclComm_t comm =
(ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
if (!comm)
throw python_error();
return comm;
}
static void destroy_nccl_comm(PyObject* capsule) {
HANDLE_TH_ERRORS
ncclComm_t comm = unpack_nccl_comm(capsule);
{
pybind11::gil_scoped_release no_gil;
comm_destroy(comm);
}
END_HANDLE_TH_ERRORS_RET()
}
static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams(
PyObject* obj,
size_t size) {
if (obj == Py_None) {
return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt);
}
auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
if (streams.size() != size) {
throw std::runtime_error(
"number of streams is not equal to number of inputs");
}
return streams;
}
static inline at::Tensor extract_tensor(PyObject* obj);
static inline std::vector<at::Tensor> extract_tensors(PyObject* obj);
static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
if (obj == Py_None) {
return std::vector<ncclComm_t>();
}
std::vector<ncclComm_t> comms;
if (PyCapsule_CheckExact(obj)) {
comms = {unpack_nccl_comm(obj)};
} else {
auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
if (!seq)
throw python_error();
auto size = PySequence_Fast_GET_SIZE(seq.get());
comms = std::vector<ncclComm_t>(size);
for (const auto i : c10::irange(size)) {
comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
}
}
if (comms.size() != size) {
throw std::runtime_error(
"number of communicators is not equal to number of inputs");
}
return comms;
}
PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
int nranks;
const char* id;
Py_ssize_t id_len;
int rank;
if (!PyArg_ParseTuple(
args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
return nullptr;
}
THPUtils_assert(
id_len == NCCL_UNIQUE_ID_BYTES,
"invalid unqiue_id (expected %d bytes, got %zd)",
NCCL_UNIQUE_ID_BYTES,
id_len);
ncclUniqueId commId;
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
ncclComm_t comm;
{
pybind11::gil_scoped_release no_gil;
comm = comm_init_rank(nranks, commId, rank);
}
return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
PyObject *_inputs, *_output, *_streams, *_comms;
int root, op;
if (!PyArg_ParseTuple(
args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
THPUtils_invalidArguments(
args,
nullptr,
"nccl_reduce",
1,
"(sequence[Tensor] inputs, Tensor output, int root,"
" int op, sequence[torch.cuda.Stream or None]");
return nullptr;
}
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
auto output = extract_tensor(_output);
std::vector<c10::optional<at::cuda::CUDAStream>> streams =
unpack_streams(_streams, inputs.size());
auto user_comms = unpack_comms(_comms, inputs.size());
{
pybind11::gil_scoped_release no_gil;
torch::cuda::nccl::reduce(inputs, output, root, op, streams, user_comms);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms;
int op;
if (!PyArg_ParseTuple(
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
THPUtils_invalidArguments(
args,
nullptr,
"nccl_all_reduce",
1,
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
" sequence[torch.cuda.Stream] streams,"
" sequence[torch.cuda.nccl.Communicator] comms)");
return nullptr;
}
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
auto streams = unpack_streams(_streams, inputs.size());
auto user_comms = unpack_comms(_comms, inputs.size());
{
pybind11::gil_scoped_release no_gil;
all_reduce(inputs, outputs, op, streams, user_comms);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
PyObject *_inputs, *_streams, *_comms;
int root;
if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
THPUtils_invalidArguments(
args,
nullptr,
"nccl_broadcast",
1,
"(sequence[Tensor] inputs, int root"
" sequence[torch.cuda.Stream] streams,"
" sequence[torch.cuda.nccl.Communicator] comms)");
return nullptr;
}
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root");
auto streams = unpack_streams(_streams, inputs.size());
auto user_comms = unpack_comms(_comms, inputs.size());
{
pybind11::gil_scoped_release no_gil;
torch::cuda::nccl::broadcast(inputs, streams, user_comms);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms;
if (!PyArg_ParseTuple(
args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
THPUtils_invalidArguments(
args,
nullptr,
"nccl_all_gather",
1,
"(sequence[Tensor] inputs, sequence[Tensor] outputs"
" sequence[torch.cuda.Stream] streams,"
" sequence[torch.cuda.nccl.Communicator] comms)");
return nullptr;
}
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
auto streams = unpack_streams(_streams, inputs.size());
auto user_comms = unpack_comms(_comms, inputs.size());
{
pybind11::gil_scoped_release no_gil;
all_gather(inputs, outputs, streams, user_comms);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms;
int op;
if (!PyArg_ParseTuple(
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
THPUtils_invalidArguments(
args,
nullptr,
"nccl_reduce_scatter",
1,
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"
" sequence[torch.cuda.Stream] streams,"
" sequence[torch.cuda.nccl.Communicator] comms)");
return nullptr;
}
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
auto streams = unpack_streams(_streams, inputs.size());
auto user_comms = unpack_comms(_comms, inputs.size());
{
pybind11::gil_scoped_release no_gil;
reduce_scatter(inputs, outputs, op, streams, user_comms);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static inline at::Tensor extract_tensor(PyObject* obj) {
if (!THPVariable_Check(obj)) {
throw torch::TypeError("expected Tensor (got %s)", Py_TYPE(obj)->tp_name);
}
return THPVariable_Unpack(obj);
}
static inline std::vector<at::Tensor> extract_tensors(PyObject* obj) {
auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
if (!seq)
throw python_error();
std::vector<at::Tensor> list;
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < length; i++) {
PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
if (!THPVariable_Check(item)) {
throw torch::TypeError(
"expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
}
list.emplace_back(THPVariable_Unpack(item));
}
return list;
}
|