1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
from typing import cast
import torchvision_models as models
from utils import check_for_functorch, extract_weights, GetterReturnType, load_weights
import torch
from torch import Tensor
has_functorch = check_for_functorch()
def get_resnet18(device: torch.device) -> GetterReturnType:
N = 32
model = models.resnet18(pretrained=False)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
params, names = extract_weights(model)
inputs = torch.rand([N, 3, 224, 224], device=device)
labels = torch.rand(N, device=device).mul(10).long()
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)
loss = criterion(out, labels)
return loss
return forward, params
def get_fcn_resnet(device: torch.device) -> GetterReturnType:
N = 8
criterion = torch.nn.MSELoss()
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
# disable dropout for consistency checking
model.eval()
model.to(device)
params, names = extract_weights(model)
inputs = torch.rand([N, 3, 480, 480], device=device)
# Given model has 21 classes
labels = torch.rand([N, 21, 480, 480], device=device)
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)["out"]
loss = criterion(out, labels)
return loss
return forward, params
def get_detr(device: torch.device) -> GetterReturnType:
# All values below are from CLI defaults in https://github.com/facebookresearch/detr
N = 2
num_classes = 91
hidden_dim = 256
nheads = 8
num_encoder_layers = 6
num_decoder_layers = 6
model = models.DETR(
num_classes=num_classes,
hidden_dim=hidden_dim,
nheads=nheads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
losses = ["labels", "boxes", "cardinality"]
eos_coef = 0.1
bbox_loss_coef = 5
giou_loss_coef = 2
weight_dict = {
"loss_ce": 1,
"loss_bbox": bbox_loss_coef,
"loss_giou": giou_loss_coef,
}
matcher = models.HungarianMatcher(1, 5, 2)
criterion = models.SetCriterion(
num_classes=num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=eos_coef,
losses=losses,
)
model = model.to(device)
criterion = criterion.to(device)
params, names = extract_weights(model)
inputs = torch.rand(N, 3, 800, 1200, device=device)
labels = []
for idx in range(N):
targets = {}
n_targets: int = int(torch.randint(5, 10, size=()).item())
label = torch.randint(5, 10, size=(n_targets,), device=device)
targets["labels"] = label
boxes = torch.randint(100, 800, size=(n_targets, 4), device=device)
for t in range(n_targets):
if boxes[t, 0] > boxes[t, 2]:
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
if boxes[t, 1] > boxes[t, 3]:
boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1]
targets["boxes"] = boxes.float()
labels.append(targets)
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)
loss = criterion(out, labels)
weight_dict = criterion.weight_dict
final_loss = cast(
Tensor,
sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict),
)
return final_loss
return forward, params
|