1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
ITER_TIME = 10
LR = 0.001
class DistOtherOpsTest(DTensorTestBase):
@property
def world_size(self) -> int:
# hard code world size to 2
return 2
@with_comms
def test_slice(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Replicate()]
input_list = torch.rand(ITER_TIME, 1024, 10)
grad_output_list = torch.rand(ITER_TIME, 1024, 5) * 1e-3
for i in range(ITER_TIME):
inp = input_list[i].to(self.device_type).requires_grad_()
grad_output = grad_output_list[i].to(self.device_type)
# droppath with dtensor
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output = inp_dtensor[:, :5]
output.backward(grad_output_dtensor)
# nll with plain tensor
output_gt = inp[:, :5]
output_gt.backward(grad_output)
output_diff_abs = output.to_local() - output_gt
output_diff_rel = output_diff_abs / (torch.abs(output_gt) + 1e-8)
output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
grad_diff_abs = inp_dtensor.grad.to_local() - inp.grad
grad_diff_rel = grad_diff_abs / (torch.abs(inp.grad) + 1e-8)
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
self.assertTrue(
output_mse_abs <= 1e-6,
f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
)
self.assertTrue(
output_mse_rel <= 1e-6,
f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
)
self.assertTrue(
grad_mse_abs <= 1e-6,
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
)
self.assertTrue(
grad_mse_rel <= 1e-6,
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
)
@with_comms
def test_bernoulli(self):
rank = dist.get_rank()
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Replicate()]
input_list = torch.rand(ITER_TIME, 1024, 10)
grad_output_list = torch.rand(ITER_TIME, 1024, 10) * 1e-3
for i in range(ITER_TIME):
inp = input_list[i].to(self.device_type).requires_grad_()
grad_output = grad_output_list[i].to(self.device_type)
# bernoulli with dtensor
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output = torch.bernoulli(inp_dtensor)
output.backward(grad_output_dtensor)
send_output_tensor = output.to_local()
recv_output_tensor = torch.zeros_like(send_output_tensor)
send_grad_tensor = inp_dtensor.grad.to_local()
recv_grad_tensor = torch.zeros_like(send_grad_tensor)
send_op_1 = dist.P2POp(dist.isend, send_output_tensor, 1 ^ rank)
send_op_2 = dist.P2POp(dist.isend, send_grad_tensor, 1 ^ rank)
recv_op_1 = dist.P2POp(dist.irecv, recv_output_tensor, 1 ^ rank)
recv_op_2 = dist.P2POp(dist.irecv, recv_grad_tensor, 1 ^ rank)
reqs = dist.batch_isend_irecv([send_op_1, send_op_2, recv_op_1, recv_op_2])
for req in reqs:
req.wait()
output_diff_abs = send_output_tensor - recv_output_tensor
output_diff_rel = output_diff_abs / (torch.abs(recv_output_tensor) + 1e-8)
output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
grad_diff_abs = send_grad_tensor - recv_grad_tensor
grad_diff_rel = grad_diff_abs / (torch.abs(recv_grad_tensor) + 1e-8)
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
self.assertTrue(
output_mse_abs <= 1e-6,
f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
)
self.assertTrue(
output_mse_rel <= 1e-6,
f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
)
self.assertTrue(
grad_mse_abs <= 1e-6,
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
)
self.assertTrue(
grad_mse_rel <= 1e-6,
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
)
@with_comms
def test_nll(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Replicate()]
pred_list = torch.rand(ITER_TIME, 1024, 10)
target_list = torch.randint(0, 10, (ITER_TIME, 1024), dtype=torch.long)
criterion = torch.nn.CrossEntropyLoss()
for i in range(ITER_TIME):
pred = pred_list[i].to(self.device_type).requires_grad_()
target = target_list[i].to(self.device_type)
# nll with dtensor
pred_dtensor = distribute_tensor(pred, device_mesh, shard_spec)
target_dtensor = distribute_tensor(target, device_mesh, shard_spec)
loss = criterion(pred_dtensor, target_dtensor)
loss.backward()
# nll with plain tensor
loss_gt = criterion(pred, target)
loss_gt.backward()
loss_diff_abs = loss.to_local() - loss_gt
loss_diff_rel = loss_diff_abs / (torch.abs(loss_gt) + 1e-8)
loss_mse_abs = torch.mean(loss_diff_abs * loss_diff_abs).item()
loss_mse_rel = torch.mean(loss_diff_rel * loss_diff_rel).item()
grad_diff_abs = pred_dtensor.grad.to_local() - pred.grad
grad_diff_rel = grad_diff_abs / (torch.abs(pred.grad) + 1e-8)
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
self.assertTrue(
loss_mse_abs <= 1e-6,
f"Too large absolute mse for loss, expected less equal 1e-6, got {loss_mse_abs}",
)
self.assertTrue(
loss_mse_rel <= 1e-6,
f"Too large relative mse for loss, expected less equal 1e-6, got {loss_mse_rel}",
)
self.assertTrue(
grad_mse_abs <= 1e-6,
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
)
self.assertTrue(
grad_mse_rel <= 1e-6,
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
)
if __name__ == "__main__":
run_tests()
|