File: test_matrix_ops.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (618 lines) | stat: -rw-r--r-- 24,189 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]

import itertools
import unittest
from typing import cast, Optional

import torch
import torch.nn.functional as F
from torch.distributed import init_device_mesh
from torch.distributed.tensor import (
    distribute_tensor,
    DTensor,
    Partial,
    Placement,
    Replicate,
    Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TEST_WITH_ROCM,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    skip_unless_torch_gpu,
    with_comms,
)


funcol = torch.ops.c10d_functional


def scale_for_fp8(
    t: torch.Tensor, scale_shape: tuple[int]
) -> tuple[torch.Tensor, torch.Tensor]:
    if all(d == 1 for d in scale_shape):
        t = t.unsqueeze(0).unsqueeze(-2)
    else:
        t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1))

    scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS
    t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type)

    return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape)


class DistMatrixOpsTest(DTensorTestBase):
    @with_comms
    def test_addmm(self):
        device_mesh = self.build_device_mesh()
        shard_spec = [Shard(0)]
        replica_spec = [Replicate()]

        tensor_to_shard = torch.randn(12, 8)
        mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
        tensor_to_replicate = torch.randn(8, 4)
        mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
        input_tensor = torch.randn(4)
        input = distribute_tensor(input_tensor, device_mesh, replica_spec)

        dist_res = torch.addmm(input, mat1, mat2)
        local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
        self.assertEqual(dist_res.full_tensor(), local_res)

    @with_comms
    def test_addmm_empty_operand(self):
        device_mesh = self.build_device_mesh()
        shard_spec = [Shard(0)]
        replica_spec = [Replicate()]

        tensor_to_shard = torch.randn(12, 0)
        mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
        tensor_to_replicate = torch.randn(0, 4)
        mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
        input_tensor = torch.randn(4)
        inp = distribute_tensor(input_tensor, device_mesh, replica_spec)

        dist_res = torch.addmm(inp, mat1, mat2)
        local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
        self.assertEqual(dist_res.full_tensor(), local_res)

    @with_comms
    def test_addmm_auto_redistribute(self):
        device_mesh = self.build_device_mesh()
        shard0_spec = [Shard(0)]
        shard1_spec = [Shard(1)]
        replica_spec = [Replicate()]

        tensor_to_shard1 = torch.randn(12, 8, requires_grad=True)
        mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec)
        tensor_to_shard0 = torch.randn(8, 4, requires_grad=True)
        mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec)
        input_tensor = torch.randn(4, requires_grad=True)
        input = distribute_tensor(input_tensor, device_mesh, replica_spec)

        local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0)
        dist_res = torch.addmm(input, mat1, mat2)

        # test if addmm output is a partial
        self.assertIsInstance(dist_res, DTensor)
        self.assertIsInstance(dist_res.placements[0], Partial)

        # test if result is the same as tensor
        dist_local_res = dist_res.full_tensor()
        self.assertEqual(local_res, dist_local_res)

        # backward checks
        dist_local_res.sum().backward()
        local_res.sum().backward()
        self.assertIsNotNone(mat2.grad)
        self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad)

    @with_comms
    def test_mm(self):
        device_mesh = self.build_device_mesh()
        shard0_spec = Shard(0)
        shard1_spec = Shard(1)
        replica_spec = Replicate()

        t1 = torch.randn(12, 8, requires_grad=True)
        t2 = torch.randn(8, 16, requires_grad=True)
        local_res = torch.mm(t1, t2)

        def test_placement_comb(
            placements1: list[Placement], placements2: list[Placement]
        ) -> None:
            dt1 = distribute_tensor(t1, device_mesh, placements1)
            dt2 = distribute_tensor(t2, device_mesh, placements2)
            dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(
                device_mesh, [replica_spec]
            )
            self.assertEqual(dist_res.to_local(), local_res)
            # backward
            grad_dist_res = torch.ones_like(dist_res)
            dist_res.backward(grad_dist_res)
            self.assertIsNotNone(dt1.grad)

        placement_specs = [shard0_spec, shard1_spec, replica_spec]
        shard_specs_comb = list(itertools.product(placement_specs, placement_specs))
        for spec in shard_specs_comb:
            test_placement_comb([spec[0]], [spec[1]])

    @with_comms
    @skip_unless_torch_gpu
    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FP8,
        "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
    )
    def test_scaled_mm(self):
        device_mesh = self.build_device_mesh()
        shrd0 = Shard(0)
        shrd1 = Shard(1)
        repl = Replicate()
        part = Partial()

        ws = self.world_size
        # _scaled_mm requires all dimensions to be multiples of 16. Since we'll
        # shard along n and k, we need to ensure this stays true on each rank.
        m, n, k = 16, 32 * ws, 16 * ws

        t1 = torch.randn(m, k, device=self.device_type, dtype=torch.bfloat16)
        t2 = torch.randn(n, k, device=self.device_type, dtype=torch.bfloat16)

        for (
            output_spec,
            t1_spec,
            t2_spec,
            scale1_shape,
            scale2_shape,
            scale1_spec,
            scale2_spec,
        ) in [
            # Tensor-wise scaling
            # Replicated, zero-dim scale
            (repl, repl, repl, (), (), repl, repl),
            # Column-parallel, two-dim scale
            (shrd1, repl, shrd0, (1, 1), (1, 1), repl, repl),
            # Row-parallel, one-dim scale
            (part, shrd1, shrd1, (1,), (1,), repl, repl),
            # Row-wise scaling
            # Replicated
            (repl, repl, repl, (m, 1), (n, 1), repl, repl),
            # Column-parallel
            (shrd1, repl, shrd0, (m, 1), (n, 1), repl, shrd0),
            # Row-parallel (which actually ends up doing sub-row-wise scaling)
            (part, shrd1, shrd1, (m, ws), (n, ws), shrd1, shrd1),
        ]:
            full_ref_res = t1 @ t2.t()

            t1_fp8, scale1 = scale_for_fp8(t1, scale1_shape)
            t2_fp8, scale2 = scale_for_fp8(t2, scale2_shape)

            dist_t1_fp8 = distribute_tensor(t1_fp8, device_mesh, [t1_spec])
            dist_t2_fp8 = distribute_tensor(t2_fp8, device_mesh, [t2_spec])
            dist_scale1 = distribute_tensor(scale1, device_mesh, [scale1_spec])
            dist_scale2 = distribute_tensor(scale2, device_mesh, [scale2_spec])

            with CommDebugMode() as comm_mode:
                dist_res = cast(
                    DTensor,
                    torch._scaled_mm(
                        dist_t1_fp8,
                        dist_t2_fp8.t(),
                        scale_a=dist_scale1,
                        scale_b=dist_scale2.t(),
                        out_dtype=torch.bfloat16,
                    ),
                )

            self.assertEqual(dist_res.placements[0], output_spec)

            full_dist_res = dist_res.full_tensor()
            # Fp8 matmuls are quite inaccurate, we need high tolerances
            self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2)

            self.assertEqual(comm_mode.get_total_counts(), 0)

    @with_comms
    def test_matmul(self):
        device_mesh = self.build_device_mesh()
        dim = 128
        x = torch.randn(8, dim)
        A = torch.randn(dim, dim)
        y = torch.matmul(x, A)

        # Prepare DTensors
        dx = distribute_tensor(x, device_mesh, [Replicate()])
        dA = distribute_tensor(A, device_mesh, [Shard(0)])

        # Use `inference_mode` to test DTensor's capability of decomposing
        # `matmul` op
        with torch.inference_mode():
            dy = torch.matmul(dx, dA)

        self.assertEqual(y, dy.full_tensor())

    @with_comms
    def test_t(self):
        device_mesh = self.build_device_mesh()
        shard_spec = [Shard(0)]

        tensor_to_transpose = torch.randn(12, 8, requires_grad=True)
        mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec)
        tranposed_mat = mat.t()
        self.assertEqual(tranposed_mat.size(), torch.Size([8, 12]))
        self.assertEqual(tranposed_mat.placements, [Shard(1)])
        tranposed_mat2 = tranposed_mat.t()
        self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8]))
        self.assertEqual(tranposed_mat2.placements, shard_spec)

    @with_comms
    def test_t_partial(self):
        device_mesh = self.build_device_mesh()

        a = torch.randn(12, 8)
        b = torch.randn(8, 4)
        c = torch.mm(a, b).t()

        da = distribute_tensor(a, device_mesh, [Shard(1)])
        db = distribute_tensor(b, device_mesh, [Shard(0)])

        # mm(da, db) should return a Partial tensor.
        # transposing it should keep it Partial
        dc = torch.mm(da, db).t()

        self.assertTrue(isinstance(dc.placements[0], Partial))

        # check that the local and distributed op results match
        self.assertEqual(
            c,
            dc.redistribute(device_mesh, [Replicate()]).to_local(),
        )

    # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
    @with_comms
    @skip_unless_torch_gpu
    def test_baddbmm(self):
        device_mesh = self.build_device_mesh()
        tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
        batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
        batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)

        def test_placement_comb(
            tensor_placements: list[Placement],
            batch_1_placements: list[Placement],
            batch_2_placements: list[Placement],
            beta: int,
            alpha: int,
            batch_1_grad: Optional[torch.Tensor],
        ) -> None:
            tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements)
            batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements)
            batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements)
            dist_res = cast(
                DTensor,
                torch.baddbmm(
                    tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha
                ),
            ).redistribute(device_mesh, [Replicate()])
            dist_local_res = dist_res.to_local()
            assert not torch.isnan(local_result).any()
            assert not torch.isnan(dist_local_res).any()
            self.assertEqual(dist_local_res.detach(), local_result.detach())

            # TODO: add test backward
            # grad_dist_res = torch.ones_like(dist_res)
            # dist_res.backward(grad_dist_res)
            # self.assertIsNotNone(batch_1_dt.grad)
            # batch_1_grad_local = batch_1_dt.grad.redistribute(
            #     device_mesh, [Replicate()]
            # ).to_local()
            # self.assertEqual(batch_1_grad_local, batch_1_grad)

        shard0_spec = Shard(0)
        shard1_spec = Shard(1)
        shard2_spec = Shard(2)
        replica_spec = Replicate()
        shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
        shard_specs_comb = list(
            itertools.product(shard_specs, shard_specs, shard_specs)
        )
        # If beta is 0, input tensor will be ignored
        numeric_params_comb = [
            (0.0, 0.5),  # zero-beta
            (0.8, 0.5),  # non-zero-beta
        ]

        for beta, alpha in numeric_params_comb:
            local_result = torch.baddbmm(
                tensor, batch_1, batch_2, beta=beta, alpha=alpha
            )
            grad_local_res = torch.ones_like(local_result)
            local_result.backward(grad_local_res)
            # test all combos
            for spec in shard_specs_comb:
                test_placement_comb(
                    [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad
                )

    @with_comms
    def test_bmm(self):
        device_mesh = self.build_device_mesh()
        mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True)
        mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
        local_result = torch.bmm(mat1, mat2)
        grad_local_res = torch.ones_like(local_result)
        local_result.backward(grad_local_res)

        def test_placement_comb(
            placements1: list[Placement],
            placements2: list[Placement],
        ) -> None:
            mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
            mat2_dt = distribute_tensor(mat2, device_mesh, placements2)
            dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute(
                device_mesh, [Replicate()]
            )
            dist_local_res = dist_res.to_local()
            self.assertEqual(dist_local_res, local_result)

            # test backward
            # TODO: figure out (replicate, shard1) fail on backward
            # it generates a different grad shape
            grad_dist_res = torch.ones_like(dist_res)
            dist_res.backward(grad_dist_res)
            self.assertIsNotNone(mat1_dt.grad)
            mat1_dt_grad = cast(DTensor, mat1_dt.grad)
            mat1_grad_local = mat1_dt_grad.redistribute(
                device_mesh, [Replicate()]
            ).to_local()
            self.assertEqual(mat1_grad_local, mat1.grad)

        shard0_spec = Shard(0)
        shard1_spec = Shard(1)
        shard2_spec = Shard(2)
        replica_spec = Replicate()
        placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
        shard_specs_comb = list(itertools.product(placement_specs, placement_specs))

        # tests that currently pass
        for spec in shard_specs_comb:
            test_placement_comb([spec[0]], [spec[1]])

    @with_comms
    @skip_unless_torch_gpu
    def test_scaled_dot_product_attention(self):
        device_mesh = self.build_device_mesh()
        comm_mode = CommDebugMode()
        # bsz, n_heads, slen, head_dim
        query = torch.rand(
            (4, 8, 8, 8),
            device=self.device_type,
            dtype=torch.bfloat16,
            requires_grad=True,
        )
        key = torch.rand(
            (4, 8, 8, 8),
            device=self.device_type,
            dtype=torch.bfloat16,
            requires_grad=True,
        )
        value = torch.rand(
            (4, 8, 8, 8),
            device=self.device_type,
            dtype=torch.bfloat16,
            requires_grad=True,
        )

        from torch.nn.attention import sdpa_kernel, SDPBackend

        available_backends = []
        dropout_p = 0.0
        # TODO: Add test cases where is_causal=False and an attention mask is provided.
        #       Gaps include missing op support for aten.masked_fill_.Scalar.
        is_causal = True
        enable_gqa = False
        params = torch.backends.cuda.SDPAParams(
            query, key, value, None, dropout_p, is_causal, enable_gqa
        )
        if torch.backends.cuda.can_use_flash_attention(params, debug=False):
            available_backends.append(SDPBackend.FLASH_ATTENTION)
        if torch.backends.cuda.can_use_efficient_attention(params, debug=False):
            available_backends.append(SDPBackend.EFFICIENT_ATTENTION)

        placement_specs = [(Replicate(),), (Shard(0),), (Shard(1),)]
        for backend, input_placements in itertools.product(
            available_backends, placement_specs
        ):
            dist_query = distribute_tensor(query, device_mesh, input_placements)
            dist_key = distribute_tensor(key, device_mesh, input_placements)
            dist_value = distribute_tensor(value, device_mesh, input_placements)
            with sdpa_kernel(backends=[backend]):
                out = F.scaled_dot_product_attention(
                    query, key, value, dropout_p=dropout_p, is_causal=is_causal
                )
                with comm_mode:
                    dist_out = F.scaled_dot_product_attention(
                        dist_query,
                        dist_key,
                        dist_value,
                        dropout_p=dropout_p,
                        is_causal=is_causal,
                    )
                    self.assertEqual(comm_mode.get_total_counts(), 0)
                    self.assertEqual(dist_out.placements, input_placements)
                    self.assertEqual(dist_out.full_tensor(), out)

                out.sum().backward()
                with comm_mode:
                    dist_out.sum().backward()
                    self.assertEqual(comm_mode.get_total_counts(), 0)
                    self.assertEqual(dist_query.grad.placements, input_placements)
                    self.assertEqual(dist_query.grad.full_tensor(), query.grad)
                    self.assertEqual(dist_key.grad.placements, input_placements)
                    self.assertEqual(dist_key.grad.full_tensor(), key.grad)
                    self.assertEqual(dist_value.grad.placements, input_placements)
                    self.assertEqual(dist_value.grad.full_tensor(), value.grad)
                    query.grad.zero_()
                    key.grad.zero_()
                    value.grad.zero_()

    @skip_unless_torch_gpu
    @with_comms()
    def test_dtensor_mm(self):
        """
        Test mm with DTensor with 2D mesh.
        We need to add the test here since we only test 1D mesh in test_dtensor_ops.py.
        Also, we added tests for the corner case where one of the 2D dimension is 1.

        # TODO: we need to test more DTensor ops with 2D mesh, especially when 1 of the
        mesh dimension of the 2D mesh is 1.
        """
        mesh_0 = init_device_mesh(self.device_type, (self.world_size // 2, 2))
        mesh_1 = init_device_mesh(self.device_type, (self.world_size, 1))
        mesh_2 = init_device_mesh(self.device_type, (1, self.world_size))

        for mesh in [mesh_0, mesh_1, mesh_2]:
            lhs = torch.randn(256, 128)
            rhs = torch.randn(128, 256)
            mm_result = lhs @ rhs

            lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()])
            rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)])
            dtensor_result = lhs_dtensor @ rhs_dtensor
            self.assertEqual(
                dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6
            )

    @with_comms
    @skip_unless_torch_gpu
    def test_tensordot_shampoo(self):
        """
        Create a simple test for Shampoo's use case.
        """
        device_mesh = self.build_device_mesh()

        local_a = torch.randn(4, 4)
        local_b = torch.randn(4, 15)
        dims = ([0], [0])
        local_result = torch.tensordot(local_a, local_b, dims=(dims))

        placements = [Replicate(), Shard(0), Shard(1)]
        placements_tuples = itertools.product(placements, repeat=2)

        for placement1, placement2 in placements_tuples:
            dist_a = distribute_tensor(local_a, device_mesh, [placement1])
            dist_b = distribute_tensor(local_b, device_mesh, [placement2])
            dist_result = torch.tensordot(dist_a, dist_b, dims=dims)
            dist_result_full = dist_result.full_tensor()
            self.assertEqual(local_result, dist_result_full)

    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
    @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
    @with_comms
    @skip_unless_torch_gpu
    @parametrize(
        "kwargs",
        [
            {
                # 2D x 3D case from MoE layer
                "inp_shape": (64, 16),
                "w1_shape": (2, 16, 32),
                "w2_shape": (2, 32, 16),
                "inp_placements": [Replicate()],
                "w1_placements": [Shard(2)],
                "w2_placements": [Shard(1)],
                "expected_comm_counts_fwd": 0,
                "expected_comm_counts_bwd": 1,
                "expected_out_placements": [Partial()],
            },
            {
                # Case that would have invalid strides on inp * mat1 when sharded
                "inp_shape": (64, 16),
                "w1_shape": (2, 16, 16),
                "w2_shape": (2, 16, 16),
                "inp_placements": [Replicate()],
                "w1_placements": [Shard(2)],
                "w2_placements": [Shard(1)],
                "expected_comm_counts_fwd": 2,
                "expected_comm_counts_bwd": 4,
                "expected_out_placements": [Replicate()],
            },
        ],
    )
    def test_grouped_mm(self, kwargs):
        # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D)
        # More tests need to be added.
        device_mesh = self.build_device_mesh()
        comm_mode = CommDebugMode()
        dtype = torch.bfloat16
        inp = torch.rand(
            *kwargs["inp_shape"],
            device=self.device_type,
            dtype=dtype,
            requires_grad=True,
        )
        w1 = torch.rand(
            *kwargs["w1_shape"],
            device=self.device_type,
            dtype=dtype,
            requires_grad=True,
        )
        w2 = torch.rand(
            *kwargs["w2_shape"],
            device=self.device_type,
            dtype=dtype,
            requires_grad=True,
        )
        offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32)

        h = torch._grouped_mm(inp, w1, offs=offs)
        out = torch._grouped_mm(h, w2, offs=offs)

        dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"])
        # colwise sharded
        dist_w1 = distribute_tensor(w1, device_mesh, kwargs["w1_placements"])
        # rowwise sharded
        dist_w2 = distribute_tensor(w2, device_mesh, kwargs["w2_placements"])
        dist_offs = distribute_tensor(offs, device_mesh, [Replicate()])

        with comm_mode:
            dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs)
            dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs)
            self.assertEqual(
                comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"]
            )
            self.assertEqual(dist_out.placements, kwargs["expected_out_placements"])
            self.assertEqual(dist_out.full_tensor(), out)

        out_grad = torch.ones_like(out)
        out.backward(out_grad)

        dist_out = dist_out.redistribute(device_mesh, [Shard(0)])
        dist_out_grad = distribute_tensor(out_grad, device_mesh, [Shard(0)])

        with comm_mode:
            dist_out.backward(dist_out_grad)
            self.assertEqual(
                comm_mode.get_total_counts(), kwargs["expected_comm_counts_bwd"]
            )
            self.assertEqual(
                comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
                kwargs["expected_comm_counts_bwd"],
            )
        self.assertEqual(dist_inp.grad.full_tensor(), inp.grad)
        self.assertEqual(dist_w1.grad.full_tensor(), w1.grad)
        self.assertEqual(dist_w2.grad.full_tensor(), w2.grad)


instantiate_parametrized_tests(DistMatrixOpsTest)

if __name__ == "__main__":
    run_tests()