File: cuda_nccl_op_gpu.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (275 lines) | stat: -rw-r--r-- 7,698 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/operator.h"

#include "caffe2/contrib/nccl/cuda_nccl_gpu.h"

namespace caffe2 {

nccl::NCCLExecution getNCCLElements(
    OperatorBase* op,
    const CUDAContext& context) {
  // We either do an N-N op, or an N-1 op.
  CAFFE_ENFORCE(op->InputSize() == op->OutputSize() || op->OutputSize() == 1);
  nccl::NCCLExecution ex;
  ex.stream_gpu_id = context.device_id();
  ex.stream = context.cuda_stream();
  ex.root = op->template GetSingleArgument<int>("root", 0);
  ex.elements.resize(op->InputSize());
  for (auto i = 0; i < op->InputSize(); ++i) {
    auto& el = ex.elements[i];
    el.src = &(op->Input<Tensor>(i, CUDA));
    if (op->OutputSize() == 1) {
      // Reduce op
      if (i == ex.root) {
        el.dst = op->Output<Tensor>(0, CUDA);
      }
    } else if (i < op->OutputSize()) {
      el.dst = op->Output<Tensor>(i, CUDA);
    }
    // TODO - expensive (>1ms) - cache these.
    el.device = GetGPUIDForPointer(op->Input<Tensor>(i, CUDA).raw_data());
  }

  return ex;
}

namespace {

// Check if all inputs are float
template <typename T>
bool AllInputsAre(OperatorBase* op) {
  for (auto i = 0; i < op->InputSize(); ++i) {
    if (op->Input<Tensor>(i, CUDA).IsType<T>()) {
      continue;
    } else {
      return false;
    }
  }
  return true;
}

// Manual count of all instantiated NCCL ops.
// If this drops to zero after destructing the last NCCL op,
// it means we can safely destroy all lazily created NCCL contexts.
std::atomic<int> kNCCLOpCounter(0);

}; // namespace

class NCCLBaseOp : public Operator<CUDAContext> {
 public:
  using Operator::Operator;

  NCCLBaseOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<CUDAContext>(operator_def, ws) {
    kNCCLOpCounter++;
  }

  ~NCCLBaseOp() {
    if (--kNCCLOpCounter == 0) {
      nccl::destroyContexts();
    }
  }
};

class NCCLAllreduceOp final : public NCCLBaseOp {
 public:
  using NCCLBaseOp::NCCLBaseOp;

  bool RunOnDevice() override {
    if (InputSize() == 1)
      return true;

    if (AllInputsAre<float>(this)) {
      nccl::NCCL<float>::AllReduce(getNCCLElements(this, context_));
      return true;
    } else if (AllInputsAre<at::Half>(this)) {
      nccl::NCCL<at::Half>::AllReduce(getNCCLElements(this, context_));
      return true;
    } else {
      return false;
    }
  }

  static std::vector<TensorShape> ShapeInference(
      const OperatorDef& def,
      const std::vector<TensorShape>& in) {
    auto n_outputs = def.output_size();
    CAFFE_ENFORCE(
        n_outputs == 1 || n_outputs == in.size(),
        "NCCLAllreduce only supports N-1 or N-N reductions");

    for (auto i = 0; i < in.size(); i++) {
      CAFFE_ENFORCE(
          in[0].dims_size() == in[i].dims_size(),
          "NCCLAllreduce requires inputs of same dimension");
      for (auto j = 0; j < in[0].dims_size(); j++) {
        CAFFE_ENFORCE(
            in[0].dims(j) == in[i].dims(j),
            "NCCLAllreduce requires inputs to be of same shape");
      }
    }

    std::vector<TensorShape> out(n_outputs);
    for (auto i = 0; i < out.size(); i++) {
      out[i] = in[0];
    }
    return out;
  }

