File: test_compiled_optimizers.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (871 lines) | stat: -rw-r--r-- 31,109 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
# Owner(s): ["module: inductor"]

import sys
import unittest
import weakref
from contextlib import ExitStack
from copy import deepcopy
from typing import NamedTuple

import torch
import torch._inductor
import torch._inductor.cudagraph_trees
import torch.optim.lr_scheduler
from torch._inductor import config
from torch._inductor.test_case import TestCase
from torch.optim import (
    Adadelta,
    Adagrad,
    Adam,
    Adamax,
    AdamW,
    ASGD,
    NAdam,
    RAdam,
    RMSprop,
    Rprop,
    SGD,
    SparseAdam,
)
from torch.optim.lr_scheduler import (
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    MultiplicativeLR,
    MultiStepLR,
    OneCycleLR,
    PolynomialLR,
    ReduceLROnPlateau,
    StepLR,
)
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    skipCUDAIf,
    skipXPUIf,
)
from torch.testing._internal.common_optimizers import (
    _get_optim_inputs_including_global_cliquey_kwargs,
    optim_db,
    optims,
)
from torch.testing._internal.common_utils import parametrize
from torch.testing._internal.inductor_utils import (
    GPU_TYPE,
    HAS_CPU,
    HAS_GPU,
    has_triton,
)
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu


# Note: we use atypical values to amplify error
LR_SCHEDULER_TO_KWARGS = {
    LambdaLR: {"lr_lambda": lambda x: 10},
    MultiplicativeLR: {"lr_lambda": lambda x: 10},
    StepLR: {"step_size": 1, "gamma": 100},
    MultiStepLR: {"milestones": [1, 2], "gamma": 100},
    ExponentialLR: {"gamma": 100},
    CosineAnnealingLR: {"T_max": 7},
    # These schedulers have memory leaks in eager
    # https://github.com/pytorch/pytorch/issues/126131
    # SequentialLR: {"schedulers": None, "milestones": [1, 2]},
    # ChainedScheduler: {"schedulers": None},
    CyclicLR: {"base_lr": 0.001, "max_lr": 0.02, "cycle_momentum": False},
    CosineAnnealingWarmRestarts: {"T_0": 1},
    OneCycleLR: {
        "max_lr": 0.02,
        "cycle_momentum": False,
        "steps_per_epoch": 1,
        "epochs": 10,
    },
    ConstantLR: {"factor": 0.001},
    LinearLR: {},
    ReduceLROnPlateau: {"factor": 0.99, "patience": 1},
    PolynomialLR: {},
}


def create_scheduler(scheduler, optim):
    kwargs = LR_SCHEDULER_TO_KWARGS[scheduler]
    if "schedulers" in kwargs:
        kwargs["schedulers"] = [
            create_scheduler(torch.optim.lr_scheduler.ConstantLR, optim)
            for _ in range(2)
        ] + [create_scheduler(torch.optim.lr_scheduler.LambdaLR, optim)]

    if scheduler == ChainedScheduler:
        return scheduler(**kwargs)
    else:
        return scheduler(optim, **kwargs)


class KernelCounts(NamedTuple):
    multitensor: int
    singletensor: int


