File: quantization.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (92 lines) | stat: -rw-r--r-- 2,819 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <torch/csrc/distributed/c10d/quantization/quantization.h>
#include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
#include <torch/library.h>

namespace torch::distributed::c10d::quantization {

// TODO: The kernels are copied from fbgemm_gpu, we should dedup them later

static void FloatToBFloat16Quantized_ref(
    const float* const input,
    const size_t nrows,
    const size_t ncols,
    uint16_t* const output) {
  for (const auto row : c10::irange(nrows)) {
    const float* input_row = input + row * ncols;
    uint16_t* output_row = output + row * ncols;

    for (const auto col : c10::irange(ncols)) {
      output_row[col] =
          (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
          16;
    }
  }
}

static void BFloat16QuantizedToFloat_ref(
    const at::BFloat16* const input,
    const size_t nrows,
    const size_t ncols,
    float* const output) {
  for (const auto row : c10::irange(nrows)) {
    const at::BFloat16* input_row = input + row * ncols;
    float* output_row = output + row * ncols;

    for (const auto col : c10::irange(ncols)) {
      uint32_t val_fp32 = static_cast<uint32_t>(
                              reinterpret_cast<const uint16_t*>(input_row)[col])
          << 16;
      reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
    }
  }
}

at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
  TENSOR_ON_CPU(input);
  // Currently it supports 2D inputs
  TENSOR_NDIM_EQUALS(input, 2);

  const auto input_sizes = input.sizes();
  const auto nrows = input_sizes[0];
  const auto ncols = input_sizes[1];
  auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf));

  FloatToBFloat16Quantized_ref(
      input.const_data_ptr<float>(),
      nrows,
      ncols,
      reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));

  return output;
}

at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
  TENSOR_ON_CPU(input);
  // Currently it supports 2D inputs
  TENSOR_NDIM_EQUALS(input, 2);

  const auto input_sizes = input.sizes();
  const auto nrows = input_sizes[0];
  const auto ncols = input_sizes[1];

  auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat));
  BFloat16QuantizedToFloat_ref(
      reinterpret_cast<const at::BFloat16*>(input.const_data_ptr<at::Half>()),
      nrows,
      ncols,
      output.mutable_data_ptr<float>());

  return output;
}

TORCH_LIBRARY(quantization, m) {
  m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
  m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
}

TORCH_LIBRARY_IMPL(quantization, CPU, m) {
  m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu);
  m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
}

} // namespace torch::distributed::c10d::quantization