1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from model_registry import ModelWithKwargs
import torch
from torch.distributed.pipelining import pipeline
from torch.distributed.pipelining.microbatch import (
merge_chunks,
split_args_kwargs_into_chunks,
TensorChunkSpec,
)
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
torch.manual_seed(0)
class MicrobatchTests(TestCase):
def test_split_and_merge(self):
x0 = torch.randn(128, d_hid)
x1 = torch.randn(256, d_hid)
x2 = torch.randn(512, d_hid)
args = (x0, x1, x2)
kwargs = {"x0": x0, "x1": x1, "x2": x2}
# Default chunking: dim 0
arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2)
assert len(arg_chunks) == 2
assert len(kwarg_chunks) == 2
assert arg_chunks[0][0].shape == torch.Size([64, d_hid])
assert arg_chunks[1][0].shape == torch.Size([64, d_hid])
assert arg_chunks[0][1].shape == torch.Size([128, d_hid])
assert arg_chunks[0][2].shape == torch.Size([256, d_hid])
assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid])
assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid])
assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid])
# Merge chunks back together
merged_args = merge_chunks(
arg_chunks,
(TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)),
)
torch.testing.assert_close(merged_args, args)
merged_kwargs = merge_chunks(
kwarg_chunks,
{
"x0": TensorChunkSpec(0),
"x1": TensorChunkSpec(0),
"x2": TensorChunkSpec(0),
},
)
torch.testing.assert_close(merged_kwargs, kwargs)
print("Microbatch test passed")
def test_chunk_spec(self):
mod = ModelWithKwargs()
batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
num_chunks = 4
args_chunk_spec = TensorChunkSpec.from_tuple((0,))
kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
args_split, kwargs_split = split_args_kwargs_into_chunks(
(x,),
{"y": y},
num_chunks,
args_chunk_spec,
kwargs_chunk_spec,
)
pipe = pipeline(
mod,
mb_args=args_split[0],
mb_kwargs=kwargs_split[0],
)
ref = mod(x, y)
out = pipe(x, y)[0]
torch.testing.assert_close(out, ref)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()
|