File: test_pipe.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (821 lines) | stat: -rw-r--r-- 23,910 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
# Owner(s): ["oncall: distributed"]

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from copy import deepcopy
import time

import pytest
import random
import torch
from torch import nn
from torch import Tensor

from torch.distributed.pipeline.sync import Pipe, NoChunk, WithDevice
from torch.distributed.pipeline.sync.pipe import PipeSequential

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")


def test_pipe_without_rpc():
    model = nn.Sequential(nn.Linear(1, 1))
    with pytest.raises(RuntimeError, match='Please initialize RPC framework'):
        pipe = Pipe(model, chunks=1)

def test_parameters(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    pipe = Pipe(model, chunks=1)
    assert list(pipe.parameters()) != []


def test_public_attrs(setup_rpc):
    class MyString:
        def __init__(self, value):
            self.value = value

        def __str__(self):
            return self.value

    model = nn.Sequential(nn.Linear(1, 1))
    pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always"))

    assert pipe.devices == [torch.device("cpu")]
    assert pipe.chunks == 42
    assert isinstance(pipe.chunks, int)
    assert pipe.checkpoint == "always"
    assert isinstance(pipe.checkpoint, str)


def test_sequential_like(setup_rpc):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model)

    assert len(model) == 2
    assert list(model) == [a, b]

    assert model[0] is a
    assert model[1] is b
    with pytest.raises(IndexError):
        _ = model[2]

    assert model[-1] is b
    assert model[-2] is a

