1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
|
#include "caffe2/contrib/gloo/broadcast_ops.h"
#include "caffe2/core/context_gpu.h"
#include <gloo/cuda_broadcast_one_to_all.h>
namespace caffe2 {
namespace gloo {
template <class Context>
void BroadcastOp<Context>::initializeAlgorithm() {
if (init_.template IsType<float>()) {
algorithm_.reset(new ::gloo::CudaBroadcastOneToAll<float>(
init_.context, init_.template getOutputs<float>(), init_.size, root_));
} else if (init_.template IsType<long>()) {
algorithm_.reset(new ::gloo::CudaBroadcastOneToAll<long>(
init_.context, init_.template getOutputs<long>(), init_.size, root_));
} else if (init_.template IsType<int>()) {
algorithm_.reset(new ::gloo::CudaBroadcastOneToAll<int>(
init_.context, init_.template getOutputs<int>(), init_.size, root_));
} else if (init_.template IsType<at::Half>()) {
algorithm_.reset(new ::gloo::CudaBroadcastOneToAll<::gloo::float16>(
init_.context,
init_.template getOutputs<::gloo::float16>(),
init_.size,
root_));
} else {
CAFFE_ENFORCE(false, "Unhandled type: ", init_.meta.name());
}
}
namespace {
REGISTER_CUDA_OPERATOR_WITH_ENGINE(Broadcast, GLOO, BroadcastOp<CUDAContext>);
} // namespace
} // namespace gloo
} // namespace caffe2
|