1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
# Owner(s): ["oncall: distributed"]
import sys
import copy
import torch
import torch.nn as nn
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.distributed._shard import shard_module
from torch.distributed._shard.sharding_plan import ShardingPlan
from torch.distributed._shard.sharder import Sharder
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
# a simple collection of embedding bag implementation
class CustomEmbeddingBagCollection(nn.Module):
def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
super().__init__()
self.num_bags = num_bags
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for i in range(num_bags):
self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
num_embeddings_per_bag,
num_dims,
mode="sum")
def forward(self, inputs):
outputs = []
for bag in self.embedding_bags.values():
outputs.append(bag(inputs))
return torch.cat(outputs)
# a simple sharded version of EBC
class CustomShardedEBC(nn.Module):
def __init__(self, ebc, split_idx, specs):
super().__init__()
self.split_idx = split_idx
row_spec, col_spec = specs
# create embedding bags base on the spec
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
assert self.split_idx < ebc.num_bags
for i in range(ebc.num_bags):
bag_key = f"embedding_bag_{i}"
if i < self.split_idx:
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": row_spec}))
else:
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": col_spec}))
self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
class CustomSharder(Sharder):
def __init__(self, devices, split_sharding_idx):
self.devices = devices
self.split_sharding_idx = split_sharding_idx
self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
def shard(self, ebc: nn.Module) -> nn.Module:
if not isinstance(ebc, CustomEmbeddingBagCollection):
raise RuntimeError("The custom sharder only supports CustomEmbeddingBagCollection")
return CustomShardedEBC(ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec))
class TestCustomSharder(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_custom_sharder(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
def forward(self, inputs):
return self.ebc(inputs)
custom_sharder = CustomSharder(
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
split_sharding_idx=TEST_GPU_NUM // 2
)
sharding_plan = ShardingPlan(
plan={
"ebc": custom_sharder,
})
local_model = MyModule().cuda(self.rank)
sharded_model = copy.deepcopy(local_model)
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)
# check to make sure the module already been sharded
emb_bags = sharded_model.ebc.embedding_bags
self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
self.assertEqual(emb_bags["embedding_bag_0"].weight.sharding_spec(), custom_sharder.rowwise_spec)
self.assertEqual(emb_bags["embedding_bag_9"].weight.sharding_spec(), custom_sharder.colwise_spec)
# make sure we can run sharded computation and compare outputs
# with the local model version
input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
local_output = local_model(input)
sharded_output = sharded_model(input)
self.assertEqual(local_output, sharded_output)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_custom_sharder_errors(self):
custom_sharder = CustomSharder(
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
split_sharding_idx=TEST_GPU_NUM // 2
)
sharding_plan = ShardingPlan(
plan={
"": custom_sharder,
})
sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
with self.assertRaisesRegex(
KeyError, "path must not be empty for custom sharder!"
):
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)
# test conflicted sharding plan
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
sharding_plan = ShardingPlan(
plan={
"embedding_bags.embedding_bag_0.weight": spec,
"embedding_bags": custom_sharder,
})
with self.assertRaisesRegex(
RuntimeError, "should not conflict with the submodule tree"
):
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)
|