def test_chunks_less_than_1(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(ValueError):
        Pipe(model, chunks=0)

    with pytest.raises(ValueError):
        Pipe(model, chunks=-1)

def test_batch_size_indivisible(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model, chunks=4)

    with pytest.warns(None) as record:
        model(torch.rand(7, 1))

    # Indivisible batch size is legal.
    assert not record


def test_batch_size_small(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model, chunks=4)

    with pytest.warns(None) as record:
        model(torch.rand(2, 1))

    # Batch size smaller than chunks is legal.
    assert not record


def test_checkpoint_mode(setup_rpc):
    def count_grad_fn(grad_fn, name, visited=None):
        if visited is None:
            visited = set()
        if grad_fn in visited:
            return 0
        visited.add(grad_fn)

        if grad_fn is None:
            return 0
        if grad_fn.__class__.__name__ == name:
            return 1

        counter = 0
        for next_grad_fn, _ in grad_fn.next_functions:
            counter += count_grad_fn(next_grad_fn, name, visited=visited)
        return counter

    model = nn.Sequential(nn.Linear(1, 1))
    input = torch.rand(2, 1)

    always = Pipe(model, chunks=2, checkpoint="always")
    except_last = Pipe(model, chunks=2, checkpoint="except_last")
    never = Pipe(model, chunks=2, checkpoint="never")

    always_output = always(input)
    except_last_output = except_last(input)
    never_output = never(input)

    assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2
    assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1
    assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0


def test_checkpoint_mode_invalid(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
        Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT")


def test_checkpoint_mode_when_chunks_1(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))

    # All checkpoint modes are fine.
    Pipe(model, chunks=1, checkpoint="except_last")
    Pipe(model, chunks=1, checkpoint="always")
    Pipe(model, chunks=1, checkpoint="never")


def test_checkpoint_eval(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model, chunks=2)
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward")


def test_checkpoint_non_float_input(setup_rpc):
    class ForkNonFloat(nn.Module):
        def forward(self, input):
            return (input * 2, torch.tensor([False]))

    class JoinNonFloat(nn.Module):
        def forward(self, input, non_float):
            return input * 2

    model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
    model = Pipe(model, chunks=1, checkpoint="always")

    input = torch.rand(1, requires_grad=True)
    output = model(input)
    output.backward()


def test_no_grad(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model, chunks=2)
    input = torch.rand(2, 1)

    latent = None

    def hook(module, input, output):
        _ = module
        _ = input

        nonlocal latent
        latent = output

    partition = model.partitions[0]
    partition.register_forward_hook(hook)

    with torch.no_grad():
        model(input)

    assert latent.grad_fn is None


def test_exception(setup_rpc):
    class ExpectedException(Exception):
        pass

    class Raise(nn.Module):
        def forward(self, *_):
            raise ExpectedException()

    model = nn.Sequential(Raise())
    model = Pipe(model, chunks=1)

    with pytest.raises(ExpectedException):
        model(torch.rand(1))


def test_exception_early_stop_asap(setup_rpc):
    """Even the first partitions have finished to process, the partition before
    the failed partition should be killed as soon as possible.
    """

    class ExpectedException(Exception):
        pass

    class Pass(nn.Module):
        def forward(self, x):
            return x

    counter = 0

    class Counter(nn.Module):
        def forward(self, x):
            time.sleep(0.1)

            nonlocal counter
            counter += 1

            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
    model = Pipe(model, chunks=3)

    with pytest.raises(ExpectedException):
        model(torch.rand(3))

    # If the early stop doesn't work, it would be 3 instead.
    assert counter == 2


def test_nested_input(setup_rpc):
    class NestedInput(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_a = nn.Linear(1, 1)
            self.fc_b = nn.Linear(1, 1)

        def forward(self, inp):
            return inp

    model = nn.Sequential(NestedInput())
    model = Pipe(model, chunks=2)

    a = torch.rand(10, 1, requires_grad=True)
    b = torch.rand(10, 1, requires_grad=True)

    # TypeError: expected Tensor, but got tuple
    with pytest.raises(TypeError):
        model((a, (a, b))).local_value()

    # TypeError: expected Tensor, but got list
    with pytest.raises(TypeError):
        model((a, [a, b])).local_value()


def test_input_pair(setup_rpc):
    class Two(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_a = nn.Linear(1, 1)
            self.fc_b = nn.Linear(1, 1)

        def forward(self, a, b):
            return (self.fc_a(a), self.fc_b(b))

    model = nn.Sequential(Two())
    model = Pipe(model, chunks=2)

    a = torch.rand(10, 1, requires_grad=True)
    b = torch.rand(10, 1, requires_grad=True)

    a_out, b_out = model(a, b).local_value()
    loss = (a_out + b_out).mean()
    loss.backward()

    assert a.grad is not None
    assert b.grad is not None

def test_multi_sequence_input(setup_rpc):
    class MultiSeq(nn.Module):
        def forward(self, tup1, tup2):
            return tup1, tup2

    model = Pipe(nn.Sequential(MultiSeq()))
    with pytest.raises(TypeError):
        model(
            [torch.rand(10), torch.rand(10)],
            [torch.rand(10), torch.rand(10)]
        )

def test_input_singleton(setup_rpc):
    class One(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)

        def forward(self, a):
            return (self.fc(a),)

    model = nn.Sequential(One())
    model = Pipe(model, chunks=2)

    a = torch.rand(10, 1, requires_grad=True)

    (a_out,) = model(a).local_value()
    loss = a_out.mean()
    loss.backward()

    assert all(p.grad is not None for p in model.parameters())
    assert a.grad is not None


def test_input_varargs(setup_rpc):
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model)

    a = torch.rand(1)
    b = torch.rand(1)

    # TypeError: forward() takes 2 positional arguments but 3 were given
    with pytest.raises(TypeError):
        model(a, b)


def test_non_tensor(setup_rpc):
    class NonTensor(nn.Module):
        def forward(self, _):
            return "hello"

    model = nn.Sequential(NonTensor())
    model = Pipe(model)
    x = torch.rand(1)

    with pytest.raises(TypeError):
        model(x)

    with pytest.raises(TypeError):
        model("hello")


def test_non_tensor_sequence(setup_rpc):
    class NonTensorTuple(nn.Module):
        def forward(self, x):
            return (x, "hello")

    class NonTensorArgs(nn.Module):
        def forward(self, x: str, y: bool):
            return x, y

    model = nn.Sequential(NonTensorTuple())
    model = Pipe(model)
    x = torch.rand(1)

    with pytest.raises(TypeError):
        model((x, "hello"))

    with pytest.raises(TypeError):
        model([x, "hello"])

    model = nn.Sequential(NonTensorArgs())
    model = Pipe(model)

    with pytest.raises(TypeError):
        # Need atleast one Tensor.
        model("hello", True)


@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_valid_non_tensor(checkpoint, setup_rpc):
    class NonTensor1(nn.Module):
        def forward(self, a: int, b: Tensor, c: bool, d: Tensor):
            res = b + a if c else b * a
            if d is not None:
                res += d
            return res, c, a, b, "hello", d

    class NonTensor2(nn.Module):
        def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor):
            res = a * c if b else a + c
            res += d
            return c, res, a, d + f if f is not None else d, b, e, f

    model = Pipe(nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint)
    a = random.randint(0, 10)
    b = torch.rand(10, 10)
    c = random.randint(0, 1) == 0
    d = torch.rand(10, 10)
    res = model(a, b, c, d).local_value()
    assert 7 == len(res)
    assert [a] * 5 == res[0]
    if c:
        assert torch.allclose(((b + a + d) * a) + b, res[1])
        assert torch.allclose(b + a + d, res[2])
    else:
        assert torch.allclose(((b * a) + d + a) + b, res[1])
        assert torch.allclose(b * a + d, res[2])
    assert torch.allclose(b + d, res[3])
    assert [c] * 5 == res[4]
    assert ["hello"] * 5 == res[5]
    assert torch.allclose(d, res[6])

    # Test one of the tensors can be None
    res = model(a, b, c, None).local_value()
    assert 7 == len(res)
    assert [a] * 5 == res[0]
    if c:
        assert torch.allclose(((b + a) * a) + b, res[1])
        assert torch.allclose(b + a, res[2])
    else:
        assert torch.allclose(((b * a) + a) + b, res[1])
        assert torch.allclose(b * a, res[2])
    assert torch.allclose(b, res[3])
    assert [c] * 5 == res[4]
    assert ["hello"] * 5 == res[5]
    assert [None] * 5 == res[6]

    # Need atleast one tensor.
    with pytest.raises(TypeError):
        model(a, None, c, None)

@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_no_tensor_output(checkpoint, setup_rpc):
    class Model1(nn.Module):
        def forward(self, a: int, b: Tensor, c: bool):
            return a, c, "hello"

    class Model2(nn.Module):
        def forward(self, a: int, b: bool, c: str):
            return a, c, b

    model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5)
    a = random.randint(0, 10)
    b = torch.rand(10, 10)
    c = random.randint(0, 1) == 0

    # Need atleast one tensor across partitions too.
    with pytest.raises(TypeError):
        res = model(a, b, c).local_value()


