File: test_replicated_tensor.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (336 lines) | stat: -rw-r--r-- 12,754 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# Owner(s): ["oncall: distributed"]
import io

import torch
import torch.distributed._shard.sharded_tensor as sharded_tensor

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from torch.distributed._shard import _shard_tensor
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.testing._internal.common_distributed import (
    requires_nccl,
    skip_if_lt_x_gpu,
)

from torch.testing._internal.distributed._shard.sharded_tensor import (
    ShardedTensorTestBase,
    with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
    gen_binary_op_func
)
from torch.testing._internal.distributed._shard.sharded_tensor import TEST_GPU_NUM


class TestReplicatedTensor(ShardedTensorTestBase):

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_basics(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
        replica_tensor = ReplicatedTensor(local_tensor)
        # validate it's a replicated tensor by checking values on all rank
        validated = replica_tensor.validate()
        self.assertEqual(validated, True)
        res = replica_tensor + 2
        self.assertIsInstance(res, torch.Tensor)
        self.assertNotIsInstance(res, ReplicatedTensor)
        self.assertEqual(res, torch.ones(3, 3) * 6)

        # modify local tensor on certain rank, and test if validation raise
        if self.rank == 2:
            local_tensor += 3

        with self.assertRaisesRegex(ValueError, 'have different values'):
            replica_tensor.validate()

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_inter_op_replicated_tensor(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
        replica_tensor1 = ReplicatedTensor(local_tensor * 4)
        replica_tensor2 = ReplicatedTensor(local_tensor * 6)

        new_tensor = replica_tensor1 * replica_tensor2
        self.assertIsInstance(new_tensor, ReplicatedTensor)
        self.assertEqual(new_tensor, torch.ones(3, 3) * 24)

        # test replicated tensor inter-op with different pgs
        new_pg = dist.new_group(ranks=[1, 2, 3])
        replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg)

        with self.assertRaisesRegex(RuntimeError, 'must be in the same'):
            replica_tensor_new_group * replica_tensor1


    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_inter_op_tensor(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
        replica_tensor = ReplicatedTensor(local_tensor)

        local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}")

        new_tensor = replica_tensor + local_rand_tensor
        self.assertIsInstance(new_tensor, torch.Tensor)
        self.assertNotIsInstance(new_tensor, ReplicatedTensor)

        self.assertEqual(new_tensor, local_tensor + local_rand_tensor)

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_inter_op_sharded_tensor(self):
        torch.manual_seed(self.rank)

        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        local_tensor2 = torch.ones(12, 3, device=f"cuda:{self.rank}") * 4

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = ["torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/"]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            res = binary_op(st, replica_tensor)
            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(res, ReplicatedTensor)
            output = torch.empty((12, 3), device=self.rank) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)

            # reflective
            reflect_res = binary_op(replica_tensor, st)
            self.assertIsInstance(reflect_res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(reflect_res, ReplicatedTensor)
            reflect_output = torch.empty((12, 3), device=self.rank) if self.rank == 0 else None
            reflect_res.gather(dst=0, out=reflect_output)

            if self.rank == 0:
                reflect_local_output = binary_op(local_tensor2, local_tensor1)
                self.assertEqual(reflect_output, reflect_local_output)


    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_implicit_broadcasting(self):
        #  use same seed
        torch.manual_seed(self.rank)

        # test implicit broadcasting
        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        # we use size (3) to trigger the implicit broadcasting logic
        # and it will fail if implicit broadcasting not happen.
        local_tensor2 = torch.ones(3, device=f"cuda:{self.rank}")

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = ["torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/"]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            # replicated tensor should automatically broadcasted
            res = binary_op(st, replica_tensor)

            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            output = torch.empty((12, 3), device=self.rank) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)


    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_replicated_tensor_inter_op_sharded_tensor_errors(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
        replica_tensor = ReplicatedTensor(local_tensor)

        torch.manual_seed(self.rank)
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st1 = sharded_tensor.rand(spec, (20, 3, 3))
        st2 = sharded_tensor.rand(spec, (30, 3, 3))

        with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting'):
            st1 + st2

        with self.assertRaisesRegex(RuntimeError, 'not supported for ShardedTensor'):
            st1 % replica_tensor

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_with_ddp(self):
        # Test Replicated params for DDP
        replica_tensor = ReplicatedTensor(torch.rand(4, 8, device=self.rank))
        model = torch.nn.Linear(8, 2).cuda(self.rank)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        ddp = DDP(model)

        # Test module.parameters.
        params = list(ddp.parameters())
        self.assertEqual(2, len(params))
        self.assertEqual(ddp.module.weight, params[0])
        self.assertEqual(ddp.module.bias, params[1])

        params = list(model.parameters())
        self.assertEqual(2, len(params))
        self.assertEqual(model.weight, params[0])
        self.assertEqual(model.bias, params[1])

        # Validate output
        out = ddp(replica_tensor)
        self.assertIsInstance(out, ReplicatedTensor)

        # Test backward and optimizer.

        # Validate backward.
        out.sum().backward()
        self.assertIsNotNone(model.weight.grad)
        self.assertIsNotNone(model.bias.grad)
        self.assertIsNotNone(ddp.module.weight.grad)
        self.assertIsNotNone(ddp.module.bias.grad)

        original_params = []
        for param_group in optim.param_groups:
            for original_param in param_group['params']:
                self.assertIsNotNone(original_param.grad)
                original_params.append(original_param)

        self.assertEqual(model.weight.grad, original_params[0].grad)
        self.assertEqual(model.bias.grad, original_params[1].grad)
        self.assertEqual(model.weight.grad, ddp.module.weight.grad)
        self.assertEqual(model.bias.grad, ddp.module.bias.grad)

        # Validate optimizer.
        optim.step()
        self.assertEqual(model.weight, ddp.module.weight)
        self.assertEqual(model.weight, original_params[0])

        self.assertEqual(model.bias, ddp.module.bias)
        self.assertEqual(model.bias, original_params[1])

        # Validate zero_grad
        optim.zero_grad()
        self.assertEqual(model.weight.grad, torch.zeros_like(model.weight.grad))
        self.assertEqual(model.weight.grad, ddp.module.weight.grad)
        self.assertEqual(model.weight.grad, original_params[0].grad)

        self.assertEqual(model.bias.grad, torch.zeros_like(model.bias.grad))
        self.assertEqual(model.bias.grad, ddp.module.bias.grad)
        self.assertEqual(model.bias.grad, original_params[1].grad)

        # Validate zero_grad set_to_none
        optim.zero_grad(set_to_none=True)
        self.assertIsNone(model.weight.grad)
        self.assertEqual(model.weight.grad, ddp.module.weight.grad)
        self.assertEqual(model.weight.grad, original_params[0].grad)

        self.assertIsNone(model.bias.grad)
        self.assertEqual(model.bias.grad, ddp.module.bias.grad)
        self.assertEqual(model.bias.grad, original_params[1].grad)

        # Multiple forward passes.
        for _ in range(5):
            out = ddp(replica_tensor)
            self.assertIsInstance(out, ReplicatedTensor)

        # Test with context manager.
        from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor
        with _ddp_replicated_tensor(False):
            for _ in range(5):
                with _ddp_replicated_tensor(True):
                    ddp = DDP(model)
                    out = ddp(replica_tensor)
                self.assertIsInstance(out, ReplicatedTensor)

        # Test save and load.
        with _ddp_replicated_tensor(False):
            ddp = DDP(model)
            expected_state_dict = ddp.state_dict()
            buffer = io.BytesIO()
            torch.save(ddp, buffer)

            buffer.seek(0)
            obj = torch.load(buffer)
            self.assertEqual(expected_state_dict, obj.state_dict())

        with _ddp_replicated_tensor(True):
            ddp = DDP(model)
            buffer = io.BytesIO()
            torch.save(ddp, buffer)

            buffer.seek(0)
            obj = torch.load(buffer)
            self.assertEqual(expected_state_dict, obj.state_dict())

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_unsqueeze(self):
        local_tensor = torch.rand(3, 3, device=self.rank)
        replicated_tensor = ReplicatedTensor(local_tensor)

        unsqueezed_replicated_tensor = replicated_tensor.unsqueeze(0)
        unsqueezed_local_tensor = local_tensor.unsqueeze(0)

        self.assertIsInstance(unsqueezed_replicated_tensor, ReplicatedTensor)
        self.assertIsInstance(torch.unsqueeze(replicated_tensor, 0), ReplicatedTensor)
        self.assertEqual(unsqueezed_local_tensor, unsqueezed_replicated_tensor)
        self.assertEqual(torch.unsqueeze(replicated_tensor, 0), unsqueezed_replicated_tensor)

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_getitem(self):
        local_tensor = torch.rand(3, 3, device=self.rank)
        replicated_tensor = ReplicatedTensor(local_tensor)

        replicated_tensor_view = replicated_tensor[0]
        local_tensor_view = local_tensor[0]

        self.assertIsInstance(replicated_tensor_view, ReplicatedTensor)
        self.assertEqual(local_tensor_view, replicated_tensor_view)