# With different settings for certain
# tests you can get different kernel counts
# This maps the test name to the
# expected kernel count
KERNEL_COUNT_OVERRIDES = {
    "test_rmsprop_foreach_weight_decay_cpu": 12,
    "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
    "test_adamw_amsgrad_capturable_foreach_cuda": 3,
    "test_adamw_amsgrad_capturable_foreach_xpu": 3,
    "test_adamw_amsgrad_capturable_cuda": 6,
    "test_adamw_amsgrad_capturable_xpu": 6,
    "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
    "test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6,
    "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6,
    "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
    "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
    "test_adam_tensor_lr_amsgrad_capturable_cuda": 6,
    "test_adam_tensor_lr_amsgrad_capturable_xpu": 6,
    "test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
    "test_adam_tensor_lr_tensor_betas_capturable_cuda": 6,
    "test_adam_amsgrad_capturable_cuda": 6,
    "test_adam_amsgrad_capturable_xpu": 6,
    "test_adadelta_tensor_lr_capturable_cuda": 6,
    "test_adadelta_tensor_lr_capturable_xpu": 6,
    "test_rmsprop_tensor_lr_capturable_cuda": 6,
    "test_rmsprop_tensor_lr_capturable_xpu": 6,
    "test_adadelta_foreach_weight_decay_maximize_cpu": 12,
    "test_adadelta_foreach_rho_weight_decay_cpu": 12,
    "test_adadelta_foreach_weight_decay_cpu": 12,
    "test_sgd_foreach_momentum_weight_decay_cpu": 16,
    "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
    "test_sgd_momentum_dampening_foreach_cuda": 5,
    "test_sgd_momentum_dampening_foreach_xpu": 5,
    "test_sgd_momentum_foreach_cuda": 5,
    "test_sgd_momentum_foreach_xpu": 5,
    "test_sgd_weight_decay_maximize_cuda": 4,
    "test_sgd_weight_decay_maximize_xpu": 4,
    "test_sgd_weight_decay_maximize_cpu": 4,
    "test_sgd_weight_decay_cpu": 4,
    "test_sgd_weight_decay_cuda": 4,
    "test_sgd_weight_decay_xpu": 4,
    "test_sgd_momentum_weight_decay_foreach_cuda": 2,
    "test_sgd_momentum_weight_decay_foreach_xpu": 2,
    "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
    "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
    "test_sgd_cuda": 4,
    "test_sgd_cpu": 4,
    "test_sgd_xpu": 4,
    "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
    "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
    "test_adagrad_weight_decay_foreach_xpu": 2,
    "test_adagrad_weight_decay_maximize_foreach_xpu": 2,
    "test_adagrad_tensor_lr_cpu": 6,
    "test_adagrad_tensor_lr_cuda": 6,
    "test_adagrad_tensor_lr_xpu": 6,
    "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
    "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
    "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
    "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
    "test_sgd_tensor_lr_cpu": 2,
    "test_sgd_tensor_lr_cuda": 2,
    "test_sgd_tensor_lr_xpu": 2,
}

# also tracks currently supported optimizers
KERNEL_COUNTS = {
    Adam: KernelCounts(multitensor=2, singletensor=8),
    AdamW: KernelCounts(multitensor=2, singletensor=8),
    NAdam: KernelCounts(multitensor=2, singletensor=8),
    Rprop: KernelCounts(multitensor=2, singletensor=8),
    RMSprop: KernelCounts(multitensor=2, singletensor=8),
    Adadelta: KernelCounts(multitensor=2, singletensor=8),
    Adagrad: KernelCounts(multitensor=2, singletensor=8),
    SGD: KernelCounts(multitensor=1, singletensor=8),
    ASGD: KernelCounts(multitensor=2, singletensor=8),
    RAdam: KernelCounts(multitensor=2, singletensor=8),
    Adamax: KernelCounts(multitensor=2, singletensor=8),
}


