1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
#include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
#include <torch/library.h>
// TODO: The kernels are copied from fbgemm_gpu, we should dedup them later
// FP32 -> BF16 kernel
__global__ void _float_to_bfloat16_cuda_kernel(
const float* __restrict__ input,
const size_t nrows,
const size_t ncols,
uint16_t* __restrict__ output) {
const auto row_incre = blockDim.y * gridDim.y;
const auto col_incre = blockDim.x * gridDim.x;
for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
const float* input_row = input + row * ncols;
uint16_t* output_row = output + row * ncols;
for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
// Add 2^15 and right shift 16 to do round-nearest
output_row[col] =
(*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
16;
}
}
}
// BF16 -> FP32 kernel
__global__ void _bfloat16_to_float_cuda_kernel(
const uint16_t* __restrict__ input,
const size_t nrows,
const size_t ncols,
float* __restrict__ output) {
const auto row_incre = blockDim.y * gridDim.y;
const auto col_incre = blockDim.x * gridDim.x;
for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
const uint16_t* input_row = input + row * ncols;
float* output_row = output + row * ncols;
uint32_t val_fp32 = static_cast<uint32_t>(
reinterpret_cast<const uint16_t*>(input_row)[col])
<< 16;
reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
}
}
}
namespace torch::distributed::c10d::quantization {
at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
TENSOR_ON_CUDA_GPU(input);
// Currently it supports 2D inputs
TENSOR_NDIM_EQUALS(input, 2);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());
const auto nrows = input.size(0);
const auto ncols = input.size(1);
const size_t output_columns = ncols;
auto output = at::empty(
{nrows, ncols},
#if HAS_NCCL_BF16_DATATYPE
input.options().dtype(at::kBFloat16));
#else
input.options().dtype(at::kHalf));
#endif
if (nrows == 0 || ncols == 0) {
return output;
}
constexpr size_t threads_per_block = 256;
const auto blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const auto gridDim_y =
std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
dim3 gridDim(gridDim_x, gridDim_y);
_float_to_bfloat16_cuda_kernel<<<
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<float>(),
nrows,
ncols,
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
#else
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
#endif
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
TENSOR_ON_CUDA_GPU(input);
// Currently it supports 2D inputs
TENSOR_NDIM_EQUALS(input, 2);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());
const auto nrows = input.size(0);
const auto ncols = input.size(1);
const size_t output_columns = ncols;
auto output = at::empty(
{nrows, ncols}, // 4 = sizeof(float)
input.options().dtype(at::kFloat)); // at::kBytes for uint8_t
if (nrows == 0 || ncols == 0) {
return output;
}
constexpr size_t threads_per_block = 256;
const auto blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const auto gridDim_y =
std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
dim3 gridDim(gridDim_x, gridDim_y);
_bfloat16_to_float_cuda_kernel<<<
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream()>>>(
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
#else
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),
#endif
nrows,
ncols,
output.mutable_data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
#define DISPATCH_TO_CUDA(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
TORCH_LIBRARY_IMPL(quantization, CUDA, m) {
DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
}
} // namespace torch::distributed::c10d::quantization
|