@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_uneven_batch_size(checkpoint, setup_rpc):
    class Model(nn.Module):
        def forward(self, a: Tensor, b: int, c: Tensor):
            return a, b, c

    model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
    a = torch.rand(3, 10)
    b = random.randint(0, 10)
    c = torch.rand(6, 10)
    res = model(a, b, c).local_value()
    assert torch.allclose(a, res[0])
    assert [b] * 3 == res[1]  # 3 chunks
    assert torch.allclose(c, res[2])

    # Two tensors producing uneven chunks would fail.
    model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
    a = torch.rand(3, 10)
    b = random.randint(0, 10)
    c = torch.rand(4, 10)

    with pytest.raises(RuntimeError, match='Found different number of chunks'):
        model(a, b, c)

@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_no_chunk(checkpoint, setup_rpc):
    class Model(nn.Module):
        def forward(self, a: Tensor, b: int, c: Tensor):
            return a, b, c

    model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
    a = torch.rand(10, 10)
    b = random.randint(0, 10)
    c = torch.rand(10, 10)
    res = model(a, b, NoChunk(c)).local_value()
    assert torch.allclose(a, res[0])
    assert [b] * 5 == res[1]
    # c gets replicated due to NoChunk and the same tensor gets concatenated 5
    # times in the output.
    assert torch.allclose(torch.cat((c, c, c, c, c)), res[2])

    # Test invalid type for NoChunk
    with pytest.raises(TypeError, match='NoChunk only supported for tensors'):
        NoChunk(b)