def build_opt_kwarg_db():
    compiled_opt_db = []
    for optim_info in optim_db:
        if optim_info.optim_cls not in KERNEL_COUNTS:
            continue

        for device in ["cpu", GPU_TYPE]:
            for optim_inputs in _get_optim_inputs_including_global_cliquey_kwargs(
                device, None, optim_info, skip=("differentiable", "fused")
            ):
                kwargs = dict(optim_inputs.kwargs)
                name = f"test_{optim_info.optim_cls.__name__.lower()}"

                has_tensor_lr = False
                for key, val in kwargs.items():
                    if (not key == "lr" and not key == "betas") and (
                        not isinstance(val, bool) or (isinstance(val, bool) and val)
                    ):
                        name += "_" + key

                    if key == "lr" and isinstance(kwargs["lr"], torch.Tensor):
                        has_tensor_lr = True
                        name += "_tensor_lr"

                    if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor):
                        name += "_tensor_betas"

                name += f"_{device}"

                kwargs["device"] = device
                if name in KERNEL_COUNT_OVERRIDES:
                    kwargs["kernel_count"] = KERNEL_COUNT_OVERRIDES[name]
                else:
                    kwargs["kernel_count"] = (
                        KERNEL_COUNTS[optim_info.optim_cls].multitensor
                        if kwargs.get("foreach", False) and device == GPU_TYPE
                        else KERNEL_COUNTS[optim_info.optim_cls].singletensor
                    )

                if kwargs["kernel_count"] is None or kwargs.get("fused", False):
                    continue

                if has_tensor_lr:
                    for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys():
                        name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}"
                        compiled_opt_db.append(
                            (
                                optim_info.optim_cls,
                                name_w_scheduler,
                                kwargs,
                                scheduler_cls,
                            )
                        )
                else:
                    compiled_opt_db.append((optim_info.optim_cls, name, kwargs, None))

    return compiled_opt_db


COMPILED_OPT_KWARG_DB = build_opt_kwarg_db()

aten = torch.ops.aten


try:
    try:
        from .test_torchinductor import check_model, check_model_gpu
    except ImportError:
        from test_torchinductor import (  # @manual=fbcode//caffe2/test/inductor:test_inductor-library
            check_model,
            check_model_gpu,
        )
except (unittest.SkipTest, ImportError) as e:
    sys.stderr.write(f"{type(e)}: {e}\n")
    if __name__ == "__main__":
        sys.exit(0)
    raise


def call_scheduler(scheduler):
    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        scheduler.step(1.0)  # we won't reduce the metric over two iters anyway
    else:
        scheduler.step()


def compile_opt(opt_compiled, closure=None, fullgraph=True):
    # run the patcher so that step has the expected structure
    torch._dynamo.eval_frame.TorchPatcher.patch()

    # unwrap step TWICE to avoid a deliberate graph break due to
    # a limitation of functionalization/no_grad detection
    # see the [Note on graph break] in optimizer.py
    # This ignores the outer _use_grad_if_differentiable wrapper
    # and instead manually disables grad before calling step, which is fine
    # for now as dynamo does not support differentiable optimizers anyway
    step_fn = opt_compiled.step.__wrapped__.__wrapped__

    # This ensures we don't receive spam of warnings from LR Scheduler
    opt_compiled._opt_called = True

    if closure is not None:

        def fn():
            step_fn(opt_compiled, closure)

    else:

        def fn():
            step_fn(opt_compiled)

    return torch.compile(fn, backend="inductor", fullgraph=fullgraph)


def check_optim(
    self,
    optim_cls,
    params_eager,
    params_compiled,
    state_eager,
    state_compiled,
    atol=None,
    rtol=None,
):
    params_eager = list(params_eager)
    params_compiled = list(params_compiled)
    # Note on tolerances:
    # test_correctness_Adadelta_cuda_float32
    # Mismatched elements: 10 / 100 (10.0%)
    # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
    # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
    # This is due to floating point ordering error + usage of sqrt
    rtol = None
    atol = None
    if optim_cls is Adadelta:
        rtol = 5.5e-4
        atol = 5e-5

    # inductor/test_compiled_optimizers.py::CompiledOptimizerTests::test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_cuda_lambdalr
    # Mismatched elements: 100 / 100 (100.0%)
    # Greatest absolute difference: 1.4960765838623047e-05 at index (2, 0) (up to 1e-05 allowed)
    # Greatest relative difference: 1.686977884673979e-05 at index (2, 0) (up to 1.3e-06 allowed)
    if optim_cls is NAdam:
        atol = 1.5e-5
        rtol = 1.7e-5

    self.assertEqual(list(params_eager), list(params_compiled), atol=atol, rtol=rtol)

    for p_eager, p_compiled in zip(params_eager, params_compiled):
        self.assertEqual(
            state_eager[p_eager],
            state_compiled[p_compiled],
            atol=atol,
            rtol=rtol,
        )


