1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
|
import torch
import torch.distributed as dist
from torch import nn
def _quantize_per_tensor_cuda(x, scale, zero_point):
y = torch.round(x / scale) + zero_point
y = torch.clamp(y, 0, 255).to(torch.uint8)
return y
def _dequantize_per_tensor_cuda(y, scale, zero_point):
x = scale * (y.to(torch.float32) - zero_point)
return x
def _quantize_per_channel_cuda(x, scale, zero_point):
y = torch.zeros(x.size(), device=x.device)
for i in range(x.size()[0]):
y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i]
y = torch.clamp(y, 0, 255).to(torch.uint8)
return y
def _dequantize_per_channel_cuda(y, scale, zero_point):
y = y.to(torch.float32).cuda(y.device)
x = torch.zeros_like(y, device=y.device)
for i in range(x.size()[0]):
x[i, :] = scale[i] * (y[i, :] - zero_point[i])
return x
def _get_allgather_out_list(all_gather_in_list, world_size):
out_list = [
torch.zeros_like(
all_gather_in_list,
device=all_gather_in_list.device,
dtype=all_gather_in_list.dtype,
)
for _ in range(world_size)
]
return out_list
def quantization_pertensor_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
protocol. Workers first allgather the scale and zero point of their own
``GradBucket`` prior to the quantization. After all workers have that information,
the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
aggregates each quantized gradient tensor locally and returns the mean.
.. warning ::
This is experimental, and uses ``allgather`` protocol which is considerably slower than
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = group_to_use.size()
tensor = bucket.buffer()
myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device)
myObserver(tensor)
s, z = myObserver.calculate_qparams()
s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)
all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
# First, allgather scale and zeros.
fut = dist.all_gather(
all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
).get_future()
def quantize_and_allgather(fut):
# Store scale and zeros accross all workers.
all_ranks_s_and_z = fut.wait()[0]
# All workers quantize their own ``GradBucket`` tensors.
quantized_tensor = _quantize_per_tensor_cuda(
tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]
)
# Allgather quantized tensors.
fut = dist.all_gather(
_get_allgather_out_list(quantized_tensor, world_size),
quantized_tensor,
group=group_to_use,
async_op=True,
).get_future()
return fut.wait()
def dequantize_and_aggregate(fut):
all_ranks_quantized_tensor = fut.wait()[0]
aggregated_dequantized_tensor = torch.zeros_like(
all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
)
# Using previously allgathered scales and zeros, dequantize gradient tensors
# locally and then aggregate them.
for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
)
return aggregated_dequantized_tensor / world_size
return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
def quantization_perchannel_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
) -> torch.futures.Future[torch.Tensor]:
"""
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
protocol. Compared to pertensor, the main motivation of perchannel is
for considerably large tensors such as a tensor that contains 6 million
elements quantizing per a bucket size of 512 (or 128) elements may significantly
increase the resolution.
It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
elements. Then, workers allgather the scales and zero points of their own
``GradBucket`` prior to the quantization. After all workers have that information,
the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
aggregates each quantized gradient tensor locally and returns the mean.
.. warning ::
This is experimental, and uses ``allgather`` protocol which is considerably slower than
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = group_to_use.size()
tensor = bucket.buffer()
tensor_in_channels = (
nn.functional.pad(
input=tensor,
pad=(0, bucket_size - len(tensor) % bucket_size),
mode="constant",
value=0,
)
.view(-1, bucket_size)
.cuda(tensor.device)
)
myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda(
tensor.device
)
myPerChannelObserver(tensor_in_channels)
s_ch, z_ch = myPerChannelObserver.calculate_qparams()
s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)
all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
# First, allgather scale and zeros.
fut = dist.all_gather(
all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
).get_future()
def quantize_and_allgather(fut):
# Store scale and zeros accross all workers.
all_ranks_s_and_z = fut.wait()[0]
# All workers quantize their corresponding ``GradBucket`` tensors.
quantized_tensor = _quantize_per_channel_cuda(
tensor_in_channels,
all_ranks_s_and_z[rank, 0, :],
all_ranks_s_and_z[rank, 1, :],
)
# Allgather quantized tensors.
fut = dist.all_gather(
_get_allgather_out_list(quantized_tensor, world_size),
quantized_tensor,
group=group_to_use,
async_op=True,
).get_future()
return fut.wait()
def dequantize_and_aggregate(fut):
all_ranks_quantized_tensor = fut.wait()[0]
aggregated_dequantized_tensor = torch.zeros_like(
all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
)
# Using previously allgathered scales and zeros, dequantize gradient tensors
# locally and then aggregate them.
for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
)
return (
torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[
: tensor.size()[0]
]
/ world_size
)
return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
|