@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint, setup_rpc):
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe = Pipe(
        nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).local_value().mean().backward()
    bn(x).mean().backward()

    assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
    assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)


@pytest.mark.parametrize("checkpoint", ["never", "always"])
def test_deferred_batch_norm_params(checkpoint, setup_rpc):
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe = Pipe(
        nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).local_value().mean().backward()
    bn(x).mean().backward()

    assert pipe[0].weight.grad is not None
    assert pipe[0].bias.grad is not None

    assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
    assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)


def test_devices(setup_rpc):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)
    c = nn.Linear(1, 1)

    # There are extra two devices.
    model = nn.Sequential(a, b, c)
    model = Pipe(model)

    cpu = torch.device("cpu")
    # Extra devices must be discarded.
    assert model.devices == [cpu, cpu, cpu]


def test_partitions(setup_rpc):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model)

    assert isinstance(model.partitions, nn.ModuleList)
    assert isinstance(model.partitions[0], nn.Sequential)
    assert isinstance(model.partitions[1], nn.Sequential)

    assert "partitions.0.0.weight" in model.state_dict()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_merged_partitions(setup_rpc):
    a = nn.Linear(1, 1).to(0)
    b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0)
    c = nn.Linear(1, 1)
    d = nn.Linear(1, 2)

    model = nn.Sequential(a, b, c, d)
    model = Pipe(model)

    assert isinstance(model.partitions, nn.ModuleList)
    assert isinstance(model.partitions[0], PipeSequential)
    assert isinstance(model.partitions[1], PipeSequential)
    assert list(model.partitions[0]) == [a, b[0], b[1]]
    assert list(model.partitions[1]) == [c]
    assert list(model.partitions[2]) == [d]


def test_deny_moving(setup_rpc):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model)

    # Moving is denied.
    with pytest.raises(TypeError):
        model.cuda()

    with pytest.raises(TypeError):
        model.cpu()

    with pytest.raises(TypeError):
        model.to(torch.device("cuda"))

    with pytest.raises(TypeError):
        model.to(0)

    with pytest.raises(TypeError):
        model.to("cuda")

    with pytest.raises(TypeError):
        model.to(device=0)

    with pytest.raises(TypeError):
        model.to(torch.rand(1))

    with pytest.raises(TypeError):
        model.to(tensor=torch.rand(1))

    # Casting is allowed.
    model.half()
    model.to(torch.double)
    model.to(dtype=torch.float)


def test_empty_module(setup_rpc):
    # Empty sequential module is not illegal.
    model = nn.Sequential()
    model = Pipe(model)

    assert model(torch.tensor(42)).local_value() == torch.tensor(42)

    # But only tensor or tensors is legal in Pipe.
    with pytest.raises(TypeError):
        model(42)


def test_named_children(setup_rpc):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
    model = Pipe(model)

    names = set(n for n, _ in model.named_modules())
    assert "partitions.0.0" in names
    assert "partitions.1.0" in names

    # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
    # several methods in its namespace.
    with pytest.raises(AttributeError):
        model.a


def test_verify_module_non_sequential(setup_rpc):
    with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
        Pipe(nn.Module())


def test_verify_module_duplicate_children(setup_rpc):
    conv = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(conv, conv)

    with pytest.raises(ValueError, match="module with duplicate children is not supported"):
        Pipe(model)