def make_test(
    optim_cls,
    closure=None,
    scheduler_cls=None,
    kernel_count=2,
    device="cuda",
    **kwargs,
):
    @config.patch("score_fusion_memory_threshold", 1)
    def test_fn(self):
        stack = ExitStack()
        try:
            # https://github.com/pytorch/pytorch/issues/118715 for capturable Adagrad support
            # https://github.com/pytorch/pytorch/issues/118018 for capturable SGD support
            run_cudagraphs = device == "cuda" and optim_cls not in (Adagrad, SGD)
            if run_cudagraphs:
                stack.enter_context(config.patch({"triton.cudagraphs": True}))

            kwargs_compiled = deepcopy(kwargs)
            if isinstance(kwargs.get("lr", None), torch.Tensor):
                kwargs["lr"] = kwargs["lr"].to(device)
                kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)

            if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor):
                kwargs["betas"] = (
                    kwargs["betas"][0].to(device),
                    kwargs["betas"][1].to(device),
                )
                kwargs_compiled["betas"] = (
                    kwargs_compiled["betas"][0].to(device),
                    kwargs_compiled["betas"][1].to(device),
                )

            torch._dynamo.reset()
            torch._inductor.metrics.reset()
            input = torch.ones([10, 10], device=device)
            model_eager = torch.nn.Sequential(
                *[torch.nn.Linear(10, 10, device=device) for _ in range(2)]
            )
            model_eager(input).sum().backward()

            input = torch.ones([10, 10], device=device)
            model_compiled = deepcopy(model_eager)
            model_compiled(input).sum().backward()

            opt_eager = optim_cls(model_eager.parameters(), **kwargs)
            opt_compiled = optim_cls(model_compiled.parameters(), **kwargs_compiled)
            compiled_step = compile_opt(opt_compiled, closure=closure)

            if scheduler_cls:
                scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
                scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
                # some schedulers only change after at least an epoch has passed
                scheduler_compiled.last_epoch = 1
                scheduler_eager.last_epoch = 1

            with torch.set_grad_enabled(False):
                for i in range(2):
                    compiled_step()
                    opt_eager.step()
                    if scheduler_cls:
                        call_scheduler(scheduler_eager)
                        call_scheduler(scheduler_compiled)

            check_optim(
                self,
                optim_cls,
                model_eager.parameters(),
                model_compiled.parameters(),
                opt_eager.state,
                opt_compiled.state,
            )

            if run_cudagraphs:
                self.check_cudagraphs_ran()

            if self.check_kernel_count:
                # currently, we compile the step and the rest of the computation
                # separately because the step is a single element tensor
                # hence, the usual kernel count is 2
                self.assertEqual(
                    torch._inductor.metrics.generated_kernel_count, kernel_count
                )
        finally:
            stack.close()

    if device == GPU_TYPE:
        test_fn = requires_gpu(test_fn)

    return test_fn


def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
    @config.patch("score_fusion_memory_threshold", 1)
    @requires_gpu
    def test_fn(self):
        torch._dynamo.reset()
        torch._inductor.metrics.reset()
        input = torch.ones([10, 10], device=GPU_TYPE)
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, device=GPU_TYPE) for _ in range(2)]
        )
        model(input).sum().backward()

        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = compile_opt(opt_compiled)

        # check no recompile here
        with torch.set_grad_enabled(False):
            for _ in range(4):
                compiled_step()

            # perturb state to force recompile
            # Adagrad doesn't reinitialize state on each step
            # SGD has an empty state
            if optim_cls in (Adagrad, SGD):
                opt_compiled.param_groups[0]["lr"] = 0.02
            elif optim_cls is Adam:  # ensure we are guarding on the data_ptr of states
                state_tensor = opt_compiled.state[
                    opt_compiled.param_groups[0]["params"][0]
                ]["exp_avg"]
                opt_compiled.state[opt_compiled.param_groups[0]["params"][0]][
                    "exp_avg"
                ] = torch.zeros_like(state_tensor)
            else:
                opt_compiled.state.clear()

            compiled_step()

        if self.check_kernel_count:
            # currently, we compile the step and the rest of the computation
            # separately because the step is a single element tensor
            # hence, the usual kernel count is 2
            # multiply by 2 to account for the recompile
            multiplier = 2

            self.assertEqual(
                torch._inductor.metrics.generated_kernel_count,
                multiplier * kernel_count,
            )

    return test_fn


