1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
import torch
from torch import Tensor
import torchvision_models as models
from utils import check_for_functorch, extract_weights, load_weights, GetterReturnType
from typing import cast
has_functorch = check_for_functorch()
def get_resnet18(device: torch.device) -> GetterReturnType:
N = 32
model = models.resnet18(pretrained=False)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
params, names = extract_weights(model)
inputs = torch.rand([N, 3, 224, 224], device=device)
labels = torch.rand(N, device=device).mul(10).long()
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)
loss = criterion(out, labels)
return loss
return forward, params
def get_fcn_resnet(device: torch.device) -> GetterReturnType:
N = 8
criterion = torch.nn.MSELoss()
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
# disable dropout for consistency checking
model.eval()
model.to(device)
params, names = extract_weights(model)
inputs = torch.rand([N, 3, 480, 480], device=device)
# Given model has 21 classes
labels = torch.rand([N, 21, 480, 480], device=device)
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)['out']
loss = criterion(out, labels)
return loss
return forward, params
def get_detr(device: torch.device) -> GetterReturnType:
# All values below are from CLI defaults in https://github.com/facebookresearch/detr
N = 2
num_classes = 91
hidden_dim = 256
nheads = 8
num_encoder_layers = 6
num_decoder_layers = 6
model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(model)
losses = ['labels', 'boxes', 'cardinality']
eos_coef = 0.1
bbox_loss_coef = 5
giou_loss_coef = 2
weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef, 'loss_giou': giou_loss_coef}
matcher = models.HungarianMatcher(1, 5, 2)
criterion = models.SetCriterion(num_classes=num_classes, matcher=matcher, weight_dict=weight_dict,
eos_coef=eos_coef, losses=losses)
model = model.to(device)
criterion = criterion.to(device)
params, names = extract_weights(model)
inputs = torch.rand(N, 3, 800, 1200, device=device)
labels = []
for idx in range(N):
targets = {}
n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
label = torch.randint(5, 10, size=(n_targets,), device=device)
targets["labels"] = label
boxes = torch.randint(100, 800, size=(n_targets, 4), device=device)
for t in range(n_targets):
if boxes[t, 0] > boxes[t, 2]:
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
if boxes[t, 1] > boxes[t, 3]:
boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1]
targets["boxes"] = boxes.float()
labels.append(targets)
def forward(*new_params: Tensor) -> Tensor:
load_weights(model, names, new_params)
out = model(inputs)
loss = criterion(out, labels)
weight_dict = criterion.weight_dict
final_loss = cast(Tensor, sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict))
return final_loss
return forward, params
|