@skip_if_no_cuda
def test_verify_module_params_on_same_device(setup_rpc):
    class Surrogate(nn.Module):
        def __init__(self, param1, param2):
            super().__init__()
            self.param1 = param1
            self.param2 = param2

    conv1 = nn.Conv2d(3, 3, 1)
    conv2 = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(Surrogate(conv1, conv2.cuda()))

    with pytest.raises(
        ValueError,
        match=r'should have all parameters on a single device, please use .to\(\)'
            ' to place the module on a single device'):
        Pipe(model)

@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_verify_nested_modules(setup_rpc):
    model = nn.Sequential(
        nn.Sequential(
            nn.Linear(32, 16).cuda(0),
            nn.Linear(16, 8).cuda(0)
        ),
        nn.Sequential(
            nn.Linear(8, 4).cuda(1),
            nn.Linear(4, 2).cuda(1)
        ),
    )

    pipe = Pipe(model)
    out = pipe(torch.rand(10, 32).cuda(0))
    assert out.local_value().device == torch.device("cuda:1")
    assert out.local_value().size() == torch.Size([10, 2])

def test_verify_module_duplicate_parameters_on_same_device(setup_rpc):
    class Surrogate(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

    conv = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(Surrogate(conv), Surrogate(conv))

    Pipe(model)


def test_forward_lockstep(setup_rpc):
    timeline = []

    class DelayedLog(nn.Module):
        def __init__(self, j, seconds):
            super().__init__()
            self.i = 0
            self.j = j
            self.seconds = seconds

        def forward(self, x):
            time.sleep(self.seconds)

            timeline.append((self.i, self.j))
            self.i += 1

            return x

    model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
    model = Pipe(model, chunks=3)
    model(torch.rand(3, 1))

    # Expected timeline: (Logs are recorded at !)
    #
    # Partition #0: 0! 1!   2!
    # Partition #1:    000! 111! 222!
    #
    assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]

@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@skip_if_no_cuda
def test_multiple_inputs(checkpoint, setup_rpc):
    class Module1(nn.Module):
        def forward(self, a, b, c):
            return a + b + c, a * b * c

    class Module2(nn.Module):
        def forward(self, a, b):
            return a + b

    model = Pipe(nn.Sequential(Module1().cuda(0), Module2().cuda(0)), chunks=2, checkpoint=checkpoint)
    t = torch.rand(10)
    res = model(t, t, t).local_value()
    assert torch.equal(res, (t + t + t) + (t * t * t))

@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_inputs_wrong_device(setup_rpc):
    class Module1(nn.Module):
        def __init__(self):
            super().__init__()
            self.param = torch.nn.Parameter(torch.rand(5))

        def forward(self, a, b):
            return a + b + self.param, b

    # Start inputs on wrong device and ensure Pipe moves them correctly.
    a = torch.rand(10).cuda(1)
    b = torch.rand(10).cuda(1)
    model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2)
    with pytest.raises(ValueError, match='All inputs should be on the same device as the first partition'):
        model(a, b)

@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_with_device_wrapper(setup_rpc):
    fc1 = nn.Linear(16, 8).cuda(0)
    fc2 = nn.Linear(8, 4).cuda(1)
    dropout = nn.Dropout()

    model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
    model = Pipe(model, chunks=8)
    assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
    assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices

    model = nn.Sequential(fc1, WithDevice(dropout, 'cuda:1'))
    model = Pipe(model, chunks=8)
    assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
    assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices

    model = nn.Sequential(fc1, WithDevice(fc2, 'cuda:0'))
    model = Pipe(model, chunks=8)
    assert torch.device('cuda:0') == model(torch.rand(16, 16).cuda(0)).local_value().device
    assert [torch.device('cuda:0')] == model.devices
    assert torch.device('cuda:0') == fc2.weight.device