class CompiledOptimizerParityTests(TestCase):
    @skipCUDAIf(not has_triton(), "torch.compile with cuda requires triton")
    @skipXPUIf(not has_triton(), "torch.compile with xpu requires triton")
    @optims(optim_db, dtypes=[torch.float32])
    @parametrize("use_closure", [True, False])
    def test_correctness(self, device, dtype, optim_info, use_closure):
        optim_cls = optim_info.optim_cls
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
            device, dtype, optim_info, skip=("differentiable",)
        )

        if optim_info.step_requires_closure and not use_closure:
            return

        for optim_input in all_optim_inputs:
            kwargs = optim_input.kwargs

            use_scheduler = isinstance(kwargs.get("lr", None), torch.Tensor)
            scheduler_classes = (
                list(LR_SCHEDULER_TO_KWARGS.keys()) if use_scheduler else [None]
            )

            for scheduler_cls in scheduler_classes:
                torch._dynamo.reset()
                torch._inductor.metrics.reset()
                input = torch.ones([10, 10], device=device)
                model_eager = torch.nn.Sequential(
                    *[torch.nn.Linear(10, 10, device=device) for _ in range(2)]
                )
                model_eager(input).sum().backward()
                model_compiled = deepcopy(model_eager)
                model_compiled(input).sum().backward()

                if optim_cls is SparseAdam:
                    for param in model_eager.parameters():
                        param.grad = param.grad.to_sparse()
                    for param in model_compiled.parameters():
                        param.grad = param.grad.to_sparse()

                opt_compiled = optim_cls(
                    model_compiled.parameters(), **deepcopy(kwargs)
                )
                opt_eager = optim_cls(model_eager.parameters(), **deepcopy(kwargs))
                if scheduler_cls:
                    scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
                    scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
                    # some schedulers only change after at least an epoch has passed
                    scheduler_compiled.last_epoch = 1
                    scheduler_eager.last_epoch = 1

                num_steps = 2
                if use_closure:

                    @torch.compile()
                    def fn():
                        def closure():
                            loss = model_compiled(input).sum()
                            loss.backward()
                            if optim_info.only_supports_sparse_grads:
                                for param in model_compiled.parameters():
                                    param.grad = param.grad.to_sparse()
                            return loss

                        opt_compiled.step(closure)
                        if scheduler_cls:
                            call_scheduler(scheduler_compiled)

                    def closure_eager():
                        loss = model_eager(input).sum()
                        loss.backward()
                        if optim_info.only_supports_sparse_grads:
                            for param in model_eager.parameters():
                                param.grad = param.grad.to_sparse()

                        return loss

                    for _ in range(num_steps):
                        opt_eager.step(closure_eager)
                        if scheduler_cls:
                            call_scheduler(scheduler_eager)
                else:

                    @torch.compile()
                    def fn():
                        opt_compiled.step()
                        if scheduler_cls:
                            call_scheduler(scheduler_compiled)

                    for _ in range(num_steps):
                        opt_eager.step()
                        if scheduler_cls:
                            call_scheduler(scheduler_eager)

                for _ in range(num_steps):
                    fn()

                check_optim(
                    self,
                    optim_cls,
                    model_eager.parameters(),
                    model_compiled.parameters(),
                    opt_eager.state,
                    opt_compiled.state,
                )


