# Owner(s): ["module: optimizer"]

import warnings
import math
import unittest
import functools
import itertools
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter
from torch.optim import SGD
from torch import sparse
from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \
    MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
    _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler, PolynomialLR, \
    EPOCH_DEPRECATION_WARNING
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
    parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests


def rosenbrock(tensor):
    x, y = tensor
    return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2


def drosenbrock(tensor):
    x, y = tensor
    return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))


class TestOptim(TestCase):
    exact_dtype = True

    def _test_rosenbrock_sparse(self, constructor, scheduler_constructors=None,
                                sparse_only=False, maximize=False):
        if scheduler_constructors is None:
            scheduler_constructors = []
        params_t = torch.tensor([1.5, 1.5])

        params = Parameter(params_t)
        optimizer = constructor([params])
        schedulers = []
        for scheduler_constructor in scheduler_constructors:
            schedulers.append(scheduler_constructor(optimizer))

        if not sparse_only:
            params_c = Parameter(params_t.clone())
            optimizer_c = constructor([params_c])

        solution = torch.tensor([1, 1])
        with torch.no_grad():
            initial_dist = params.dist(solution)

        def eval(params, sparse_grad, w):
            # Depending on w, provide only the x or y gradient
            optimizer.zero_grad()
            loss = rosenbrock(params)
            loss.backward()
            grad = drosenbrock(params.data)
            # NB: We torture test the optimizer by returning an
            # uncoalesced sparse tensor
            if w:
                i = torch.LongTensor([[0, 0]])
                x = grad[0]
                v = torch.tensor([x / 4., x - x / 4.])
            else:
                i = torch.LongTensor([[1, 1]])
                y = grad[1]
                v = torch.tensor([y - y / 4., y / 4.])
            x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
            with torch.no_grad():
                if sparse_grad:
                    params.grad = x
                else:
                    params.grad = x.to_dense()
            return loss

        for i in range(2000):
            # Do cyclic coordinate descent
            w = i % 2
            optimizer.step(functools.partial(eval, params, True, w))
            for scheduler in schedulers:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(rosenbrock(params))
                else:
                    scheduler.step()
            if not sparse_only:
                optimizer_c.step(functools.partial(eval, params_c, False, w))
                self.assertEqual(params, params_c)

        if not maximize:
            self.assertLessEqual(params.data.dist(solution), initial_dist)
        else:
            self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t))

    def _test_basic_cases_template(self, weight_tensor, bias_tensor, input_tensor, constructor,
                                   scheduler_constructors, constructor_accepts_maximize=True, constructor_accepts_foreach=False):
        maximize_options = set([False, constructor_accepts_maximize])
        foreach_options = set([False, constructor_accepts_foreach])

        four_arg_constructor = constructor
        if constructor_accepts_maximize and constructor_accepts_foreach:
            pass
        elif constructor_accepts_maximize:
            def four_arg_constructor(weight, bias, maximize, foreach):
                self.assertFalse(foreach)
                return constructor(weight, bias, maximize)
        elif constructor_accepts_foreach:
            def four_arg_constructor(weight, bias, maximize, foreach):
                self.assertFalse(maximize)
                return constructor(weight, bias, foreach)
        else:
            def four_arg_constructor(weight, bias, maximize, foreach):
                self.assertFalse(maximize or foreach)
                return constructor(weight, bias)

        for maximize, foreach in itertools.product(maximize_options, foreach_options):
            with torch.no_grad():
                weight = Parameter(weight_tensor.clone().detach())
                bias = Parameter(bias_tensor.clone().detach())
                input = input_tensor.clone().detach().requires_grad_()
            optimizer = four_arg_constructor(weight, bias, maximize, foreach)
            schedulers = []
            for scheduler_constructor in scheduler_constructors:
                schedulers.append(scheduler_constructor(optimizer))

            # to check if the optimizer can be printed as a string
            optimizer.__repr__()

            def fn():
                optimizer.zero_grad()
                y = weight.mv(input)
                if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
                    y = y.cuda(bias.get_device())
                loss = (y + bias).pow(2).sum()
                loss.backward()
                return loss

            initial_value = fn().item()
            for _ in range(200):
                for scheduler in schedulers:
                    if isinstance(scheduler, ReduceLROnPlateau):
                        val_loss = fn()
                        scheduler.step(val_loss)
                    else:
                        scheduler.step()
                optimizer.step(fn)
            if maximize:
                self.assertGreater(fn().item(), initial_value)
            else:
                self.assertLess(fn().item(), initial_value)

    def _test_state_dict(self, weight, bias, input, constructor):
        weight = Parameter(weight)
        bias = Parameter(bias)
        with torch.no_grad():
            input = input.clone().detach().requires_grad_()

        def fn_base(optimizer, weight, bias):
            optimizer.zero_grad()
            i = input_cuda if weight.is_cuda else input
            loss = (weight.mv(i) + bias).pow(2).sum()
            loss.backward()
            return loss

        optimizer = constructor(weight, bias)
        fn = functools.partial(fn_base, optimizer, weight, bias)

        # Prime the optimizer
        for _i in range(20):
            optimizer.step(fn)
        # Clone the weights and construct new optimizer for them
        with torch.no_grad():
            weight_c = Parameter(weight.clone().detach())
            bias_c = Parameter(bias.clone().detach())
        optimizer_c = constructor(weight_c, bias_c)
        fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
        # Load state dict
        state_dict = deepcopy(optimizer.state_dict())
        state_dict_c = deepcopy(optimizer.state_dict())
        optimizer_c.load_state_dict(state_dict_c)
        # Run both optimizations in parallel
        for _ in range(20):
            optimizer.step(fn)
            optimizer_c.step(fn_c)
            self.assertEqual(weight, weight_c)
            self.assertEqual(bias, bias_c)
        # Make sure state dict wasn't modified
        self.assertEqual(state_dict, state_dict_c)
        # Make sure state dict is deterministic with equal but not identical parameters
        self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
        # Make sure repeated parameters have identical representation in state dict
        optimizer_c.param_groups.extend(optimizer_c.param_groups)
        self.assertEqual(optimizer.state_dict()['param_groups'][-1],
                         optimizer_c.state_dict()['param_groups'][-1])

        # Make sure that optimizers that support maximize can load older models
        state_dict = optimizer.state_dict()
        if 'maximize' in state_dict['param_groups'][0]:
            for group in state_dict['param_groups']:
                del group['maximize']
            optimizer.load_state_dict(state_dict)
            # Make sure we can still step
            optimizer.step()
        # Make sure that optimizers that support foreach can load older models
        state_dict = optimizer.state_dict()
        if 'foreach' in state_dict['param_groups'][0]:
            for group in state_dict['param_groups']:
                del group['foreach']
            optimizer.load_state_dict(state_dict)
            # Make sure we can still step
            optimizer.step()

        # Make sure that loading optimizers with step not wrapped in tensor can work
        state_dict = optimizer.state_dict()
        if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']):
            for state in state_dict['state'].values():
                state['step'] = state['step'].item()
            optimizer.load_state_dict(state_dict)
            optimizer.step()

        # Check that state dict can be loaded even when we cast parameters
        # to a different type and move to a different device.
        if not torch.cuda.is_available():
            return

        with torch.no_grad():
            input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda")
            weight_cuda = Parameter(weight.clone().detach().to(dtype=torch.float32, device="cuda"))
            bias_cuda = Parameter(bias.clone().detach().to(dtype=torch.float32, device="cuda"))
        optimizer_cuda = constructor(weight_cuda, bias_cuda)
        fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)

        state_dict = deepcopy(optimizer.state_dict())
        state_dict_c = deepcopy(optimizer.state_dict())
        optimizer_cuda.load_state_dict(state_dict_c)

        # Make sure state dict wasn't modified
        self.assertEqual(state_dict, state_dict_c)

        # Make sure that device of state['step'] is still CPU
        new_state_dict = optimizer_cuda.state_dict()
        if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']):
            for state in new_state_dict['state'].values():
                self.assertEqual(state['step'].device.type, 'cpu')

        for _i in range(20):
            optimizer.step(fn)
            optimizer_cuda.step(fn_cuda)
            self.assertEqual(weight, weight_cuda)
            self.assertEqual(bias, bias_cuda)

        # validate deepcopy() copies all public attributes
        def getPublicAttr(obj):
            return set(k for k in obj.__dict__ if not k.startswith('_'))
        self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))

    def _test_basic_cases(self, constructor, scheduler_constructors=None,
                          ignore_multidevice=False, constructor_accepts_maximize=False, constructor_accepts_foreach=False):
        if scheduler_constructors is None:
            scheduler_constructors = []

        def make_two_arg_constructor(constructor, maximize: bool = False, foreach: bool = False):
            if constructor_accepts_maximize and constructor_accepts_foreach:
                return lambda weight, bias: constructor(weight, bias, maximize, foreach)
            if constructor_accepts_maximize:
                return lambda weight, bias: constructor(weight, bias, maximize)
            if constructor_accepts_foreach:
                return lambda weight, bias: constructor(weight, bias, foreach)
            return constructor

        for maximize, foreach in itertools.product(
            set([False, constructor_accepts_maximize]),
            set([False, constructor_accepts_foreach]),
        ):
            self._test_state_dict(
                torch.randn(10, 5),
                torch.randn(10),
                torch.randn(5),
                make_two_arg_constructor(constructor, maximize, foreach),
            )
        self._test_basic_cases_template(
            torch.randn(10, 5),
            torch.randn(10),
            torch.randn(5),
            constructor,
            scheduler_constructors,
            constructor_accepts_maximize,
            constructor_accepts_foreach,
        )
        # non-contiguous parameters
        self._test_basic_cases_template(
            torch.randn(10, 5, 2)[..., 0],
            torch.randn(10, 2)[..., 0],
            torch.randn(5),
            constructor,
            scheduler_constructors,
            constructor_accepts_maximize,
            constructor_accepts_foreach,
        )
        # CUDA
        if not torch.cuda.is_available():
            return
        self._test_basic_cases_template(
            torch.randn(10, 5).cuda(),
            torch.randn(10).cuda(),
            torch.randn(5).cuda(),
            constructor,
            scheduler_constructors,
            constructor_accepts_maximize,
            constructor_accepts_foreach,
        )
        # Multi-GPU
        if not torch.cuda.device_count() > 1 or ignore_multidevice:
            return
        self._test_basic_cases_template(
            torch.randn(10, 5).cuda(0),
            torch.randn(10).cuda(1),
            torch.randn(5).cuda(0),
            constructor,
            scheduler_constructors,
            constructor_accepts_maximize,
            constructor_accepts_foreach,
        )

    def _test_complex_optimizer(self, optimizer_constructor):
        complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
        real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
        complex_opt = optimizer_constructor(complex_param)
        real_opt = optimizer_constructor(real_param)

        for _ in range(3):
            complex_param.grad = torch.randn_like(complex_param)
            real_param.grad = torch.view_as_real(complex_param.grad)
            complex_opt.step()
            real_opt.step()

            self.assertEqual(torch.view_as_real(complex_param), real_param)

    def _test_complex_2d(self, optimizer_constructor, f=None):
        if f is None:
            f = rosenbrock
        a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
        a1_real = a1.real.clone().detach()
        a1_imag = a1.imag.clone().detach()
        a1_real.requires_grad_()
        a1_imag.requires_grad_()
        optim1 = optimizer_constructor([a1])
        optim2 = optimizer_constructor([a1_real, a1_imag])

        for _ in range(10):
            optim1.zero_grad()
            optim2.zero_grad()
            a2 = torch.complex(a1_real, a1_imag)
            f(a1).backward()
            f(a2).backward()

            self.assertEqual(a1.grad.real, a1_real.grad)
            self.assertEqual(a1.grad.imag, a1_imag.grad)

            optim1.step()
            optim2.step()
            self.assertEqual(a1.real, a1_real)
            self.assertEqual(a1.imag, a1_imag)

    def _build_params_dict(self, weight, bias, **kwargs):
        return [{'params': [weight]}, dict(params=[bias], **kwargs)]

    def _build_params_dict_single(self, weight, bias, **kwargs):
        return [dict(params=bias, **kwargs)]

    def test_sgd(self):
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD(
                self._build_params_dict_single(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD(
                self._build_params_dict_single(weight, bias, lr=1e-2), maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
                lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
                lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
                lambda opt: ExponentialLR(opt, gamma=0.99),
                lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach:
            optim.SGD([weight, bias], lr=1e-3, momentum=0.5, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach:
            optim.SGD([weight, bias], lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach:
            optim.SGD([weight, bias], nesterov=True, lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
            constructor_accepts_maximize=True, constructor_accepts_foreach=True,
        )
        with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
            optim.SGD(None, lr=1e-2, momentum=-0.5)

    def test_sgd_sparse(self):
        for foreach in (False, True):
            self._test_rosenbrock_sparse(
                lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach)
            )
            self._test_rosenbrock_sparse(
                lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
                [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)]
            )

    def test_sgd_complex(self):
        for foreach in (False, True):
            self._test_complex_optimizer(
                lambda param: optim.SGD([param], lr=0.001, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.SGD([param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.SGD([param], lr=0.001, nesterov=True, momentum=1, weight_decay=1, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.SGD([param], lr=0.001, momentum=1, dampening=0.5, weight_decay=1, foreach=foreach)
            )

    def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
        if not torch.cuda.is_available():
            return
        assert flag in ("foreach", "fused")

        kIterations = 4
        device = 'cuda'
        for optimizer_constructor, params in optimizer_pairs_with_flags:
            res, state = [], []
            for foreach in (False, True):
                input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device).reshape(3, 2)

                torch.manual_seed(1)
                model = torch.nn.Sequential(torch.nn.Linear(2, 3),
                                            torch.nn.Sigmoid(),
                                            torch.nn.Linear(3, 1),
                                            torch.nn.Sigmoid())
                model.to(dtype=torch.float64, device=device)
                params_with_foreach = deepcopy(params)
                params_with_foreach["foreach"] = foreach
                optimizer = optimizer_constructor(model.parameters(), **params_with_foreach)

                for _ in range(kIterations):
                    optimizer.zero_grad()
                    output = model(input)
                    loss = output.sum()
                    loss.backward()

                    # Test that step behaves as expected (a no-op) when grads are set to None
                    if iter == 0:
                        optimizer.zero_grad(set_to_none=True)

                    optimizer.step()

                state.append(optimizer.state)
                res.append(model.parameters())

            st_state = state[0]
            mt_state = state[1]
            for st_p, mt_p in zip(res[0], res[1]):
                self.assertEqual(st_p, mt_p, atol=5e-5, rtol=0)

                # check that optimizer states are the same
                st_p_state = st_state[st_p]
                mt_p_state = mt_state[mt_p]

                for k in st_p_state:
                    actual = mt_p_state[k]
                    # If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
                    # `step` Tensor is 1D while usually it's 0D.
                    if k == "step" and isinstance(actual, torch.Tensor) and actual.ndim == 1:
                        actual = actual[0]
                    self.assertEqual(st_p_state[k], actual, atol=5e-5, rtol=0)

    def test_multi_tensor_optimizers(self):
        optimizer_pairs_with_flags = [
            (optim.Adam, dict(weight_decay=1., amsgrad=True)),
            (optim.Adam, dict(weight_decay=1., amsgrad=False)),
            (optim.Adam, dict(weight_decay=0., amsgrad=True)),
            (optim.Adam, dict(weight_decay=0., amsgrad=False)),
            (optim.AdamW, dict(weight_decay=1., amsgrad=True)),
            (optim.AdamW, dict(weight_decay=1., amsgrad=False)),
            (optim.AdamW, dict(weight_decay=0., amsgrad=True)),
            (optim.AdamW, dict(weight_decay=0., amsgrad=False)),
            (optim.NAdam, dict(weight_decay=0., momentum_decay=6e-3)),
            (optim.NAdam, dict(weight_decay=1., momentum_decay=6e-3)),
            (optim.NAdam, dict(weight_decay=0., momentum_decay=4e-3)),
            (optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
            (optim.SGD, dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)),
            (optim.SGD, dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)),
            (optim.RAdam, dict(weight_decay=0)),
            (optim.RAdam, dict(weight_decay=1)),
            (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
            (optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
            (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
            (optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
            (optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
            (optim.ASGD, dict(weight_decay=0)),
            (optim.ASGD, dict(weight_decay=1)),
            (optim.Adamax, dict(weight_decay=0)),
            (optim.Adamax, dict(weight_decay=1)),
            (optim.Adadelta, dict(weight_decay=0)),
            (optim.Adadelta, dict(weight_decay=1)),
            (optim.Adagrad, dict(weight_decay=0)),
            (optim.Adagrad, dict(weight_decay=1)),
        ]
        self._test_derived_optimizers(optimizer_pairs_with_flags, "foreach")

    def test_fused_optimizers(self):
        optimizer_pairs_with_flags = [
            (optim.Adam, dict(weight_decay=1., amsgrad=False)),
            (optim.Adam, dict(weight_decay=1., amsgrad=True)),
            (optim.Adam, dict(weight_decay=0., amsgrad=False)),
            (optim.Adam, dict(weight_decay=0., amsgrad=True)),
        ]
        self._test_derived_optimizers(optimizer_pairs_with_flags, "fused")

    def test_adam(self):
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                [weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: ExponentialLR(opt, gamma=0.9)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach),
            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
                lambda opt: ExponentialLR(opt, gamma=0.9)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach),
            [lambda opt: ExponentialLR(opt, gamma=0.9),
                lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
                lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )

        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, maximize=maximize, foreach=foreach),
            [lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_complex_2d(optim.Adam)
        self._test_complex_2d(functools.partial(optim.Adam, foreach=True))

        with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
            optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))

        with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
            optim.Adam(None, lr=1e-2, weight_decay=-1)

    def test_adamw(self):
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.AdamW([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.AdamW(
                self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.AdamW(
                [weight, bias], lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.AdamW(
                [weight, bias], lr=1e-3, weight_decay=1, amsgrad=True, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_complex_2d(optim.AdamW)
        self._test_complex_2d(functools.partial(optim.AdamW, foreach=True))
        with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
            optim.AdamW(None, lr=1e-2, weight_decay=-1)

    def test_sparse_adam(self):
        self._test_rosenbrock_sparse(
            lambda params: optim.SparseAdam(params, lr=4e-2),
            [],
            True
        )
        self._test_rosenbrock_sparse(
            lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True),
            [],
            True,
            True
        )
        with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
            optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0))
        with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"):
            optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)])
        with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"):
            optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])

    # ROCm precision is too low to pass this test
    def test_adadelta(self):
        # Handles https://github.com/pytorch/pytorch/issues/69698
        self.rel_tol = 4e-3
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adadelta([weight, bias], maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adadelta(
                self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adadelta(
                self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach),
            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
                lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adadelta(
                [weight, bias], weight_decay=1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
            optim.Adadelta(None, lr=1e-2, rho=1.1)

    def test_adadelta_complex(self):
        # Handles https://github.com/pytorch/pytorch/issues/69698
        self.rel_tol = 2e-2
        for optimizer in [optim.Adadelta]:
            self._test_complex_optimizer(
                lambda weight: optimizer([weight])
            )
            self._test_complex_optimizer(
                lambda weight: optimizer([weight], rho=0.95)
            )
            self._test_complex_optimizer(
                lambda weight: optimizer([weight], rho=0.95, weight_decay=1)
            )

    def test_nadam(self):
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.NAdam([weight, bias], lr=1e-3, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.NAdam(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-3, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.NAdam(
                [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.NAdam(
                [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach),
            [lambda opt: ExponentialLR(opt, gamma=0.9)],
            constructor_accepts_foreach=True,
        )
        with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
            optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
        with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
            optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)

    def test_adagrad(self):
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adagrad([weight, bias], lr=1e-1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adagrad(
                [weight, bias], lr=1e-1, initial_accumulator_value=0.1, maximize=maximize, foreach=foreach,
            ),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adagrad(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-1,
                maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adagrad(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-1,
                maximize=maximize, foreach=foreach),
            [lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adagrad(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-1,
                maximize=maximize, foreach=foreach),
            [lambda opt: ReduceLROnPlateau(opt),
                lambda opt: ExponentialLR(opt, gamma=0.99)],
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
            optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)

    def test_adagrad_sparse(self):
        for foreach in (False, True):
            self._test_rosenbrock_sparse(
                lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach)
            )
            self._test_rosenbrock_sparse(
                lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
                [lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
                 lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)]
            )

    def test_adagrad_complex(self):
        for foreach in (False, True):
            self._test_complex_optimizer(
                lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.Adagrad(
                    [param], lr=1e-1, initial_accumulator_value=0.1, foreach=foreach,
                )
            )

    def test_adamax(self):
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adamax(
                [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adamax(
                self._build_params_dict(weight, bias, lr=1e-2),
                lr=1e-1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, maximize, foreach: optim.Adamax(
                [weight, bias], lr=1e-1, weight_decay=1, maximize=maximize, foreach=foreach),
            constructor_accepts_maximize=True,
            constructor_accepts_foreach=True,
        )
        self._test_complex_2d(optim.Adamax)
        self._test_complex_2d(functools.partial(optim.Adamax, foreach=True))
        with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"):
            optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0))

    def test_radam(self):
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.RAdam(
                self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach),
            constructor_accepts_foreach=True,
        )
        self._test_basic_cases(
            lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach),
            [lambda opt: ExponentialLR(opt, gamma=0.9), lambda opt: ReduceLROnPlateau(opt)],
            constructor_accepts_foreach=True,
        )
        with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
            optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0))

        with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
            optim.RAdam(None, lr=1e-2, weight_decay=-1)

    def test_rmsprop(self):
        for foreach in (False, True):
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    [weight, bias], lr=1e-2, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    self._build_params_dict(weight, bias, lr=1e-3),
                    lr=1e-2, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    self._build_params_dict(weight, bias, lr=1e-3),
                    lr=1e-2, centered=True, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    self._build_params_dict(weight, bias, lr=1e-3),
                    lr=1e-2, centered=True, momentum=0.1, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    self._build_params_dict(weight, bias, lr=1e-3),
                    lr=1e-2, momentum=0.1, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.RMSprop(
                    self._build_params_dict(weight, bias, lr=1e-3),
                    lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach))
            self._test_complex_2d(lambda param: optim.RMSprop(param, centered=True, foreach=foreach))
            self._test_complex_2d(lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach))
            self._test_complex_2d(lambda param: optim.RMSprop(param, maximize=True, foreach=foreach))
            self._test_complex_optimizer(lambda param: optim.RMSprop([param], foreach=foreach))
            self._test_complex_optimizer(lambda param: optim.RMSprop([param], centered=True, foreach=foreach))
            self._test_complex_optimizer(lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach))
            self._test_complex_optimizer(lambda param: optim.RMSprop([param], maximize=True, foreach=foreach))
            with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
                optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach)

    def test_asgd(self):
        for foreach in (False, True):
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.ASGD(
                    [weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.ASGD(
                    self._build_params_dict(weight, bias, lr=1e-2),
                    lr=1e-3, t0=100, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.ASGD(
                    self._build_params_dict(weight, bias, lr=1e-2),
                    lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            # Ref: https://github.com/pytorch/pytorch/issues/84560
            # self._test_complex_2d(optimizer)
            self._test_complex_optimizer(lambda params: optim.ASGD([params], foreach=foreach))
            self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, foreach=foreach))
            self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, weight_decay=0.9, foreach=foreach))
            self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=False, weight_decay=0.9, foreach=foreach))
            self._test_complex_optimizer(lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach))
            with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"):
                optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach)

    @skipIfRocm
    def test_rprop(self):
        for foreach in (False, True):
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.Rprop(
                    [weight, bias], lr=2e-4, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_basic_cases(
                lambda weight, bias, maximize, foreach: optim.Rprop(
                    self._build_params_dict(weight, bias, lr=1e-2), lr=2e-4, maximize=maximize, foreach=foreach),
                constructor_accepts_maximize=True,
                constructor_accepts_foreach=True,
            )
            self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach))
            self._test_complex_optimizer(
                lambda param: optim.Rprop([param], lr=0.001, foreach=foreach)
            )
            self._test_complex_optimizer(
                lambda param: optim.Rprop([param], lr=0.001, maximize=True, foreach=foreach)
            )
            with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
                optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach)

    def test_lbfgs(self):
        self._test_basic_cases(
            lambda weight, bias: optim.LBFGS([weight, bias]),
            ignore_multidevice=True
        )
        self._test_basic_cases(
            lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_wolfe"),
            ignore_multidevice=True
        )

    @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
    def test_lbfgs_return_type(self):
        params = [torch.randn(10, 5), torch.randn(10)]
        opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf)
        opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf)

        def closure():
            return torch.tensor([10])

        res1 = opt1.step(closure)
        res2 = opt2.step(closure)
        self.assertEqual(type(res1), type(res2))

    def test_invalid_param_type(self):
        with self.assertRaises(TypeError):
            optim.SGD(Parameter(torch.randn(5, 5)), lr=3)

    def test_duplicate_params_in_param_group(self):
        param = Parameter(torch.randn(5, 5))
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            optim.SGD([param, param], lr=0.1)
            self.assertEqual(len(w), 1)
            self.assertIn('a parameter group with duplicate parameters', str(w[0].message))

    def test_no_grad_for_all_params(self):
        params = [torch.randn(5, 5, requires_grad=False) for _ in range(2)]

        optimizer_list = [
            optim.Adadelta,
            optim.AdamW,
            optim.Adam,
            optim.Adagrad,
            optim.Adamax,
            optim.RMSprop,
            optim.SGD,
            optim.SparseAdam,
            optim.ASGD,
        ]
        for optim_ctr in optimizer_list:
            opt = optim_ctr(params, lr=0.1)
            # make sure step can still run even if
            # all params have no grad
            opt.step()

    # make sure that `state_steps` is correctly either updated or not updated when `found_inf`.
    def test_functional_fused_adam_with_foundinf(self):
        if not torch.cuda.is_available():
            self.skipTest("CUDA is required.")

        from torch.optim import adam

        num_tensors = 5
        for amsgrad in (False, True):
            params, grads, exp_avgs, exp_avg_sqs = [[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
            max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
            state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
            grad_scale = torch.cuda.amp.grad_scaler._MultiDeviceReplicator(
                torch.ones((1,), dtype=torch.float32, device="cuda"))
            found_inf = torch.cuda.amp.grad_scaler._MultiDeviceReplicator(
                torch.ones((1,), dtype=torch.float32, device="cuda"))

            adam.adam(
                params,
                grads,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
                foreach=False,
                capturable=False,
                fused=True,
                amsgrad=amsgrad,
                beta1=0.9,
                beta2=0.99,
                lr=1e-2,
                weight_decay=.0,
                eps=1e-8,
                maximize=False,
                grad_scale=grad_scale,
                found_inf=found_inf,
            )

            self.assertEqual(
                state_steps,
                [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)],
            )


class SchedulerTestNet(torch.nn.Module):
    def __init__(self):
        super(SchedulerTestNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 1, 1)
        self.conv2 = torch.nn.Conv2d(1, 1, 1)

    def forward(self, x):
        return self.conv2(F.relu(self.conv1(x)))


class LambdaLRTestObject:
    def __init__(self, value):
        self.value = value

    def __call__(self, epoch):
        return self.value * epoch

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False


class TestLRScheduler(TestCase):
    exact_dtype = True

    def setUp(self):
        super(TestLRScheduler, self).setUp()
        self.net = SchedulerTestNet()
        self.opt = SGD(
            [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
            lr=0.05)

    def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1):
        """This function swallows the epoch deprecation warning which is produced when we
        call `scheduler.step(epoch)` with some not `None` value of `epoch`.
        this is deprecated, and this function will need to be removed/updated when
        the schedulers no longer accept the parameter at all.
        """
        self.assertEqual(len(w), num_warnings)
        for warning in w:
            self.assertEqual(len(warning.message.args), 1)
            self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING)

    def test_error_when_getlr_has_epoch(self):
        class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
            def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
                self.init_lr = [group['lr'] for group in optimizer.param_groups]
                self.gamma = gamma
                self.milestones = milestones
                super().__init__(optimizer, last_epoch)

            def get_lr(self, step):
                global_step = self.last_epoch
                gamma_power = ([0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m])[-1]
                return [init_lr * (self.gamma ** gamma_power) for init_lr in self.init_lr]

        optimizer = torch.optim.SGD([torch.rand(1)], lr=1)

        with self.assertRaises(TypeError):
            scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])

    def test_no_cyclic_references(self):
        import gc
        param = Parameter(torch.empty(10))
        optim = SGD([param], lr=0.5)
        scheduler = LambdaLR(optim, lambda epoch: 1.0)
        del scheduler

        # Prior to Python 3.7, local variables in a function will be referred by the current frame.
        import sys
        if sys.version_info < (3, 7):
            import inspect
            referrers = gc.get_referrers(optim)
            self.assertTrue(
                len(referrers) == 1 and referrers[0] is inspect.currentframe(),
                "Optimizer should contain no cyclic references (except current frame)")
            del referrers
        else:
            self.assertTrue(
                len(gc.get_referrers(optim)) == 0,
                "Optimizer should contain no cyclic references")

        gc.collect()
        del optim
        self.assertEqual(
            gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__")

    def test_old_pattern_warning(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)

    def test_old_pattern_warning_with_arg(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)

    def test_old_pattern_warning_resuming(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group['initial_lr'] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)

    def test_old_pattern_warning_resuming_with_arg(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group['initial_lr'] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)

    def test_old_pattern_warning_with_overridden_optim_step(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group['initial_lr'] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        # emulate use-case with optimizer.step overridden
        import types

        old_step = self.opt.step

        def new_step(o, *args, **kwargs):
            retval = old_step(*args, **kwargs)
            return retval

        self.opt.step = types.MethodType(new_step, self.opt)

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)

    def test_new_pattern_no_warning(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            for _ in range(epochs):
                self.opt.step()
                scheduler.step()
            self.assertTrue(len(ws) == 0, "No warning should be raised")

    def test_new_pattern_no_warning_with_arg(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            for _ in range(epochs):
                self.opt.step()
                scheduler.step()
            self.assertTrue(len(ws) == 0, "No warning should be raised")

    def test_new_pattern_no_warning_with_overridden_optim_step(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        # emulate use-case with optimizer.step overridden
        import types

        old_step = self.opt.step

        def new_step(o, *args, **kwargs):
            retval = old_step(*args, **kwargs)
            return retval

        self.opt.step = types.MethodType(new_step, self.opt)

        def new_pattern():
            for e in range(epochs):
                self.opt.step()
                scheduler.step()

        self.assertWarnsRegex(UserWarning, r'`optimizer.step\(\)` has been overridden', new_pattern)

    def _test_lr_is_constant_for_constant_epoch(self, scheduler):
        l = []

        for _ in range(10):
            scheduler.optimizer.step()
            with warnings.catch_warnings(record=True) as w:
                scheduler.step(2)
                self._check_warning_is_epoch_deprecation_warning(w)

            l.append(self.opt.param_groups[0]['lr'])
        self.assertEqual(min(l), max(l))

    def test_step_lr_is_constant_for_constant_epoch(self):
        scheduler = StepLR(self.opt, 2)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_exponential_lr_is_constant_for_constant_epoch(self):
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_constantlr_is_constant_for_constant_epoch(self):
        scheduler = ConstantLR(self.opt)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_linear_linearlr_is_constant_for_constant_epoch(self):
        scheduler = LinearLR(self.opt)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_polynomial_lr_is_constant_for_constant_epoch(self):
        scheduler = PolynomialLR(self.opt, power=0.9)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_step_lr(self):
        # lr = 0.05     if epoch < 3
        # lr = 0.005    if 30 <= epoch < 6
        # lr = 0.0005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test(scheduler, targets, epochs)

    def test_get_last_lr_step_lr(self):
        from torch.nn import Parameter
        epochs = 10
        optimizer = torch.optim.SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
        targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]]
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_get_last_lr_multi_step_lr(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if 5 <= epoch < 9
        # lr = 0.00005   if 9 <= epoch
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_multi_step_lr(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if epoch < 9
        # lr = 0.00005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test(scheduler, targets, epochs)

    def test_multi_step_lr_with_epoch(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if epoch < 9
        # lr = 0.00005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_with_epoch(scheduler, targets, epochs)

    def test_get_last_lr_constantlr(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_get_last_lr_linearlr(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 4
        end_factor = 3. / 5
        iters = 4
        interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)]
        single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (epochs - iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters)
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_constantlr(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test(scheduler, targets, epochs)

    def test_linearlr(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 2
        iters = 4
        interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)]
        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test(scheduler, targets, epochs)

    def test_constantlr_with_epoch(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test_with_epoch(scheduler, targets, epochs)

    def test_linearlr_with_epoch(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 2
        end_factor = 1.
        iters = 4
        interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)]
        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test_with_epoch(scheduler, targets, epochs)

    def test_exp_lr(self):
        epochs = 10
        single_targets = [0.05 * (0.9 ** x) for x in range(epochs)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test(scheduler, targets, epochs)

    def test_poly_lr(self):
        epochs = 10
        power = 0.9
        total_iters = 5
        single_targets = [(1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)] + [0.0] * (epochs - total_iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters)
        self._test(scheduler, targets, epochs)

    def test_cos_anneal_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        self._test(scheduler, targets, epochs)

    def test_closed_form_step_lr(self):
        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_linearlr(self):
        scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4)
        closed_form_scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_constantlr(self):
        scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
        closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_multi_step_lr(self):
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_exp_lr(self):
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_poly_lr(self):
        scheduler = PolynomialLR(self.opt, power=0.9)
        closed_form_scheduler = PolynomialLR(self.opt, power=0.9)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_cos_anneal_lr(self):
        eta_min = 1e-10
        epochs = 20
        T_max = 5
        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
        closed_form_scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
        self._test_against_closed_form(scheduler, closed_form_scheduler, epochs)

    def test_cos_anneal_lr_continue(self):
        eta_min = 0.1
        T_max = 5
        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
        self.opt.step()
        scheduler.step()
        original_lrs = scheduler._last_lr
        new_scheduler = CosineAnnealingLR(
            self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0)
        new_lrs = new_scheduler._last_lr
        torch.testing.assert_allclose(original_lrs, new_lrs, rtol=1e-4, atol=1e-5)

    def test_reduce_lr_on_plateau1(self):
        epochs = 10
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 20]
        metrics = [10 - i * 0.0167 for i in range(20)]
        scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
                                      threshold=0.01, patience=5, cooldown=5)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau2(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
        metrics = [10 - i * 0.0165 for i in range(22)]
        scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
                                      mode='min', threshold=0.1)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau3(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [-0.8] * 2 + [-0.234] * 20
        scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
                                      threshold_mode='abs')
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau4(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 20]
        metrics = [1.5 * (1.025 ** i) for i in range(20)]  # 1.025 > 1.1**0.25
        scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3,
                                      threshold_mode='rel', threshold=0.1)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau5(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [1.5 * (1.005 ** i) for i in range(20)]
        scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel',
                                      threshold=0.1, patience=5, cooldown=5)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau6(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 20]
        metrics = [1.5 * (0.85 ** i) for i in range(20)]
        scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
                                      threshold=0.1)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau7(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [1] * 7 + [0.6] + [0.5] * 12
        scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
                                      threshold=0.1, patience=5, cooldown=5)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau8(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
        metrics = [1.5 * (1.005 ** i) for i in range(20)]
        scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3],
                                      threshold=0.1, patience=5, cooldown=5)
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_sequentiallr1(self):
        epochs = 19
        schedulers = [None] * 2
        targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)]
                                       + [0.05 * 0.1 for x in range(4)]
                                       + [0.05 * 0.01 for x in range(4)]
                                       + [0.05 * 0.001 for x in range(4)]]
        milestones = [3]
        schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr2(self):
        epochs = 13
        schedulers = [None] * 2
        targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]]
        milestones = [3]
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr3(self):
        epochs = 12
        schedulers = [None] * 3
        targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032]
                                         + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]]
        milestones = [3, 6]
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr4(self):
        optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1)
        prev_lr = optimizer.param_groups[0]["lr"]

        schedulers = [
            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1)
        ]
        scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=[10])

        new_lr = optimizer.param_groups[0]["lr"]

        # Ensure that multiple schedulers does not affect the initial learning rate
        self.assertEqual(prev_lr, new_lr)

    def test_get_last_lr_sequentiallr(self):
        epochs = 12
        milestones = [3, 6]
        schedulers = [None] * 3
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        constant_lr_target = [0.005] * 3
        exponential_lr_target = [0.05, 0.04, 0.032]
        step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
        single_targets = constant_lr_target + exponential_lr_target + step_lr_target
        targets = [single_targets, [x * 10 for x in single_targets]]
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_chained_lr2_get_last_lr_before_step(self):
        schedulers = [
            LinearLR(self.opt, start_factor=0.4, total_iters=3),
            MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
        ]
        scheduler = ChainedScheduler(schedulers)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr1(self):
        epochs = 10
        schedulers = [None] * 1
        targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr2(self):
        epochs = 10
        schedulers = [None] * 1
        targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr3(self):
        epochs = 10
        schedulers = [None] * 2
        targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3]
        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
        schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr4(self):
        epochs = 9
        schedulers = [None] * 3
        targets = [[0.05 * 0.2 * 0.9 ** x for x in range(3)]
                   + [0.05 * 0.2 * 0.9 ** 3 * 0.1]
                   + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)]
                   + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]]
        schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
        schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr5(self):
        def poly_lr(lr: float):
            return [
                (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters)
            ] + [0.0] * (epochs - total_iters)

        schedulers = [None] * 2
        epochs = 10
        power = 0.9
        total_iters = 5
        const_factor = 0.1
        single_targets = [x * const_factor for x in poly_lr(lr=0.05)]
        targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]]
        schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters)
        schedulers[1] = ConstantLR(self.opt, factor=const_factor)
        scheduler = ChainedScheduler(schedulers)
        self._test(scheduler, targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_compound_step_and_multistep_lr(self):
        epochs = 10
        schedulers = [None] * 2
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
        self._test(schedulers, targets, epochs)

    def test_compound_step_and_exp_lr(self):
        epochs = 10
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9 ** x) for x in range(3)]
        single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)]
        single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)]
        single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_exp_and_multistep_lr(self):
        epochs = 10
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9 ** x) for x in range(2)]
        single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)]
        single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)]
        single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_exp_and_linearlr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        end_factor = 0.9
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9 ** x) for x in range(11)]
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
        for i in range(iters, 11):
            single_targets[i] *= end_factor
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_step_and_constantlr(self):
        epochs = 10
        iters = 4
        factor = 0.4
        schedulers = [None] * 2
        single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
        self._test(schedulers, targets, epochs)

    def test_compound_linearlr_and_multistep_lr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        schedulers = [None] * 2
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_step_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_multistep_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_linearlr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        eta_min = 1e-10
        schedulers = [None] * 2
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_exp_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        multipliers = [0.1 ** i for i in range(epochs)]
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
        self._test(schedulers, targets, epochs)

    def test_compound_reduce_lr_on_plateau1(self):
        epochs = 10
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        single_targets = [0.5] * 20
        multipliers = [0.1 ** (i // 3) for i in range(20)]
        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0167 for i in range(20)]
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
                                          threshold=0.01, patience=5, cooldown=5)
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau2(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
        multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0165 for i in range(22)]
        schedulers = [None] * 2
        schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
                                          mode='min', threshold=0.1)
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau3(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
        multipliers = [0.1 ** i for i in range(epochs)]
        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [-0.8] * 2 + [-0.234] * 20
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
                                          threshold_mode='abs')
        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau4(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.05
        epochs = 10
        eta_min = 1e-10
        single_targets = [eta_min + (0.05 - eta_min) *
                          (1 + math.cos(math.pi * x / epochs)) / 2
                          for x in range(epochs)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [1.5 * (1.025 ** i) for i in range(20)]  # 1.025 > 1.1**0.25
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=3,
                                          threshold_mode='rel', threshold=0.1)
        schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau5(self):
        iters = 4
        start_factor = 0.4
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group['lr'] = 0.5
        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
        multipliers = [1] * 22
        for i in range(iters):
            multipliers[i] *= start_factor + i / iters * (1 - start_factor)
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0165 for i in range(22)]
        schedulers = [None] * 2
        schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
                                          mode='min', threshold=0.1)
        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_cycle_lr_invalid_mode(self):
        with self.assertRaises(ValueError):
            scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")

    def test_cycle_lr_triangular_mode_one_lr(self):
        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
                             cycle_momentum=True, base_momentum=1, max_momentum=5,
                             mode='triangular')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_targets = [lr_target, lr_target]
        momentum_target = [self.opt.defaults['momentum']] * len(lr_target)
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
                             cycle_momentum=False, mode='triangular')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular2_mode_one_lr(self):
        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5,
                     1, 1.25, 1.50, 1.75, 2.00, 1.75]
        momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0,
                           3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
                             cycle_momentum=True, base_momentum=1, max_momentum=5,
                             mode='triangular2')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_exp_range_mode_one_lr(self):
        base_lr, max_lr = 1, 5
        diff_lr = max_lr - base_lr
        gamma = 0.9
        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=base_lr,
                             max_lr=max_lr, step_size_up=4,
                             cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr,
                             mode='exp_range', gamma=gamma)
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular_mode(self):
        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_target_2 = [x + 1 for x in lr_target_1]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
        momentum_target_2 = [x + 1 for x in momentum_target_1]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4,
                             cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6],
                             mode='triangular')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_triangular2_mode(self):
        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1,
                       1.25, 1.50, 1.75, 2.00, 1.75]
        lr_target_2 = [x + 2 for x in lr_target_1]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5,
                             3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
        momentum_target_2 = [x + 2 for x in momentum_target_1]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4,
                             cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7],
                             mode='triangular2')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_exp_range_mode(self):
        base_lr_1, max_lr_1 = 1, 5
        base_lr_2, max_lr_2 = 5, 12

        diff_lr_1 = max_lr_1 - base_lr_1
        diff_lr_2 = max_lr_2 - base_lr_2

        gamma = 0.9
        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
        lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
        lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
        momentum_target_2 = [max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2],
                             max_lr=[max_lr_1, max_lr_2], step_size_up=4,
                             cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2],
                             max_momentum=[max_lr_1, max_lr_2],
                             mode='exp_range', gamma=gamma)
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_triangular_mode_step_size_up_down(self):
        lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0]
        lr_targets = [lr_target, lr_target]
        momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0]
        momentum_targets = [momentum_target, momentum_target]

        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5,
                             step_size_up=4,
                             step_size_down=6,
                             cycle_momentum=True,
                             base_momentum=1, max_momentum=5,
                             mode='triangular')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular2_mode_step_size_up_down(self):
        lr_base_target = ([
            1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3,
            7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6,
            8.0 / 6, 7.0 / 6
        ])
        momentum_base_target = ([
            5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3,
            11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3,
            29.0 / 6
        ])
        deltas = [2 * i for i in range(0, 2)]
        base_lrs = [1 + delta for delta in deltas]
        max_lrs = [5 + delta for delta in deltas]
        lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
        momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas]
        scheduler = CyclicLR(
            self.opt,
            base_lr=base_lrs,
            max_lr=max_lrs,
            step_size_up=2,
            step_size_down=6,
            cycle_momentum=True,
            base_momentum=base_lrs,
            max_momentum=max_lrs,
            mode='triangular2')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target))

    def test_cycle_lr_exp_range_mode_step_size_up_down(self):
        base_lr, max_lr = 1, 5
        diff_lr = max_lr - base_lr
        gamma = 0.9
        xs = ([
            0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6,
            4.0 / 6
        ])
        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target, lr_target]
        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr,
                             step_size_up=2, step_size_down=6,
                             cycle_momentum=True, base_momentum=base_lr,
                             max_momentum=max_lr,
                             mode='exp_range', gamma=gamma)
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_with_momentumless_optimizer(self):
        # Note [Temporarily set optimizer to Adam]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # The TestLRScheduler object carries around an SGD optimizer to avoid having to
        # instantiate one for every test. This gets in the way for our very specific case
        # in which we need to use Adam (or really any optimizer that doesn't use momentum)
        # in order to test that the momentum bug in CyclicLR is fixed (the bug is described
        # in more detail in https://github.com/pytorch/pytorch/issues/19003 ).
        old_opt = self.opt
        self.opt = optim.Adam(
            [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
            lr=0.05)

        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_targets = [lr_target, lr_target]
        momentum_target = [None] * len(lr_target)
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
                             cycle_momentum=False, mode='triangular')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

        self.opt = old_opt  # set optimizer back to SGD

    def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
        with self.assertRaises(ValueError):
            adam_opt = optim.Adam(self.net.parameters())
            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)

    def test_cycle_lr_removed_after_out_of_scope(self):
        import gc
        import weakref
        gc.disable()

        def test():
            adam_opt = optim.Adam(self.net.parameters())
            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
            return weakref.ref(scheduler)

        ref = test()
        assert ref() is None
        gc.enable()

    def test_onecycle_lr_invalid_anneal_strategy(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS")

    def test_onecycle_lr_invalid_pct_start(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)

    def test_onecycle_lr_cannot_calculate_total_steps(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(self.opt, max_lr=1e-3)

    def test_onecycle_lr_linear_annealing(self):
        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
                               total_steps=10, anneal_strategy='linear')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_onecycle_lr_linear_annealing_three_phases(self):
        lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
        momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(self.opt, max_lr=25, div_factor=25,
                               base_momentum=1, max_momentum=22,
                               total_steps=10, anneal_strategy='linear',
                               pct_start=0.4, final_div_factor=4,
                               three_phase=True)
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_onecycle_lr_cosine_annealing(self):
        def annealing_cos(start, end, pct):
            cos_out = math.cos(math.pi * pct) + 1
            return end + (start - end) / 2.0 * cos_out
        lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0),
                     annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0),
                     annealing_cos(25, 0.5, 6 / 7.0), 0.5]
        momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0),
                           annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0),
                           annealing_cos(1, 22, 6 / 7.0), 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
                               total_steps=10)
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_cycle_lr_with_adam(self):
        old_opt = self.opt
        self.opt = optim.Adam(
            [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
            lr=0.05)

        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
                               total_steps=10, anneal_strategy='linear')
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
        self.opt = old_opt  # set optimizer back to SGD

    def test_lambda_lr(self):
        epochs = 10
        self.opt.param_groups[0]['lr'] = 0.05
        self.opt.param_groups[1]['lr'] = 0.4
        targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]]
        scheduler = LambdaLR(self.opt,
                             lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
        self._test(scheduler, targets, epochs)

    def test_multiplicative_lr(self):
        epochs = 10
        self.opt.param_groups[0]['lr'] = 0.05
        self.opt.param_groups[1]['lr'] = 0.4
        targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]]
        scheduler = MultiplicativeLR(self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8])
        self._test(scheduler, targets, epochs)

    @parametrize("T_mult", [1, 2, 4])
    def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
        iters = 100
        eta_min = 1e-10
        T_i = 10
        T_cur = 0
        targets = [[0.05], [0.5]]
        scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min)
        for _ in range(1, iters, 1):
            T_cur += 1
            if T_cur >= T_i:
                T_cur = T_cur - T_i
                T_i = int(T_mult) * T_i
            targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
            targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
        self._test(scheduler, targets, iters)

    def test_CosineAnnealingWarmRestarts_lr2(self):
        iters = 30
        eta_min = 1e-10
        T_mults = [1, 2, 4]
        for T_mult in T_mults:
            T_i = 10
            T_cur = 0
            targets = [[0.05], [0.5]]
            scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min)
            for _ in torch.arange(0.1, iters, 0.1):
                T_cur = round(T_cur + 0.1, 1)
                if T_cur >= T_i:
                    T_cur = T_cur - T_i
                    T_i = int(T_mult) * T_i
                targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
                targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
            self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters)

    def test_CosineAnnealingWarmRestarts_lr3(self):
        epochs_for_T_mults = [[0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13],
                              [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3],
                              [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50]]
        T_curs_for_T_mults = [[1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3],
                              [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3],
                              [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10]]
        T_is_for_T_mults = [[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
                            [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10],
                            [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90]]
        eta_min = 1e-10
        T_mults = [1, 2, 3]
        for epochs, T_mult, T_curs, T_is in zip(epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults):
            targets = [[0.05], [0.5]]
            scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min)
            for T_cur, T_i in zip(T_curs, T_is):
                targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
                targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
            self._test_interleaved_CosineAnnealingWarmRestarts(scheduler, targets, epochs)

    def test_swalr_no_anneal(self):
        epochs, swa_start, swa_lr = 10, 5, 0.01
        initial_lrs = [group['lr'] for group in self.opt.param_groups]
        targets = [[lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
                   for lr in initial_lrs]
        swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
        self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)

    def test_swalr_cosine_anneal_after_multiplicative(self):
        # same swa_lr for different param_groups
        epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
        mult_factor = 0.9
        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)

        def anneal_coef(t):
            if t + 1 >= anneal_epochs:
                return 0.
            return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2

        initial_lrs = [group['lr'] for group in self.opt.param_groups]
        targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)]
                              for lr in initial_lrs]
        swa_epochs = epochs - swa_start - 1
        targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)]
                   for lrs in targets_before_swa]

        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

    def test_swalr_linear_anneal_after_multiplicative(self):
        # separate swa_lr for different param_groups
        epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
        mult_factor = 0.9
        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs,
                              anneal_strategy="linear", swa_lr=swa_lrs)

        def anneal_coef(t):
            if t + 1 >= anneal_epochs:
                return 0.
            return 1 - (t + 1) / anneal_epochs

        initial_lrs = [group['lr'] for group in self.opt.param_groups]
        targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)]
                              for lr in initial_lrs]
        swa_epochs = epochs - swa_start - 1
        targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)]
                   for lrs, swa_lr in zip(targets_before_swa, swa_lrs)]

        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

    def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
        for epoch in range(epochs):
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[epoch], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0)
            if epoch >= swa_start:
                self.opt.step()
                swa_scheduler.step()
            elif scheduler is not None:
                self.opt.step()
                scheduler.step()

    def test_swalr_hypers(self):
        # Test that SWALR raises errors for incorrect hyper-parameters
        with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
            swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.)

        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
            swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.)
        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
            swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.)
        with self.assertRaisesRegex(ValueError, "swa_lr must"):
            swa_scheduler = SWALR(self.opt, swa_lr=[1., 0.1, 0.01])

    def test_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: StepLR(self.opt, gamma=0.1, step_size=3),
            lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1))

    def test_multi_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
            lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]))

    def test_exp_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: ExponentialLR(self.opt, gamma=0.1),
            lambda: ExponentialLR(self.opt, gamma=0.01))

    def test_cosine_lr_state_dict(self):
        epochs = 10
        eta_min = 1e-10
        self._check_scheduler_state_dict(
            lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
            lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
            epochs=epochs)

    def test_reduce_lr_on_plateau_state_dict(self):
        scheduler = ReduceLROnPlateau(self.opt, mode='min', factor=0.1, patience=2)
        for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
            scheduler.step(score)
        scheduler_copy = ReduceLROnPlateau(self.opt, mode='max', factor=0.5, patience=10)
        scheduler_copy.load_state_dict(scheduler.state_dict())
        for key in scheduler.__dict__.keys():
            if key not in {'optimizer', 'is_better'}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_lambda_lr_state_dict_fn(self):
        scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x)
        state = scheduler.state_dict()
        self.assertIsNone(state['lr_lambdas'][0])

        scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x)
        scheduler_copy.load_state_dict(state)
        for key in scheduler.__dict__.keys():
            if key not in {'optimizer', 'lr_lambdas'}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_lambda_lr_state_dict_obj(self):
        scheduler = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(10))
        state = scheduler.state_dict()
        self.assertIsNotNone(state['lr_lambdas'][0])

        scheduler_copy = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(-1))
        scheduler_copy.load_state_dict(state)
        for key in scheduler.__dict__.keys():
            if key not in {'optimizer'}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100))

    def test_swa_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
            lambda: SWALR(self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.))

    def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
        scheduler = constr()
        for _ in range(epochs):
            scheduler.optimizer.step()
            scheduler.step()
        scheduler_copy = constr2()
        scheduler_copy.load_state_dict(scheduler.state_dict())
        for key in scheduler.__dict__.keys():
            if key != 'optimizer':
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
        self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())

    def _test_get_last_lr(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, _LRScheduler):
            schedulers = [schedulers]
        optimizers = {scheduler.optimizer for scheduler in schedulers}
        for epoch in range(epochs):
            result = [scheduler.get_last_lr() for scheduler in schedulers]
            [optimizer.step() for optimizer in optimizers]
            [scheduler.step() for scheduler in schedulers]
            target = [[t[epoch] for t in targets]] * len(schedulers)
            for t, r in zip(target, result):
                self.assertEqual(target, result,
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, t, r), atol=1e-5, rtol=0)

    def _test_with_epoch(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, _LRScheduler):
            schedulers = [schedulers]
        optimizers = {scheduler.optimizer for scheduler in schedulers}
        for epoch in range(epochs):
            [optimizer.step() for optimizer in optimizers]
            with warnings.catch_warnings(record=True) as w:
                [scheduler.step(epoch) for scheduler in schedulers]  # step before assert: skip initial lr
                self._check_warning_is_epoch_deprecation_warning(w, num_warnings=len(schedulers))
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[epoch], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0)

    def _test(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, _LRScheduler):
            schedulers = [schedulers]
        for epoch in range(epochs):
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[epoch], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0)
            [scheduler.step() for scheduler in schedulers]

    def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10):
        for index, epoch in enumerate(torch.arange(0, epochs, 0.1)):
            epoch = round(epoch.item(), 1)
            scheduler.step(epoch)
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[index], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[index], param_group['lr']), atol=1e-5, rtol=0)

    def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs):
        for index, epoch in enumerate(epochs):
            scheduler.step(epoch)
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[index], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[index], param_group['lr']), atol=1e-5, rtol=0)

    def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10):
        self.setUp()
        targets = []
        for epoch in range(epochs):
            closed_form_scheduler.optimizer.step()
            with warnings.catch_warnings(record=True) as w:
                closed_form_scheduler.step(epoch)
                self._check_warning_is_epoch_deprecation_warning(w)
            targets.append([group['lr'] for group in self.opt.param_groups])
        self.setUp()
        for epoch in range(epochs):
            self.opt.step()
            scheduler.step()
            for i, param_group in enumerate(self.opt.param_groups):
                self.assertEqual(targets[epoch][i], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, targets[epoch][i], param_group['lr']), atol=1e-5, rtol=0)

    def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, verbose=False):
        if isinstance(schedulers, _LRScheduler) or isinstance(schedulers, ReduceLROnPlateau):
            schedulers = [schedulers]
        for epoch in range(epochs):
            self.opt.step()
            for scheduler in schedulers:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(metrics[epoch])
                else:
                    scheduler.step()
            if verbose:
                print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr']))
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(target[epoch], param_group['lr'],
                                 msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                     epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0)

    def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False):
        for batch_num in range(batch_iterations):
            if verbose:
                if 'momentum' in self.opt.param_groups[0].keys():
                    print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
                                                               self.opt.param_groups[0]['momentum']))
                elif use_beta1 and 'betas' in self.opt.param_groups[0].keys():
                    print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'],
                                                            self.opt.param_groups[0]['betas'][0]))
                else:
                    print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr']))

            for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets):
                self.assertEqual(
                    lr_target[batch_num], param_group['lr'],
                    msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
                        batch_num, lr_target[batch_num], param_group['lr']), atol=1e-5, rtol=0)

                if use_beta1 and 'betas' in param_group.keys():
                    self.assertEqual(
                        momentum_target[batch_num], param_group['betas'][0],
                        msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format(
                            batch_num, momentum_target[batch_num], param_group['betas'][0]), atol=1e-5, rtol=0)
                elif 'momentum' in param_group.keys():
                    self.assertEqual(
                        momentum_target[batch_num], param_group['momentum'],
                        msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
                            batch_num, momentum_target[batch_num], param_group['momentum']), atol=1e-5, rtol=0)
            self.opt.step()
            scheduler.step()

    def test_cosine_then_cyclic(self):
        # https://github.com/pytorch/pytorch/issues/21965

        max_lr = 0.3
        base_lr = 0.1
        optim_lr = 0.5

        model = torch.nn.Linear(2, 1)
        optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr)
        lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1)
        lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
            optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
        )

        for i in range(40):
            optimizer.step()
            if i <= lr_scheduler_1.T_max:
                lr_scheduler_1.step()
            else:
                lr_scheduler_2.step()
            last_lr = optimizer.param_groups[0]["lr"]

        self.assertLessEqual(last_lr, max_lr)


