1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
class TestShardedTensorNNInit(ShardedTensorTestBase):
""" Testing torch.nn.init functions for ShardedTensor """
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_sharded_tensor_with_uniform(self):
""" Test torch.nn.init.uniform_(ShardedTensor, a, b) """
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
a, b = 10, 20
seed = 1234
dtype = torch.double
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.uniform_(st, a=a, b=b)
torch.manual_seed(seed)
torch.nn.init.uniform_(local_tensor_clone, a=a, b=b)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_sharded_tensor_with_normal(self):
""" Test torch.nn.init.normal_(ShardedTensor, mean, std) """
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
mean, std = 10, 5
seed = 1234
dtype = torch.double
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.normal_(st, mean=mean, std=std)
torch.manual_seed(seed)
torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_sharded_tensor_with_kaiming_uniform(self):
""" Test torch.nn.init.kaiming_uniform_(ShardedTensor, a, mode, nonlinearit) """
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
a, mode, nonlinearity = 0, 'fan_in', 'leaky_relu'
seed = 1234
dtype = torch.double
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.kaiming_uniform_(st, a=a, mode=mode, nonlinearity=nonlinearity)
torch.manual_seed(seed)
torch.nn.init.kaiming_uniform_(local_tensor_clone, a=a, mode=mode, nonlinearity=nonlinearity)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
if __name__ == '__main__':
run_tests()
|