1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy
from model_registry import MLPModule
import torch
from torch.distributed.pipelining._backward import (
stage_backward,
stage_backward_input,
stage_backward_weight,
)
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
batch_size = 256
class StageBackwardTests(TestCase):
def test_stage_backward(self):
# MLP as a stage module
mod = MLPModule(d_hid)
x = torch.randn(batch_size, d_hid)
# As in a pipeline stage, the inputs to this stage requires gradients
x.requires_grad_(True)
target = torch.randn(batch_size, d_hid)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Make a copy
ref_mod = copy.deepcopy(mod)
ref_x = x.detach().requires_grad_(x.requires_grad)
ref_target = target.detach()
# Forward and backward in stage manner
out = mod(x)
loss = loss_fn(out, target)
grad_inputs = stage_backward(
stage_output=loss,
output_grads=None,
input_values=(x,),
)
# Run reference
ref_out = ref_mod(ref_x)
ref_loss = loss_fn(ref_out, ref_target)
ref_loss.backward()
torch.testing.assert_close(grad_inputs[0], ref_x.grad)
# Every rank checks gradients
for name, p in mod.named_parameters():
ref_p = ref_mod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise
def test_stage_backward_input(self):
# MLP as a stage module
mod = MLPModule(d_hid)
x = torch.randn(batch_size, d_hid)
# As in a pipeline stage, the inputs to this stage requires gradients
x.requires_grad_(True)
target = torch.randn(batch_size, d_hid)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Make a copy
ref_mod = copy.deepcopy(mod)
ref_x = x.detach().requires_grad_(x.requires_grad)
ref_target = target.detach()
# Forward, then backward of loss with respect to inputs
out = mod(x)
loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,),
output_grads=None,
input_values=[x],
weights=mod.parameters(),
)
# Run reference
ref_out = ref_mod(ref_x)
ref_loss = loss_fn(ref_out, ref_target)
ref_loss.backward()
torch.testing.assert_close(x.grad, ref_x.grad)
torch.testing.assert_close(dinputs[0], ref_x.grad)
for name, p in mod.named_parameters():
# Check that the weight gradients were not updated
self.assertEqual(p.grad, None)
def test_stage_backward_weight(self):
# MLP as a stage module
mod = MLPModule(d_hid)
x = torch.randn(batch_size, d_hid)
# As in a pipeline stage, the inputs to this stage requires gradients
x.requires_grad_(True)
target = torch.randn(batch_size, d_hid)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Make a copy
ref_mod = copy.deepcopy(mod)
ref_x = x.detach().requires_grad_(x.requires_grad)
ref_target = target.detach()
# Forward, then backward of loss with respect to inputs
out = mod(x)
loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,),
output_grads=None,
input_values=[x],
weights=mod.parameters(),
)
# backward of loss with respect to weights
stage_backward_weight(mod.parameters(), param_groups, retain_graph=True)
# Run reference
ref_out = ref_mod(ref_x)
ref_loss = loss_fn(ref_out, ref_target)
ref_loss.backward()
# Every rank checks gradients
for name, p in mod.named_parameters():
ref_p = ref_mod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise
def test_stage_backward_weight_multiple_iters(self):
# MLP as a stage module
mod = MLPModule(d_hid)
inputs = []
for _ in range(10):
x = torch.randn(batch_size, d_hid)
inputs.append(x)
# As in a pipeline stage, the inputs to this stage requires gradients
x.requires_grad_(True)
target = torch.randn(batch_size, d_hid)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Make a copy
ref_mod = copy.deepcopy(mod)
ref_inputs = []
for x in inputs:
ref_inputs.append(x.detach().requires_grad_(x.requires_grad))
ref_target = target.detach()
# Forward, then backward of loss with respect to inputs
for x in inputs:
out = mod(x)
loss = loss_fn(out, target)
dinputs, param_groups = stage_backward_input(
stage_outputs_or_loss=(loss,),
output_grads=None,
input_values=[x],
weights=mod.parameters(),
)
# backward of loss with respect to weights
stage_backward_weight(mod.parameters(), param_groups)
# Run reference
for ref_x in ref_inputs:
ref_out = ref_mod(ref_x)
ref_loss = loss_fn(ref_out, ref_target)
ref_loss.backward()
# Every rank checks gradients
for name, p in mod.named_parameters():
ref_p = ref_mod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise
if __name__ == "__main__":
run_tests()
|