File: test_linear.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (274 lines) | stat: -rw-r--r-- 11,482 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Owner(s): ["oncall: distributed"]

import copy
import sys

import torch
import torch.distributed as dist
from torch.distributed._shard.api import (
    shard_parameter,
    _collect_local_shard,
    _reshard_output,
)
from torch.distributed._shard.sharded_optim import (
    ShardedOptimizer,
)
from torch.distributed._shard.sharded_tensor import (
    empty,
)
from torch.distributed._shard.sharding_spec import (
    ChunkShardingSpec,
    EnumerableShardingSpec,
    ShardMetadata,
)
from torch.testing._internal.common_distributed import (
    requires_nccl,
    skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
    TEST_GPU_NUM,
    ShardedTensorTestBase,
    with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
    clone_module_parameter,
    generate_chunk_sharding_specs_for_test,
    generate_local_weight_sharding_params_for_test,
)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestShardedTensorOpsLinear(ShardedTensorTestBase):
    def _run_sharded_linear(
        self, spec, input_size, linear_size, sharded_dim, dtype
    ):
        # Use same seed.
        torch.manual_seed(0)
        local_linear = torch.nn.Linear(*linear_size, dtype=dtype).cuda(self.rank)
        sharded_linear = torch.nn.Linear(*linear_size, dtype=dtype)

        # Copy the weights and bias from local linear
        sharded_linear.weight = clone_module_parameter(local_linear, "weight")
        sharded_linear.bias = clone_module_parameter(local_linear, "bias")

        # Shard the parameter.
        shard_parameter(sharded_linear, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size, dtype=dtype).cuda(self.rank)
        reshard_spec = copy.deepcopy(spec)
        reshard_spec.dim = 0
        reshard_spec.placements.sort(key=lambda placement: placement.rank())
        sharded_linear = _collect_local_shard(
            _reshard_output(sharded_linear, reshard_spec)
        )
        sharded_output = sharded_linear(inp)

        # Run local computation
        local_output = local_linear(inp)

        # Verify
        self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3)

        # Validate for torch.nn.functional.linear version.
        local_output = torch.nn.functional.linear(
            inp, local_linear.weight, local_linear.bias
        )
        sharded_output = torch.nn.functional.linear(
            inp, sharded_linear.weight, sharded_linear.bias
        )
        sharded_output = sharded_output.reshard(reshard_spec).local_tensor()
        # When local tensor only has one dimension, we increase one more dimension
        # for reshard. We need to squeeze the # of dimensions manually.
        if inp.dim() == 1:
            sharded_output = sharded_output.squeeze(reshard_spec.dim)
        self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3)

        # Compute loss and run backward pass.
        local_output.sum().backward()
        sharded_output.sum().backward()
        local_grad = local_linear.weight.grad

        # Verify that both weight and bias in the sharded linear has non-None grad.
        sharded_weight = sharded_linear.weight.local_tensor()
        self.assertNotEqual(sharded_linear.bias.grad, None)
        self.assertNotEqual(sharded_weight.grad, None)

        # Shard the local linear's weight grad so that we can compare.
        dist.all_reduce(local_grad)
        (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test(
            local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank
        )
        local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size)
        local_bias_grad = local_linear.bias.grad
        dist.all_reduce(local_bias_grad)

        # Test backward gradient calculation.
        self.assertEqual(sharded_linear.bias.grad, local_bias_grad, atol=1e-3, rtol=1e-3)
        self.assertEqual(sharded_weight.grad, local_grad_narrowed, atol=1e-3, rtol=1e-3)

        # Test optimizer.
        previous = local_linear.weight.clone().detach()
        optim = torch.optim.SGD(local_linear.parameters(), lr=0.1)
        optim.step()
        self.assertNotEqual(previous, local_linear.weight)
        previous_sharded_weight = sharded_weight.clone()
        previous_sharded_bias = sharded_linear.bias.clone()
        sharded_optim = ShardedOptimizer(
            dict(sharded_linear.named_parameters()),
            torch.optim.SGD,
            lr=0.1,
        )
        sharded_optim.step()
        sharded_weight = sharded_linear.weight.local_tensor()
        local_weight_narrowed = local_linear.weight.narrow(
            sharded_dim, start_pos, chunk_size
        )
        self.assertEqual(sharded_weight.size(), local_weight_narrowed.size())
        self.assertNotEqual(previous_sharded_weight, sharded_weight)
        self.assertEqual(sharded_weight, local_weight_narrowed, atol=1e-3, rtol=1e-3)
        self.assertNotEqual(previous_sharded_bias, sharded_linear.bias)
        self.assertEqual(sharded_linear.bias, local_linear.bias, atol=1e-3, rtol=1e-3)

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_sharded_linear_colwise(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            self._run_sharded_linear(spec, [2, 17], [17, 12], 0, torch.float16)
            self._run_sharded_linear(spec, [8, 21], [21, 11], 0, torch.float32)
            self._run_sharded_linear(spec, [7, 23], [23, 13], 0, torch.float64)
            self._run_sharded_linear(spec, [4, 15], [15, 14], 0, torch.float16)

            # Test multiple input dims
            self._run_sharded_linear(spec, [10, 2, 17], [17, 12], 0, torch.float32)
            self._run_sharded_linear(spec, [13, 8, 21], [21, 11], 0, torch.float64)
            self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0, torch.float16)
            self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0, torch.float32)

            # Test single input dim
            self._run_sharded_linear(spec, [17], [17, 12], 0, torch.float64)
            self._run_sharded_linear(spec, [21], [21, 11], 0, torch.float16)
            self._run_sharded_linear(spec, [23], [23, 13], 0, torch.float32)
            self._run_sharded_linear(spec, [15], [15, 14], 0, torch.float64)

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_sharded_linear_rowwise(self):
        for spec in generate_chunk_sharding_specs_for_test(1):
            # Test even split.
            self._run_sharded_linear(spec, [8, 16], [16, 11], 1, torch.float16)

            # Test uneven split.
            self._run_sharded_linear(spec, [5, 19], [19, 11], 1, torch.float32)
            self._run_sharded_linear(spec, [10, 21], [21, 11], 1, torch.float64)

            # Test multiple input dims
            self._run_sharded_linear(spec, [13, 8, 16], [16, 11], 1, torch.float16)
            self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1, torch.float32)
            self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1, torch.float64)

            # Test single input dim
            self._run_sharded_linear(spec, [16], [16, 11], 1, torch.float16)
            self._run_sharded_linear(spec, [19], [19, 11], 1, torch.float32)
            self._run_sharded_linear(spec, [21], [21, 11], 1, torch.float64)

    @with_comms(init_rpc=False)
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
    @requires_nccl()
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "weight", spec)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile
            error_msg = r"torch function 'linear', with args: (?s).* "
            r"and kwargs: None not supported for ShardedTensor!"
            with self.assertRaisesRegex(RuntimeError, error_msg):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))


if __name__ == "__main__":
    run_tests()