1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import functools
import torch
import torch.distributed as dist
from enum import Enum
TORCH_HALF_MIN = torch.finfo(torch.float16).min
TORCH_HALF_MAX = torch.finfo(torch.float16).max
class DQuantType(Enum):
"""
Different quantization methods for auto_quantize API are identified here.
auto_quantize API currently supports fp16 and bfp16 methods.
"""
FP16 = "fp16",
BFP16 = "bfp16"
def __str__(self) -> str:
return self.value
def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
def _quantize_tensor(tensor, qtype):
if not isinstance(tensor, torch.Tensor):
raise RuntimeError(
f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
)
if (qtype == DQuantType.FP16):
return _fp32_to_fp16_with_clamp(tensor)
elif (qtype == DQuantType.BFP16):
return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
else:
raise RuntimeError(
f'Quantization type {qtype} is not supported'
)
def _quantize_tensor_list(tensor_list, qtype):
if not isinstance(tensor_list, list) or not all(
isinstance(p, torch.Tensor) for p in tensor_list
):
raise RuntimeError(
f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
)
quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
return quantized_tensor_list
def _dequantize_tensor(tensor, qtype, quant_loss=None):
if not isinstance(tensor, torch.Tensor):
raise RuntimeError(
f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
)
if (qtype == DQuantType.FP16):
if tensor.dtype != torch.float16:
raise RuntimeError(
f"tensor dtype is {tensor.dtype} while expected to be FP16."
)
elif tensor.dtype == torch.float16 and quant_loss is None:
return tensor.float()
else:
return tensor.float() / quant_loss
elif (qtype == DQuantType.BFP16):
if tensor.dtype != torch.float16:
raise RuntimeError(
f"tensor dtype is {tensor.dtype} while expected to be FP16."
)
else:
return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
else:
raise RuntimeError(
f'Quantization type {qtype} is not supported'
)
def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
if not isinstance(tensor_list, list) or not all(
isinstance(p, torch.Tensor) for p in tensor_list
):
raise RuntimeError(
f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
)
dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
return dequantized_tensor_list
def auto_quantize(func, qtype, quant_loss=None):
"""
This is a prototype API that automatically quantize the input tensors, choose the precision types, and
pass other necessary arguments and then dequantizes the output.
Currently it only supports:
. FP16 and BFP16 quantization method supported for gloo and nccl backends
. all_gather, all_to_all collective ops
Note: BFP16 only supports 2D tensors.
Args:
func (Callable): A function representing collective operations.
qtype (QuantType): Quantization method
quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
Returns:
(Callable): the same collective as func but enables automatic quantization/dequantization.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
group = kwargs.get('group', None)
async_op = kwargs.get('async_op', False)
if (async_op is True):
raise RuntimeError(
'The async_op=True mode is not supported yet.'
)
if (func == dist.all_gather):
tensors = args[0]
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)
dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
tensors[i] = t
elif (func == dist.all_to_all):
tensors = args[0]
input_tensors = _quantize_tensor_list(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)
dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
tensors[i] = t
elif (func == dist.all_to_all_single):
tensors = args[0]
out_splits = kwargs.get('out_splits', None)
in_splits = kwargs.get('in_splits', None)
# Quantizing the input/output tensor
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor(tensors, qtype)
dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group)
for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)):
tensors[i] = t
else:
raise RuntimeError(
f"The collective op {func} is not supported yet"
)
return wrapper
|