1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
# Owner(s): ["oncall: distributed"]
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import SGD, Adam, AdamW
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.distributed.optim.utils import functional_optim_map, register_functional_optim
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.lin1 = nn.Linear(3, 3, bias=False)
self.lin2 = nn.Linear(3, 3, bias=False)
def forward(self, t1):
return self.lin2(F.relu(self.lin1(t1)))
# dummy class to showcase custom optimizer registration with functional wrapper
class MyDummyFnOptimizer(object):
def __init__(
self,
params: List[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 < weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
self.defaults = {
"lr": lr,
"eps": eps,
"beta1": betas[0],
"beta2": betas[1],
"weight_decay": weight_decay,
}
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
def step_param(self, param: Tensor, grad: Optional[Tensor]):
# call the custom optimizer step_param implementation
with torch.no_grad():
raise RuntimeError("MyDummyFnOptimizer does not support step_param() as of now")
def step(self, gradients: List[Optional[Tensor]]):
# call the custom optimizer step implementation
with torch.no_grad():
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
class TestFunctionalOptimParity(TestCase):
def _validate_parameters(self, params_1, params_2):
for p1, p2 in zip(params_1, params_2):
self.assertEqual(p1, p2)
def _test_functional_optim_parity(self, optim_cls, *args, **kwargs):
module_optim = MyModule()
module_functional = MyModule()
optim_params = module_optim.parameters()
functional_params = module_functional.parameters()
optim = optim_cls(optim_params, *args, **kwargs)
functional_optim_cls = functional_optim_map.get(optim_cls, None)
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
optim_functional = functional_optim_cls(
[], *args, **kwargs, _allow_empty_param_list=True
)
if not hasattr(optim_functional, "step_param"):
raise ValueError(
f"Functional optimizer class {optim_functional} must implement step_param method."
)
# Initial weights should match
self._validate_parameters(
module_optim.parameters(), module_functional.parameters()
)
# Save old parameters to verify optimizer modifies them.
old_module_optim_params = [
param.clone().detach() for param in module_optim.parameters()
]
old_module_functional_params = [
param.clone().detach() for param in module_functional.parameters()
]
t1 = torch.randn(3, 3)
for _ in range(10):
module_optim.zero_grad()
module_functional.zero_grad()
# Forward + Backward
optim_out = module_optim(t1).sum()
functional_out = module_functional(t1).sum()
optim_out.backward()
functional_out.backward()
# Optimizer step
optim.step()
# Functional optimizer step_param
for param in module_functional.parameters():
grad = param.grad
optim_functional.step_param(param, grad)
# Validate parameters are equal
for optim_param, functional_param in zip(
module_optim.parameters(), module_functional.parameters()
):
self.assertEqual(optim_param, functional_param)
# Validate parameters are modified.
for i, (optim_param, functional_param) in enumerate(
zip(module_optim.parameters(), module_functional.parameters())
):
self.assertNotEqual(old_module_optim_params[i], optim_param)
self.assertNotEqual(old_module_functional_params[i], functional_param)
def _test_functional_optim_registration(self):
fn_map_key = "MyDummyFnOptimizer"
fn_optim = MyDummyFnOptimizer
register_functional_optim(fn_map_key, fn_optim)
functional_optim_cls = functional_optim_map.get(fn_map_key, None)
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not registered for {fn_map_key}")
def test_functional_optim_registration(self):
self._test_functional_optim_registration()
def test_functional_optim_parity_sgd(self):
self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01)
def test_functional_optim_parity_adam(self):
self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6)
def test_functional_optim_parity_adam_w(self):
self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6)
if __name__ == "__main__":
run_tests()
|