class CompiledOptimizerTests(TestCase):
    check_model_gpu = check_model_gpu
    check_model_cpu = check_model
    check_kernel_count = True

    def setUp(self):
        super().setUp()
        torch._dynamo.reset()
        torch._inductor.metrics.reset()

    def tearDown(self):
        super().tearDown()
        torch._dynamo.reset()
        torch._inductor.metrics.reset()

    def check_cudagraphs_ran(self):
        # We run the zeroth device currently
        manager = torch._inductor.cudagraph_trees.get_container(0).tree_manager
        self.assertIsNotNone(manager)
        self.assertEqual(manager.new_graph_id().id, 1)

    test_adam_recompile = make_recompile_test(Adam, lr=0.01)
    test_adamw_recompile = make_recompile_test(AdamW, lr=0.01)
    test_adamax_recompile = make_recompile_test(Adamax, lr=0.01)
    test_nadam_recompile = make_recompile_test(NAdam, lr=0.01)
    test_rprop_recompile = make_recompile_test(Rprop, lr=0.01, kernel_count=2)
    test_rmsprop_recompile = make_recompile_test(RMSprop, lr=0.01)
    test_adadelta_recompile = make_recompile_test(Adadelta, lr=0.01)
    test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01)
    test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01)
    test_asgd_recompile_single = make_recompile_test(
        ASGD, kernel_count=8, lr=0.01, foreach=False
    )
    test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True)
    test_sgd_recompile_single = make_recompile_test(
        SGD, kernel_count=4, lr=0.01, foreach=False
    )
    test_sgd_recompile_foreach = make_recompile_test(
        SGD, kernel_count=1, lr=0.01, foreach=True
    )

    @requires_gpu
    def test_static_address_finalizer(self):
        import gc

        gc.disable()
        p_ref = None

        def fn():
            nonlocal p_ref
            mod = torch.nn.Linear(10, 10, device=GPU_TYPE, bias=False)
            for p in mod.parameters():
                p.grad = torch.rand_like(p)

            opt = torch.optim.Adam(mod.parameters(), lr=0.1)

            def fn():
                opt.step()

            with torch.set_grad_enabled(False):
                step_fn_compiled = torch.compile(fn)
                step_fn_compiled()
            p_ref = weakref.ref(p)
            self.assertTrue(p_ref() is not None)

        fn()

        self.assertTrue(p_ref() is None)
        gc.enable()

    def test_guard_on_none_grads(self):
        def training_loop():
            input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).reshape(3, 2)

            model = torch.nn.Sequential(
                torch.nn.Linear(2, 3),
                torch.nn.Sigmoid(),
                torch.nn.Linear(3, 1),
                torch.nn.Sigmoid(),
            )

            params = list(model.parameters())
            optimizer = torch.optim.Adam(params)
            step_list = []

            for i in range(6):
                optimizer.zero_grad()
                # Test that step behaves as expected (a no-op) when grads are set to None
                if i != 3:
                    output = model(input)
                    loss = output.sum()
                    loss.backward()

                optimizer.step()
                step_list.append(optimizer.state[params[0]]["step"])

            return step_list

        compiled_training_loop = torch._dynamo.optimize("eager")(training_loop)
        actual_steps = compiled_training_loop()
        expected_steps = training_loop()
        self.assertEqual(actual_steps, expected_steps)

    # Basic shampoo test to verify we support compiling the various ops without error
    @requires_gpu
    def test_basic_shampoo(self):
        param_buf = torch.rand((1024, 128))
        param_buf_c = param_buf.detach().clone()

        params_c = [param_buf_c[0:512, :].t(), param_buf_c[512:, :].t()]
        params = [param_buf[0:512, :].t(), param_buf[512:, :].t()]

        for p, p_c in zip(params, params_c):
            p.grad = torch.rand_like(p)
            p_c.grad = p.grad.detach().clone()

        # note this skips the root inverse because this has a lot of internal dependencies
        # we also don't compile it regardless
        @torch.no_grad()
        def shampoo_functional_basic(params):
            step = 1
            weight_decay = 0.1
            grads = [p.grad for p in params]
            beta1 = 0.9
            beta2 = 1.0
            epsilon = 1e-10
            preconditioners = [torch.zeros_like(p) for p in params]
            lr = 0.01

            # pt2 region 1
            # weight decay
            torch._foreach_add_(grads, params, alpha=weight_decay)

            # update preconditioners
            torch._foreach_addcmul_(preconditioners, grads, grads, value=1.0)

            torch._foreach_mul_(grads, beta1)
            torch._foreach_add_(
                grads,
                grads,
                alpha=1 - beta1,
            )
            bias_correction1 = 1.0 - beta1**step
            grad_list = torch._foreach_div(grads, bias_correction1)

            # pt2 region 2
            # precondition (with shampoo branch), with no grafting
            bias_correction2 = 1.0 - beta2**step
            bias_corrected_preconditioner_list = torch._foreach_div(
                preconditioners, bias_correction2
            )
            torch._foreach_sqrt_(bias_corrected_preconditioner_list)
            torch._foreach_add_(bias_corrected_preconditioner_list, epsilon)
            search_directions = torch._foreach_div(
                grad_list, bias_corrected_preconditioner_list
            )

            torch._foreach_add_(
                search_directions,
                params,
                alpha=weight_decay,
            )

            torch._foreach_mul_(search_directions, -lr)
            # pt2 region 3 update params
            torch._foreach_add_(params, search_directions)

            return params, preconditioners, grads

        compiled_fn = torch.compile(shampoo_functional_basic)

        self.assertEqual(compiled_fn(params_c), shampoo_functional_basic(params))

    @requires_gpu
    def test_closure_graph_break(self):
        param = torch.rand(
            2, 3, dtype=torch.float32, device=GPU_TYPE, requires_grad=True
        )
        param_c = param.detach().clone().requires_grad_(True)

        def closure():
            param.grad = torch.ones_like(param) * 2
            return param.grad

        def closure_c():
            param_c.grad = torch.ones_like(param_c) * 2
            return param_c.grad

        optimizer = torch.optim.AdamW([param])
        optimizer_c = torch.optim.AdamW([param_c])

        def loop(opt, c):
            opt.step(c)

        compiled_loop = torch._dynamo.optimize("eager")(loop)

        compiled_loop(optimizer, closure)
        loop(optimizer_c, closure_c)

        self.assertEqual(param, param_c)

    def test_get_value_on_static_address(self):
        from torch._dynamo.decorators import mark_static_address
        from torch.optim.optimizer import _get_value

        compiled = torch.compile(_get_value)

        x = torch.ones(2, 2)
        mark_static_address(x)

        ret_val = compiled(x)

        self.assertEqual(ret_val, x)

    # compile a large foreach op and verify
    # that the time taken is within an expected range
    @requires_gpu
    def test_compile_time_smoketest(self):
        import time

        xs = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)]
        ys = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)]

        @torch.compile
        def fn(xs, ys):
            return torch._foreach_add(xs, ys)

        start = time.perf_counter()
        fn(xs, ys)
        end = time.perf_counter()

        self.assertLess(end - start, 90)

    @requires_cuda
    def test_S429861(self):
        # Just verify we can compile this function without error
        try:
            from . import s429861_repro
        except ImportError:
            import s429861_repro  # @manual

        forward = s429861_repro.forward

        import torch._dynamo
        import torch._inductor
        from torch._dynamo.debug_utils import aot_graph_input_parser
        from torch._inductor.utils import fresh_inductor_cache

        with fresh_inductor_cache():
            kwargs = aot_graph_input_parser(forward)
            torch.compile(forward)(**kwargs)


for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB:
    setattr(
        CompiledOptimizerTests,
        name,
        make_test(optim_cls, scheduler_cls=scheduler_cls, **kwargs),
    )

instantiate_device_type_tests(
    CompiledOptimizerParityTests, globals(), allow_xpu=True, except_for="cpu"
)

if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if HAS_CPU or HAS_GPU:
        run_tests(needs="filelock")