1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
|
#!/usr/bin/env python3
import importlib
import logging
import os
import re
import subprocess
import sys
import warnings
try:
from .common import BenchmarkRunner, download_retry_decorator, main
except ImportError:
from common import BenchmarkRunner, download_retry_decorator, main
import torch
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
torch._inductor.config.fx_graph_cache = True
def pip_install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
importlib.import_module("timm")
except ModuleNotFoundError:
print("Installing PyTorch Image Models...")
pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
from timm import __version__ as timmversion
from timm.data import resolve_data_config
from timm.models import create_model
TIMM_MODELS = {}
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)
# TODO - Figure out the reason of cold start memory spike
BATCH_SIZE_DIVISORS = {
"beit_base_patch16_224": 2,
"convit_base": 2,
"convmixer_768_32": 2,
"convnext_base": 2,
"cspdarknet53": 2,
"deit_base_distilled_patch16_224": 2,
"gluon_xception65": 2,
"mobilevit_s": 2,
"pnasnet5large": 2,
"poolformer_m36": 2,
"resnest101e": 2,
"swin_base_patch4_window7_224": 2,
"swsl_resnext101_32x16d": 2,
"vit_base_patch16_224": 2,
"volo_d1_224": 2,
"jx_nest_base": 4,
}
REQUIRE_HIGHER_TOLERANCE = {
"fbnetv3_b",
"gmixer_24_224",
"hrnet_w18",
"inception_v3",
"mixer_b16_224",
"mobilenetv3_large_100",
"sebotnet33ts_256",
"selecsls42b",
"convnext_base",
}
REQUIRE_EVEN_HIGHER_TOLERANCE = {
"levit_128",
"sebotnet33ts_256",
"beit_base_patch16_224",
"cspdarknet53",
}
# These models need higher tolerance in MaxAutotune mode
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
"gluon_inception_v3",
}
REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
"adv_inception_v3",
"botnet26t_256",
"gluon_inception_v3",
"selecsls42b",
"swsl_resnext101_32x16d",
}
SCALED_COMPUTE_LOSS = {
"ese_vovnet19b_dw",
"fbnetc_100",
"mnasnet_100",
"mobilevit_s",
"sebotnet33ts_256",
}
FORCE_AMP_FOR_FP16_BF16_MODELS = {
"convit_base",
"xcit_large_24_p8_224",
}
SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
"xcit_large_24_p8_224",
}
REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
"inception_v3",
"mobilenetv3_large_100",
"cspdarknet53",
}
def refresh_model_names():
import glob
from timm.models import list_models
def read_models_from_docs():
models = set()
# TODO - set the path to pytorch-image-models repo
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
with open(fn) as f:
while True:
line = f.readline()
if not line:
break
if not line.startswith("model = timm.create_model("):
continue
model = line.split("'")[1]
# print(model)
models.add(model)
return models
def get_family_name(name):
known_families = [
"darknet",
"densenet",
"dla",
"dpn",
"ecaresnet",
"halo",
"regnet",
"efficientnet",
"deit",
"mobilevit",
"mnasnet",
"convnext",
"resnet",
"resnest",
"resnext",
"selecsls",
"vgg",
"xception",
]
for known_family in known_families:
if known_family in name:
return known_family
if name.startswith("gluon_"):
return "gluon_" + name.split("_")[1]
return name.split("_")[0]
def populate_family(models):
family = {}
for model_name in models:
family_name = get_family_name(model_name)
if family_name not in family:
family[family_name] = []
family[family_name].append(model_name)
return family
docs_models = read_models_from_docs()
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
all_models_family = populate_family(all_models)
docs_models_family = populate_family(docs_models)
for key in docs_models_family:
del all_models_family[key]
chosen_models = set()
chosen_models.update(value[0] for value in docs_models_family.values())
chosen_models.update(value[0] for key, value in all_models_family.items())
filename = "timm_models_list.txt"
if os.path.exists("benchmarks"):
filename = "benchmarks/" + filename
with open(filename, "w") as fw:
for model_name in sorted(chosen_models):
fw.write(model_name + "\n")
class TimmRunner(BenchmarkRunner):
def __init__(self):
super().__init__()
self.suite_name = "timm_models"
@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS
@property
def force_fp16_for_bf16_models(self):
return set()
@property
def get_output_amp_train_process_func(self):
return {}
@property
def skip_accuracy_check_as_eager_non_deterministic(self):
if self.args.accuracy and self.args.training:
return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
return set()
@property
def guard_on_nn_module_models(self):
return {
"convit_base",
}
@property
def inline_inbuilt_nn_modules_models(self):
return {
"lcnet_050",
}
@download_retry_decorator
def _download_model(self, model_name):
model = create_model(
model_name,
in_chans=3,
scriptable=False,
num_classes=None,
drop_rate=0.0,
drop_path_rate=None,
drop_block_rate=None,
pretrained=True,
)
return model
def load_model(
self,
device,
model_name,
batch_size=None,
extra_args=None,
):
if self.args.enable_activation_checkpointing:
raise NotImplementedError(
"Activation checkpointing not implemented for Timm models"
)
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
channels_last = self._args.channels_last
model = self._download_model(model_name)
if model is None:
raise RuntimeError(f"Failed to load model '{model_name}'")
model.to(
device=device,
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
vars(self._args) if timmversion >= "0.8.0" else self._args,
model=model,
use_test_size=not is_training,
)
input_size = data_config["input_size"]
recorded_batch_size = TIMM_MODELS[model_name]
if model_name in BATCH_SIZE_DIVISORS:
recorded_batch_size = max(
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
)
batch_size = batch_size or recorded_batch_size
torch.manual_seed(1337)
input_tensor = torch.randint(
256, size=(batch_size,) + input_size, device=device
).to(dtype=torch.float32)
mean = torch.mean(input_tensor)
std_dev = torch.std(input_tensor)
example_inputs = (input_tensor - mean) / std_dev
if channels_last:
example_inputs = example_inputs.contiguous(
memory_format=torch.channels_last
)
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
if model_name in SCALED_COMPUTE_LOSS:
self.compute_loss = self.scaled_compute_loss
if is_training and not use_eval_mode:
model.train()
else:
model.eval()
self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size
def iter_model_names(self, args):
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
model_names = sorted(TIMM_MODELS.keys())
start, end = self.get_benchmark_indices(len(model_names))
for index, model_name in enumerate(model_names):
if index < start or index >= end:
continue
if (
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
or model_name in args.exclude_exact
or model_name in self.skip_models
):
continue
yield model_name
def pick_grad(self, name, is_training):
if is_training:
return torch.enable_grad()
else:
return torch.no_grad()
def use_larger_multiplier_for_smaller_tensor(self, name):
return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
tolerance = 1e-3
if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
# the conv-batchnorm fusion used under freezing may cause relatively
# large numerical difference. We need are larger tolerance.
# Check https://github.com/pytorch/pytorch/issues/120545 for context
tolerance = 8 * 1e-2
if is_training:
from torch._inductor import config as inductor_config
if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
inductor_config.max_autotune
and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
):
tolerance = 8 * 1e-2
elif name in REQUIRE_HIGHER_TOLERANCE:
tolerance = 4 * 1e-2
else:
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.
return reduce_to_scalar_loss(pred)
def scaled_compute_loss(self, pred):
# Loss values need zoom out further.
return reduce_to_scalar_loss(pred) / 1000.0
def forward_pass(self, mod, inputs, collect_outputs=True):
with self.autocast(**self.autocast_arg):
return mod(*inputs)
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
with self.autocast(**self.autocast_arg):
pred = mod(*cloned_inputs)
if isinstance(pred, tuple):
pred = pred[0]
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
def timm_main():
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings("ignore")
main(TimmRunner())
if __name__ == "__main__":
timm_main()
|