  static struct OpSchema::Cost CostInference(
      const OperatorDef& def,
      const vector<TensorShape>& inputs) {
    CAFFE_ENFORCE_GE(inputs.size(), 1, "Conv requires at least 1 input");
    const TensorShape X0 = inputs[0];
    const auto nElem = nElemFromDim(inputs[0]);

    struct OpSchema::Cost c;
    c.flops = (inputs.size() - 1) * nElem;
    c.bytes_read = inputs.size() * nElem;
    c.bytes_written = def.output_size() * nElem;
    c.params_bytes = 0;
    return c;
  }
};

class NCCLBroadcastOp final : public NCCLBaseOp {
 public:
  using NCCLBaseOp::NCCLBaseOp;

  bool RunOnDevice() override {
    if (InputSize() == 1)
      return true;
    if (AllInputsAre<float>(this)) {
      nccl::NCCL<float>::Broadcast(getNCCLElements(this, context_));
      return true;
    } else if (AllInputsAre<at::Half>(this)) {
      nccl::NCCL<at::Half>::Broadcast(getNCCLElements(this, context_));
      return true;
    } else {
      return false;
    }
  }
};

class NCCLReduceOp final : public NCCLBaseOp {
 public:
  using NCCLBaseOp::NCCLBaseOp;

  bool RunOnDevice() override {
    if (InputSize() == 1)
      return true;
    const auto& ex = getNCCLElements(this, context_);

    if (AllInputsAre<float>(this)) {
      nccl::NCCL<float>::Reduce(ex);
      return true;
    } else if (AllInputsAre<at::Half>(this)) {
      nccl::NCCL<at::Half>::Reduce(ex);
      return true;
    } else {
      return false;
    }
  }
};

class NCCLAllGatherOp final : public NCCLBaseOp {
 public:
  using NCCLBaseOp::NCCLBaseOp;

  bool RunOnDevice() override {
    if (InputSize() == 1)
      return true;
    if (AllInputsAre<float>(this)) {
      nccl::NCCL<float>::AllGather(getNCCLElements(this, context_));
      return true;
    } else if (AllInputsAre<at::Half>(this)) {
      nccl::NCCL<at::Half>::AllGather(getNCCLElements(this, context_));
      return true;
    } else {
      return false;
    }
  }
};

class NCCLReduceScatterOp final : public NCCLBaseOp {
 public:
  using NCCLBaseOp::NCCLBaseOp;

  bool RunOnDevice() override {
    if (AllInputsAre<float>(this)) {
      nccl::NCCL<float>::ReduceScatter(getNCCLElements(this, context_));
      return true;
    } else if (AllInputsAre<at::Half>(this)) {
      nccl::NCCL<at::Half>::ReduceScatter(getNCCLElements(this, context_));
      return true;
    } else {
      return false;
    }
  }
};

namespace {

std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
    const OperatorDef& def) {
  std::vector<DeviceOption> opt;
  for (int i = 0; i < def.input().size(); ++i) {
    DeviceOption dev;
    dev.set_device_type(1);
    dev.set_device_id(i);
    opt.push_back(dev);
  }
  return std::make_pair(opt, opt);
}

REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
OPERATOR_SCHEMA(NCCLAllreduce)
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .CostInferenceFunction(NCCLAllreduceOp::CostInference)
    .TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
    .IdenticalTypeAndShape()
    .InputsCanCrossDevices()
    .AllowOneToOneInplace()
    .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);

REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
OPERATOR_SCHEMA(NCCLBroadcast)
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .IdenticalTypeAndShape()
    .InputsCanCrossDevices()
    .EnforceOneToOneInplace()
    .DeviceInferenceFunction(ncclOpDevInfer);

SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);

REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
OPERATOR_SCHEMA(NCCLReduce)
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .NumOutputs(1)
    .IdenticalTypeAndShapeOfInput(0)
    .InputsCanCrossDevices()
    .AllowInplace([](int /*in*/, int out) -> bool { return (out == 0); })
    .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLReduce);

REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
OPERATOR_SCHEMA(NCCLAllGather)
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .InputsCanCrossDevices()
    .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);

REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
OPERATOR_SCHEMA(NCCLReduceScatter)
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
    .InputsCanCrossDevices()
    .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);

} // namespace
} // namespace caffe2