1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
from torch.distributed.tensor._ops.utils import register_prop_rule
aten = torch.ops.aten
@register_prop_rule(aten.convolution.default)
def convolution_rules(op_schema: OpSchema) -> OutputSharding:
(
input_spec,
weight_spec,
bias_spec,
stride,
padding,
dilation,
_transposed,
_output_padding,
_groups,
) = op_schema.args_schema
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
assert isinstance(bias_spec, DTensorSpec)
assert input_spec.tensor_meta is not None
assert weight_spec.tensor_meta is not None
in_shape = input_spec.tensor_meta.shape
weight_shape = weight_spec.tensor_meta.shape
assert isinstance(stride, List)
assert isinstance(padding, List)
assert isinstance(dilation, List)
assert isinstance(weight_shape, torch.Size)
N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3]
C_out = weight_shape[0]
H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[
0
] + 1
W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[
1
] + 1
output_shape = [N, C_out, H_out, W_out]
output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1)
output_dim_map = input_spec.dim_map
pending_sums = input_spec.sums
tensor_meta = TensorMeta(
torch.Size(output_shape),
output_stride,
input_spec.tensor_meta.dtype,
)
return OutputSharding(
DTensorSpec.from_dim_map(
input_spec.mesh,
output_dim_map,
pending_sums,
tensor_meta=tensor_meta,
)
)
@register_prop_rule(aten.convolution_backward.default)
def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
input_spec = op_schema.args_schema[0]
(
grad_output_spec,
input_spec,
weight_spec,
bias_shape_opt,
_stride,
_padding,
_dilation,
_transposed,
_output_padding,
_groups,
_output_mask,
) = op_schema.args_schema
assert isinstance(grad_output_spec, DTensorSpec)
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
assert isinstance(bias_shape_opt, List)
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
bias_tensor_meta = TensorMeta(
torch.Size(bias_shape_opt),
(1,),
input_spec.tensor_meta.dtype,
)
grad_input_spec = input_spec
grad_weight_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1, -1, -1, -1],
[0],
tensor_meta=weight_tensor_meta,
)
grad_bias_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1],
[0],
tensor_meta=bias_tensor_meta,
)
return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])
|