File: test_tp_random_state.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (132 lines) | stat: -rw-r--r-- 5,506 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    MLPModule,
    with_comms,
)


class TensorParallelRandomStateTests(DTensorTestBase):
    def get_tensor_slice(self, idx, n, large_tensor):
        shape = large_tensor.shape
        assert shape[0] % n == 0
        local_shape = [shape[0] // n, shape[1]]

        slice_idx = (
            slice(idx * local_shape[0], (idx + 1) * local_shape[0]),
            slice(local_shape[1]),
        )
        return large_tensor[slice_idx]

    def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc):
        for other_rank in range(size):
            if self_rank != other_rank:
                assertFunc(
                    self.get_tensor_slice(self_rank, size, gathered_tensors),
                    self.get_tensor_slice(other_rank, size, gathered_tensors),
                )

    @with_comms
    @skip_if_lt_x_gpu(4)
    def test_model_init(self):
        dp_size = 2
        tp_size = self.world_size // dp_size
        mesh_2d = init_device_mesh(
            self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
        )
        dp_mesh = mesh_2d["dp"]
        tp_mesh = mesh_2d["tp"]
        dp_rank = dp_mesh.get_coordinate()[0]
        tp_rank = tp_mesh.get_coordinate()[0]
        self.assertEqual(dp_rank, self.rank // tp_size)
        self.assertEqual(tp_rank, self.rank % tp_size)

        for enable_distribute_flag in [True, False]:
            # a local model on meta device
            model = MLPModule(device="meta")
            # the col-wise parallel style shards the weight over tensor dim 0
            model_tp = parallelize_module(
                model,
                tp_mesh,
                {
                    "net1": ColwiseParallel(output_layouts=Replicate()),
                    "net2": ColwiseParallel(output_layouts=Replicate()),
                },
            )
            # in most cases, the random number generator states is set by data loader
            # in the following way:
            #   - within a tensor parallel group, the RNG is set with the same seed
            #   - across data parallel groups, the RNG is set with different seeds
            torch.get_device_module(self.device_type).manual_seed(0)

            # disable/enable parallel RNG feature
            if random._rng_tracker:
                random._rng_tracker.distribute_region_enabled = enable_distribute_flag

            self.assertTrue(model_tp.net1.weight.is_meta)
            # initialize the model's local shard
            model_tp.to_empty(device=self.device_type)
            model_tp.reset_parameters()
            # examine that the weights are initialized adhere to DP/TP
            for dtensor in [model_tp.net1.weight, model_tp.net2.weight]:
                # check within the TP group
                # the 1d mesh represents the TP group
                _1d_mesh = dtensor.device_mesh
                assert _1d_mesh.ndim == 1
                self.assertEqual(_1d_mesh, tp_mesh)

                tensor_local = dtensor.to_local()

                # all-gather local shards
                tensor_gather = funcol.all_gather_tensor(
                    tensor_local,
                    gather_dim=0,
                    group=_1d_mesh,
                )
                self.assertEqual(_1d_mesh.get_coordinate()[0], tp_rank)

                # compare local shards within the TP group
                def tp_weights_assert(tensor1, tensor2):
                    if enable_distribute_flag:
                        # each rank within a TP group shall initialize local weights differently
                        self.assertNotEqual(tensor1, tensor2)
                    else:
                        # without the parallel RNG, weight initialization violates the TP setup:
                        # each rank within a TP group has the same initial weights
                        self.assertEqual(tensor1, tensor2)

                self.check_gathered_tensors(
                    tp_rank, tp_size, tensor_gather, tp_weights_assert
                )

                # check across TP groups
                # all-gather local shards
                tensor_gather = funcol.all_gather_tensor(
                    tensor_local,
                    gather_dim=0,
                    group=dp_mesh,
                )

                # compare local shards across TP groups
                def dp_weights_assert(tensor1, tensor2):
                    # local weights shall be initialized the same across TP groups,
                    # and it doesn't matter whether DTensor's RNG infra is activated since all spmd ranks
                    # started with the same seed.
                    self.assertEqual(tensor1, tensor2)

                self.check_gathered_tensors(
                    dp_rank, dp_size, tensor_gather, dp_weights_assert
                )


if __name__ == "__main__":
    run_tests()