1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
ITER_TIME = 10
LR = 0.001
def _conv_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
) -> None:
for name, param in module.named_parameters():
dist_spec = [Replicate()]
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, dist_spec)
)
name = "_".join(name.split("."))
module.register_parameter(name, dist_param)
class DistConvolutionOpsTest(DTensorTestBase):
@property
def world_size(self) -> int:
# hard code world size to 2
return 2
@with_comms
def test_downsampling_convolution(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(3)]
input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024)
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to(
self.device_type
)
nn.init.ones_(model.weight)
nn.init.zeros_(model.bias)
model_gt = copy.deepcopy(model).to(self.device_type)
# training with dtensor
model = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
output = model(inp_dtensor)
grad_output = grad_output_list[i].to(self.device_type)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output.backward(grad_output_dtensor)
optimizer.step()
# training with plain tensor
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer_gt.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
output = model_gt(inp)
grad_output = grad_output_list[i].to(self.device_type)
output.backward(grad_output)
optimizer_gt.step()
weight_diff_abs = model.weight.to_local() - model_gt.weight
bias_diff_abs = model.bias.to_local() - model_gt.bias
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
self.assertTrue(
weight_mse_abs <= 1e-6,
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
)
self.assertTrue(
bias_mse_abs <= 1e-6,
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
)
self.assertTrue(
weight_mse_rel <= 1e-6,
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
)
self.assertTrue(
bias_mse_rel <= 1e-6,
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
)
# TODO: test_depthwise_convolution is broken in CI with gloo backend.
# Temporarily disable it to unblock CI.
@with_comms
@skip_if_lt_x_gpu(2)
def test_depthwise_convolution(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(3)]
input_list = torch.rand(ITER_TIME, 7, 256, 128, 256)
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to(
self.device_type
)
nn.init.ones_(model.weight)
nn.init.zeros_(model.bias)
model_gt = copy.deepcopy(model).to(self.device_type)
# training with dtensor
model = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
output = model(inp_dtensor)
grad_output = grad_output_list[i].to(self.device_type)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output.backward(grad_output_dtensor)
optimizer.step()
# training with plain tensor
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer_gt.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
output = model_gt(inp)
grad_output = grad_output_list[i].to(self.device_type)
output.backward(grad_output)
optimizer_gt.step()
weight_diff_abs = model.weight.to_local() - model_gt.weight
bias_diff_abs = model.bias.to_local() - model_gt.bias
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
self.assertTrue(
weight_mse_abs <= 1e-6,
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
)
self.assertTrue(
bias_mse_abs <= 1e-6,
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
)
self.assertTrue(
weight_mse_rel <= 1e-6,
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
)
self.assertTrue(
bias_mse_rel <= 1e-6,
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
)
if __name__ == "__main__":
run_tests()
|