class SWATestDNN(torch.nn.Module):
    def __init__(self, input_features):
        super(SWATestDNN, self).__init__()
        self.n_features = 100
        self.fc1 = torch.nn.Linear(input_features, self.n_features)
        self.bn = torch.nn.BatchNorm1d(self.n_features)

    def compute_preactivation(self, x):
        return self.fc1(x)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn(x)
        return x


class SWATestCNN(torch.nn.Module):
    def __init__(self, input_channels):
        super(SWATestCNN, self).__init__()
        self.n_features = 10
        self.conv1 = torch.nn.Conv2d(input_channels, self.n_features, kernel_size=3, padding=1)
        self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)

    def compute_preactivation(self, x):
        return self.conv1(x)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        return x


class TestSWAUtils(TestCase):

    def _test_averaged_model(self, net_device, swa_device):
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.BatchNorm2d(5, momentum=0.3),
            torch.nn.Conv2d(5, 2, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 10)
        ).to(net_device)

        averaged_dnn = AveragedModel(dnn, device=swa_device)
        averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
        n_updates = 10
        for i in range(n_updates):
            for p, p_avg in zip(dnn.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                p_avg += p.detach() / n_updates
            averaged_dnn.update_parameters(dnn)

        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
            self.assertEqual(p_avg, p_swa)
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_swa.device == swa_device)
            self.assertTrue(p.device == net_device)
        self.assertTrue(averaged_dnn.n_averaged.device == swa_device)

    def test_averaged_model_all_devices(self):
        cpu = torch.device("cpu")
        self._test_averaged_model(cpu, cpu)
        if torch.cuda.is_available():
            cuda = torch.device(0)
            self._test_averaged_model(cuda, cpu)
            self._test_averaged_model(cpu, cuda)
            self._test_averaged_model(cuda, cuda)

    def test_averaged_model_mixed_device(self):
        if not torch.cuda.is_available():
            return
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.Linear(5, 10)
        )
        dnn[0].cuda()
        dnn[1].cpu()
        averaged_dnn = AveragedModel(dnn)
        averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
        n_updates = 10
        for i in range(n_updates):
            for p, p_avg in zip(dnn.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                p_avg += p.detach() / n_updates
            averaged_dnn.update_parameters(dnn)

        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
            self.assertEqual(p_avg, p_swa)
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_avg.device == p_swa.device)

    def test_averaged_model_state_dict(self):
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.Linear(5, 10)
        )
        averaged_dnn = AveragedModel(dnn)
        averaged_dnn2 = AveragedModel(dnn)
        n_updates = 10
        for i in range(n_updates):
            for p in dnn.parameters():
                p.detach().add_(torch.randn_like(p))
            averaged_dnn.update_parameters(dnn)
        averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
        for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
            self.assertEqual(p_swa, p_swa2)
        self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)

    def test_averaged_model_exponential(self):
        # Test AveragedModel with EMA as avg_fn
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.Linear(5, 10)
        )
        alpha = 0.9

        def avg_fn(p_avg, p, n_avg):
            return alpha * p_avg + (1 - alpha) * p
        averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)
        averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            for p, p_avg in zip(dnn.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    updated_averaged_params.append((p_avg * alpha +
                                                   p * (1 - alpha)).clone())
            averaged_dnn.update_parameters(dnn)
            averaged_params = updated_averaged_params

        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
            self.assertEqual(p_avg, p_swa)

    def test_averaged_model_exponential_buffers(self):
        # Test AveragedModel with EMA as avg_fn and use_buffers as True.
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5, momentum=0.3),
            torch.nn.Linear(5, 10)
        )
        alpha = 0.9

        def avg_fn(p_avg, p, n_avg):
            return alpha * p_avg + (1 - alpha) * p
        averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
        dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
        averaged_params = [torch.zeros_like(param) for param in dnn_params
                           if param.size() != torch.Size([])]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            for p, p_avg in zip(dnn_params, averaged_params):
                if p.size() == torch.Size([]):
                    continue
                p.detach().add_(torch.randn_like(p))
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    updated_averaged_params.append((p_avg * alpha +
                                                   p * (1 - alpha)).clone())
            averaged_dnn.update_parameters(dnn)
            averaged_params = updated_averaged_params

        for p_avg, p_swa in zip(
                averaged_params, itertools.chain(averaged_dnn.module.parameters(), averaged_dnn.module.buffers())):
            self.assertEqual(p_avg, p_swa)

    def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

        preactivation_sum = torch.zeros(dnn.n_features)
        preactivation_squared_sum = torch.zeros(dnn.n_features)
        if cuda:
            preactivation_sum = preactivation_sum.cuda()
            preactivation_squared_sum = preactivation_squared_sum.cuda()
        total_num = 0
        for x in dl_x:
            x = x[0]
            if cuda:
                x = x.cuda()

            dnn.forward(x)
            preactivations = dnn.compute_preactivation(x)
            if len(preactivations.shape) == 4:
                preactivations = preactivations.transpose(1, 3)
            preactivations = preactivations.contiguous().view(-1, dnn.n_features)
            total_num += preactivations.shape[0]

            preactivation_sum += torch.sum(preactivations, dim=0)
            preactivation_squared_sum += torch.sum(preactivations**2, dim=0)

        preactivation_mean = preactivation_sum / total_num
        preactivation_var = preactivation_squared_sum / total_num
        preactivation_var = preactivation_var - preactivation_mean**2

        update_bn(dl_xy, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)

        def _reset_bn(module):
            if issubclass(module.__class__,
                          torch.nn.modules.batchnorm._BatchNorm):
                module.running_mean = torch.zeros_like(module.running_mean)
                module.running_var = torch.ones_like(module.running_var)
        # reset batch norm and run update_bn again
        dnn.apply(_reset_bn)
        update_bn(dl_xy, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
        # using the dl_x loader instead of dl_xy
        dnn.apply(_reset_bn)
        update_bn(dl_x, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)

    def test_update_bn_dnn(self):
        # Test update_bn for a fully-connected network with BatchNorm1d
        objects, input_features = 100, 5
        x = torch.rand(objects, input_features)
        y = torch.rand(objects)
        ds_x = torch.utils.data.TensorDataset(x)
        ds_xy = torch.utils.data.TensorDataset(x, y)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
        dnn = SWATestDNN(input_features=input_features)
        dnn.train()
        self._test_update_bn(dnn, dl_x, dl_xy, False)
        if torch.cuda.is_available():
            dnn = SWATestDNN(input_features=input_features)
            dnn.train()
            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
        self.assertTrue(dnn.training)

    def test_update_bn_cnn(self):
        # Test update_bn for convolutional network and BatchNorm2d
        objects = 100
        input_channels = 3
        height, width = 5, 5
        x = torch.rand(objects, input_channels, height, width)
        y = torch.rand(objects)
        ds_x = torch.utils.data.TensorDataset(x)
        ds_xy = torch.utils.data.TensorDataset(x, y)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
        dnn = SWATestCNN(input_channels=input_channels)
        dnn.train()
        self._test_update_bn(dnn, dl_x, dl_xy, False)
        if torch.cuda.is_available():
            dnn = SWATestCNN(input_channels=input_channels)
            dnn.train()
            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
        self.assertTrue(dnn.training)

    def test_bn_update_eval_momentum(self):
        # check that update_bn preserves eval mode
        objects = 100
        input_channels = 3
        height, width = 5, 5
        x = torch.rand(objects, input_channels, height, width)
        ds_x = torch.utils.data.TensorDataset(x)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        dnn = SWATestCNN(input_channels=input_channels)
        dnn.eval()
        update_bn(dl_x, dnn)
        self.assertFalse(dnn.training)

        # check that momentum is preserved
        self.assertEqual(dnn.bn.momentum, 0.3)


instantiate_parametrized_tests(TestLRScheduler)


def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
    # Ignored is the list of values in `opt_differentiable_state`, we do this
    # for `gradcheck` to correctly track the state tensors as function inputs
    # because otherwise it can't unpack the values in the `opt_differentiable_state`
    # dict
    p = p.clone()
    p.grad = grad
    opt_differentiable_state = {
        k: v.clone() if isinstance(v, torch.Tensor) else v
        for k, v in opt_differentiable_state.items()
    }
    opt = opt_class([p], **kwargs)
    opt.state[p].update(opt_differentiable_state)
    opt.step()
    return (p,) + tuple(
        v for v in opt_differentiable_state.values() if isinstance(v, torch.Tensor) and v.requires_grad)


class TestDifferentiableOptimizer(TestCase):

    def test_sgd(self):
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
        mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
        state = {'momentum_buffer': mbuff}
        gradcheck(_diff_fn, (p, grad, state, torch.optim.SGD, {'lr': 0.9, 'differentiable': True}, *state.values()))

    def test_adam(self):
        state = {}
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
        # `step` is not a continuous variable (even though we define it as a float)
        # and so it shouldn't require gradients.
        state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64)
        state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64)
        state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64)
        state['max_exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64)

        gradcheck(
            _diff_fn,
            (p, grad, state, torch.optim.Adam,
             {'lr': 0.9, 'differentiable': True, 'amsgrad': True}, *state.values())
        )

    def test_rmsprop(self):
        state = {}
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
        state['step'] = 0
        state['square_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64)
        state['momentum_buffer'] = torch.rand(10, requires_grad=True, dtype=torch.float64)
        # This can cause issues with large values and nan due to sqrt ops
        state['grad_avg'] = 1e-2 * torch.rand(10, requires_grad=True, dtype=torch.float64)
        gradcheck(
            _diff_fn,
            (p, grad, state, torch.optim.RMSprop,
             {'lr': 0.9, 'maximize': True, 'momentum': 0.9, 'differentiable': True, 'centered': True, 'weight_decay': 0.1},
             *state.values()))


if __name__ == '__main__':
    run_tests()
