1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import cast, Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.distributed.tensor._api as dtensor
aten = torch.ops.aten
def _requires_data_exchange(padding):
# TODO: whether there requires data exchange is currently determined by padding
return padding[1] != 0
def _is_supported(input_size, kernel_size, stride, padding, dilation):
if dilation[1] != 1:
raise RuntimeError("Dilation must be 1 for tensor parallel convolution.")
if padding[1] != 0:
if stride[1] != 1:
raise RuntimeError(
"Stride must be 1 when there is padding for tensor parallel convolution."
)
if kernel_size[3] // 2 > input_size[3]:
raise RuntimeError(
"kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution."
)
else:
if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]):
raise RuntimeError(
"It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] "
"when there is padding for tensor parallel convolution."
)
return True
def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size):
# dist comms and reconstruct local input tensor
send_to_right = in_tensor[:, :, :, -d1:].contiguous()
send_to_left = in_tensor[:, :, :, :d2].contiguous()
recv_from_right = torch.zeros_like(send_to_left)
recv_from_left = torch.zeros_like(send_to_right)
send_op_right = dist.P2POp(dist.isend, send_to_right, right)
send_op_left = dist.P2POp(dist.isend, send_to_left, left)
recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
reqs = dist.batch_isend_irecv(
[send_op_right, send_op_left, recv_op_left, recv_op_right]
)
for req in reqs:
req.wait()
if rank == 0:
in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1)
elif rank == size - 1:
in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1)
else:
in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1)
return in_tensor
def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
# dist comms and aggregate gradients for edge pixels
send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous()
send_to_left = grad_in_tensor[:, :, :, :d1].contiguous()
recv_from_right = torch.zeros_like(send_to_left)
recv_from_left = torch.zeros_like(send_to_right)
send_op_right = dist.P2POp(dist.isend, send_to_right, right)
send_op_left = dist.P2POp(dist.isend, send_to_left, left)
recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
reqs = dist.batch_isend_irecv(
[send_op_right, send_op_left, recv_op_left, recv_op_right]
)
for req in reqs:
req.wait()
if rank == 0:
grad_in_tensor = grad_in_tensor[:, :, :, :-d2]
grad_in_tensor[:, :, :, -d1:] = torch.add(
grad_in_tensor[:, :, :, -d1:], recv_from_right
)
elif rank == size - 1:
grad_in_tensor = grad_in_tensor[:, :, :, d1:]
grad_in_tensor[:, :, :, :d2] = torch.add(
grad_in_tensor[:, :, :, :d2], recv_from_left
)
else:
grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2]
grad_in_tensor[:, :, :, -d1:] = torch.add(
grad_in_tensor[:, :, :, -d1:], recv_from_right
)
grad_in_tensor[:, :, :, :d2] = torch.add(
grad_in_tensor[:, :, :, :d2], recv_from_left
)
def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: Tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
rank = dist.get_rank()
size = dist.get_world_size()
in_tensor = cast(torch.Tensor, local_tensor_args[0])
weight = cast(torch.Tensor, local_tensor_args[1])
stride, padding, dilation = local_tensor_args[3:6]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
# step 0 compute the overlap pixels of the input tensor
d = weight.shape[3] - 1
d1 = d // 2
d2 = d - d1
assert d1 + d2 == d
right = (rank + 1) % size
left = (rank - 1 + size) % size
# step1 reconstruct local input tensor
in_tensor = _ring_send_recv_construct(
in_tensor, d1, d2, left, right, rank, size
)
# step2 feed local input tensor to op_call
local_tensor_args_list = list(local_tensor_args)
local_tensor_args_list[0] = in_tensor
local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
# step3 remove extra outputs from the results
padding_w = padding[1]
w = local_results.size(3)
if rank == 0:
local_results = local_results[:, :, :, : w - padding_w]
elif rank == size - 1:
local_results = local_results[:, :, :, padding_w:]
else:
local_results = local_results[:, :, :, padding_w : w - padding_w]
return local_results
def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: Tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
rank = dist.get_rank()
size = dist.get_world_size()
grad_out_tensor = cast(torch.Tensor, local_tensor_args[0])
in_tensor = cast(torch.Tensor, local_tensor_args[1])
weight = cast(torch.Tensor, local_tensor_args[2])
stride, padding, dilation = local_tensor_args[4:7]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
# step 0 compute the overlap pixels of the input tensor
d = weight.shape[3] - 1
d1 = d // 2
d2 = d - d1
assert d1 + d2 == d
right = (rank + 1) % size
left = (rank - 1 + size) % size
# step1 reconstruct local input tensor
in_tensor = _ring_send_recv_construct(
in_tensor, d1, d2, left, right, rank, size
)
# step2 reconstruct local gradient output tensor
padding_w = padding[1]
if rank == 0:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (0, padding_w), "constant", 0
)
elif rank == size - 1:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (padding_w, 0), "constant", 0
)
else:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (padding_w, padding_w), "constant", 0
)
# step3 feed local input tensor to op_call
local_tensor_args_list = list(local_tensor_args)
local_tensor_args_list[0] = grad_out_tensor
local_tensor_args_list[1] = in_tensor
local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
# step4 aggregate gradients for edge pixels
grad_in_tensor = local_results[0]
grad_in_tensor = _ring_send_recv_aggregate(
grad_in_tensor, d1, d2, left, right, rank, size
)
local_results = list(local_results)
local_results[0] = grad_in_tensor
local_results = cast(Tuple[object, ...], local_results)
return local_results
def convolution_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# extract local tensor and sharding infos to a OpInfo
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
# sharding propagation
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
# local propagation
local_results = tp_convolution(
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
def convolution_backward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as input tensor
args = list(args)
assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
args = tuple(args)
# extract local tensor and sharding infos to a OpInfo
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
# sharding propagation
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
# local propagation
local_results = tp_convolution_backward(
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
|