File: test_tp_random_state.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (133 lines) | stat: -rw-r--r-- 5,645 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    MLPModule,
    with_comms,
)


class TensorParallelRandomStateTests(DTensorTestBase):
    def get_tensor_slice(self, idx, n, large_tensor):
        shape = large_tensor.shape
        assert shape[0] % n == 0
        local_shape = [shape[0] // n, shape[1]]

        slice_idx = [
            slice(idx * local_shape[0], (idx + 1) * local_shape[0]),
            slice(local_shape[1]),
        ]
        return large_tensor[slice_idx]

    def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc):
        for other_rank in range(size):
            if self_rank != other_rank:
                assertFunc(
                    self.get_tensor_slice(self_rank, size, gathered_tensors),
                    self.get_tensor_slice(other_rank, size, gathered_tensors),
                )

    @with_comms
    @skip_if_lt_x_gpu(4)
    def test_model_init(self):
        dp_size = 2
        tp_size = self.world_size // dp_size
        mesh_2d = init_device_mesh(
            self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
        )
        dp_mesh = mesh_2d["dp"]
        tp_mesh = mesh_2d["tp"]
        dp_rank = dp_mesh.get_coordinate()[0]
        tp_rank = tp_mesh.get_coordinate()[0]
        self.assertEqual(dp_rank, self.rank // tp_size)
        self.assertEqual(tp_rank, self.rank % tp_size)

        for enable_distribute_flag in [False, True]:
            # a local model on meta device
            model = MLPModule(device="meta")
            # the col-wise parallel style shards the weight over tensor dim 0
            model_tp = parallelize_module(
                model,
                tp_mesh,
                {
                    "net1": ColwiseParallel(output_layouts=Replicate()),
                    "net2": ColwiseParallel(output_layouts=Replicate()),
                },
            )
            # in most cases, the random number generator states is set by data loader
            # in the following way:
            #   - within a tensor parallel group, the RNG is set with the same seed
            #   - across data parallel groups, the RNG is set with different seeds
            torch.cuda.manual_seed(dp_rank)

            # disable/enable parallel RNG feature
            random._rng_tracker.distribute_region_enabled = enable_distribute_flag
            self.assertTrue(model_tp.net1.weight.is_meta)
            # initialize the model's local shard
            model_tp.to_empty(device=self.device_type)
            model_tp.reset_parameters()
            # examine that the weights are initialized adhere to DP/TP
            for dtensor in [model_tp.net1.weight, model_tp.net2.weight]:
                # check within the TP group
                # the 1d mesh represents the TP group
                _1d_mesh = dtensor.device_mesh
                assert _1d_mesh.ndim == 1
                self.assertEqual(_1d_mesh, tp_mesh)

                tensor_local = dtensor.to_local()

                # all-gather local shards
                tensor_gather = funcol.all_gather_tensor(
                    tensor_local,
                    gather_dim=0,
                    group=_1d_mesh,
                )
                self.assertEqual(_1d_mesh.get_coordinate()[0], tp_rank)

                # compare local shards within the TP group
                def tp_weights_assert(tensor1, tensor2):
                    if enable_distribute_flag:
                        # each rank within a TP group shall initialize local weights differently
                        self.assertNotEqual(tensor1, tensor2)
                    else:
                        # without the parallel RNG, weight initialization violates the TP setup:
                        # each rank within a TP group has the same initial weights
                        self.assertEqual(tensor1, tensor2)

                self.check_gathered_tensors(
                    tp_rank, tp_size, tensor_gather, tp_weights_assert
                )

                # check across TP groups
                # all-gather local shards
                tensor_gather = funcol.all_gather_tensor(
                    tensor_local,
                    gather_dim=0,
                    group=dp_mesh,
                )

                # compare local shards across TP groups
                def dp_weights_assert(tensor1, tensor2):
                    if enable_distribute_flag:
                        # local weights shall be initialized the same across TP groups
                        self.assertEqual(tensor1, tensor2)
                    else:
                        # without the parallel RNG, weight initialization violates the TP setup:
                        # local weights are initialized differently across TP groups due to different
                        # random seeds set in data loading.
                        self.assertNotEqual(tensor1, tensor2)

                self.check_gathered_tensors(
                    dp_rank, dp_size, tensor_gather, dp_weights_assert
                )


if __name__ == "__main__":
    run_tests()