1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
# mypy: ignore-errors
# Owner(s): ["oncall: distributed"]
from typing import Tuple
import torch
import torch.nn as nn
class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.seq(self.l1(x)))
class CompositeModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.u2(self.u1(self.l1(x))))
class UnitParamModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x):
return torch.mm(self.seq(self.l(x)), self.p)
class CompositeParamModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
self.register_buffer(
"buffer", torch.randn((100, 100), device=device), persistent=True
)
def forward(self, x):
a = self.u2(self.u1(self.l(x)))
b = self.p
return torch.mm(a, b)
class FakeSequential(nn.Module):
# Define this class to achieve a desired nested wrapping using the module
# wrap policy with `nn.Sequential`
def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
super().__init__()
self._module_sequence = list(modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self._module_sequence:
x = module(x)
return x
class NestedSequentialModel(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
# This nested structure exercises traversal order to catch differences
# between valid traversals (e.g. BFS and DFS variations).
self.seq1 = nn.Sequential(
nn.Linear(1, 1, device=device),
FakeSequential(
nn.Linear(1, 1, device=device),
nn.ReLU(),
FakeSequential(
nn.Linear(1, 1, device=device),
),
nn.ReLU(),
),
nn.Linear(1, 2, device=device),
)
self.lin = nn.Linear(2, 2, device=device)
self.seq2 = nn.Sequential(
nn.ReLU(),
nn.Linear(2, 3, device=device),
FakeSequential(
nn.Linear(3, 2, bias=False, device=device),
nn.Linear(2, 4, bias=False, device=device),
),
)
# FIXME(rec): forward() is not a method, it's a local function inside __init__
# that is never used. It should probabkly be outdented by four spaces, or removed.
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.lin(self.seq1(x)))
|