File: timm_models.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (423 lines) | stat: -rwxr-xr-x 12,200 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
#!/usr/bin/env python3

import importlib
import logging
import os
import re
import subprocess
import sys
import warnings


try:
    from .common import BenchmarkRunner, download_retry_decorator, main
except ImportError:
    from common import BenchmarkRunner, download_retry_decorator, main

import torch
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs


# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
    torch._inductor.config.fx_graph_cache = True


def pip_install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])


try:
    importlib.import_module("timm")
except ModuleNotFoundError:
    print("Installing PyTorch Image Models...")
    pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
    from timm import __version__ as timmversion
    from timm.data import resolve_data_config
    from timm.models import create_model

TIMM_MODELS = {}
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")

with open(filename) as fh:
    lines = fh.readlines()
    lines = [line.rstrip() for line in lines]
    for line in lines:
        model_name, batch_size = line.split(" ")
        TIMM_MODELS[model_name] = int(batch_size)


# TODO - Figure out the reason of cold start memory spike

BATCH_SIZE_DIVISORS = {
    "beit_base_patch16_224": 2,
    "convit_base": 2,
    "convmixer_768_32": 2,
    "convnext_base": 2,
    "cspdarknet53": 2,
    "deit_base_distilled_patch16_224": 2,
    "gluon_xception65": 2,
    "mobilevit_s": 2,
    "pnasnet5large": 2,
    "poolformer_m36": 2,
    "resnest101e": 2,
    "swin_base_patch4_window7_224": 2,
    "swsl_resnext101_32x16d": 2,
    "vit_base_patch16_224": 2,
    "volo_d1_224": 2,
    "jx_nest_base": 4,
}

REQUIRE_HIGHER_TOLERANCE = {
    "fbnetv3_b",
    "gmixer_24_224",
    "hrnet_w18",
    "inception_v3",
    "mixer_b16_224",
    "mobilenetv3_large_100",
    "sebotnet33ts_256",
    "selecsls42b",
    "convnext_base",
}

REQUIRE_EVEN_HIGHER_TOLERANCE = {
    "levit_128",
    "sebotnet33ts_256",
    "beit_base_patch16_224",
    "cspdarknet53",
}

# These models need higher tolerance in MaxAutotune mode
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
    "gluon_inception_v3",
}

REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
    "adv_inception_v3",
    "botnet26t_256",
    "gluon_inception_v3",
    "selecsls42b",
    "swsl_resnext101_32x16d",
}

SCALED_COMPUTE_LOSS = {
    "ese_vovnet19b_dw",
    "fbnetc_100",
    "mnasnet_100",
    "mobilevit_s",
    "sebotnet33ts_256",
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
    "convit_base",
    "xcit_large_24_p8_224",
}

SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
    "xcit_large_24_p8_224",
}

REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
    "inception_v3",
    "mobilenetv3_large_100",
    "cspdarknet53",
}


def refresh_model_names():
    import glob

    from timm.models import list_models

    def read_models_from_docs():
        models = set()
        # TODO - set the path to pytorch-image-models repo
        for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
            with open(fn) as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if not line.startswith("model = timm.create_model("):
                        continue

                    model = line.split("'")[1]
                    # print(model)
                    models.add(model)
        return models

    def get_family_name(name):
        known_families = [
            "darknet",
            "densenet",
            "dla",
            "dpn",
            "ecaresnet",
            "halo",
            "regnet",
            "efficientnet",
            "deit",
            "mobilevit",
            "mnasnet",
            "convnext",
            "resnet",
            "resnest",
            "resnext",
            "selecsls",
            "vgg",
            "xception",
        ]

        for known_family in known_families:
            if known_family in name:
                return known_family

        if name.startswith("gluon_"):
            return "gluon_" + name.split("_")[1]
        return name.split("_")[0]

    def populate_family(models):
        family = {}
        for model_name in models:
            family_name = get_family_name(model_name)
            if family_name not in family:
                family[family_name] = []
            family[family_name].append(model_name)
        return family

    docs_models = read_models_from_docs()
    all_models = list_models(pretrained=True, exclude_filters=["*in21k"])

    all_models_family = populate_family(all_models)
    docs_models_family = populate_family(docs_models)

    for key in docs_models_family:
        del all_models_family[key]

    chosen_models = set()
    chosen_models.update(value[0] for value in docs_models_family.values())

    chosen_models.update(value[0] for key, value in all_models_family.items())

    filename = "timm_models_list.txt"
    if os.path.exists("benchmarks"):
        filename = "benchmarks/" + filename
    with open(filename, "w") as fw:
        for model_name in sorted(chosen_models):
            fw.write(model_name + "\n")


class TimmRunner(BenchmarkRunner):
    def __init__(self):
        super().__init__()
        self.suite_name = "timm_models"

    @property
    def force_amp_for_fp16_bf16_models(self):
        return FORCE_AMP_FOR_FP16_BF16_MODELS

    @property
    def force_fp16_for_bf16_models(self):
        return set()

    @property
    def get_output_amp_train_process_func(self):
        return {}

    @property
    def skip_accuracy_check_as_eager_non_deterministic(self):
        if self.args.accuracy and self.args.training:
            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
        return set()

    @property
    def guard_on_nn_module_models(self):
        return {
            "convit_base",
        }

    @property
    def inline_inbuilt_nn_modules_models(self):
        return {
            "lcnet_050",
        }

    @download_retry_decorator
    def _download_model(self, model_name):
        model = create_model(
            model_name,
            in_chans=3,
            scriptable=False,
            num_classes=None,
            drop_rate=0.0,
            drop_path_rate=None,
            drop_block_rate=None,
            pretrained=True,
        )
        return model

    def load_model(
        self,
        device,
        model_name,
        batch_size=None,
        extra_args=None,
    ):
        if self.args.enable_activation_checkpointing:
            raise NotImplementedError(
                "Activation checkpointing not implemented for Timm models"
            )

        is_training = self.args.training
        use_eval_mode = self.args.use_eval_mode

        channels_last = self._args.channels_last
        model = self._download_model(model_name)

        if model is None:
            raise RuntimeError(f"Failed to load model '{model_name}'")
        model.to(
            device=device,
            memory_format=torch.channels_last if channels_last else None,
        )

        self.num_classes = model.num_classes

        data_config = resolve_data_config(
            vars(self._args) if timmversion >= "0.8.0" else self._args,
            model=model,
            use_test_size=not is_training,
        )
        input_size = data_config["input_size"]
        recorded_batch_size = TIMM_MODELS[model_name]

        if model_name in BATCH_SIZE_DIVISORS:
            recorded_batch_size = max(
                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
            )
        batch_size = batch_size or recorded_batch_size

        torch.manual_seed(1337)
        input_tensor = torch.randint(
            256, size=(batch_size,) + input_size, device=device
        ).to(dtype=torch.float32)
        mean = torch.mean(input_tensor)
        std_dev = torch.std(input_tensor)
        example_inputs = (input_tensor - mean) / std_dev

        if channels_last:
            example_inputs = example_inputs.contiguous(
                memory_format=torch.channels_last
            )
        example_inputs = [
            example_inputs,
        ]
        self.target = self._gen_target(batch_size, device)

        self.loss = torch.nn.CrossEntropyLoss().to(device)

        if model_name in SCALED_COMPUTE_LOSS:
            self.compute_loss = self.scaled_compute_loss

        if is_training and not use_eval_mode:
            model.train()
        else:
            model.eval()

        self.validate_model(model, example_inputs)

        return device, model_name, model, example_inputs, batch_size

    def iter_model_names(self, args):
        # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
        model_names = sorted(TIMM_MODELS.keys())
        start, end = self.get_benchmark_indices(len(model_names))
        for index, model_name in enumerate(model_names):
            if index < start or index >= end:
                continue
            if (
                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
                or model_name in args.exclude_exact
                or model_name in self.skip_models
            ):
                continue

            yield model_name

    def pick_grad(self, name, is_training):
        if is_training:
            return torch.enable_grad()
        else:
            return torch.no_grad()

    def use_larger_multiplier_for_smaller_tensor(self, name):
        return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR

    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
        cosine = self.args.cosine
        tolerance = 1e-3

        if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
            # the conv-batchnorm fusion used under freezing may cause relatively
            # large numerical difference. We need are larger tolerance.
            # Check https://github.com/pytorch/pytorch/issues/120545 for context
            tolerance = 8 * 1e-2

        if is_training:
            from torch._inductor import config as inductor_config

            if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
                inductor_config.max_autotune
                and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
            ):
                tolerance = 8 * 1e-2
            elif name in REQUIRE_HIGHER_TOLERANCE:
                tolerance = 4 * 1e-2
            else:
                tolerance = 1e-2
        return tolerance, cosine

    def _gen_target(self, batch_size, device):
        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
            self.num_classes
        )

    def compute_loss(self, pred):
        # High loss values make gradient checking harder, as small changes in
        # accumulation order upsets accuracy checks.
        return reduce_to_scalar_loss(pred)

    def scaled_compute_loss(self, pred):
        # Loss values need zoom out further.
        return reduce_to_scalar_loss(pred) / 1000.0

    def forward_pass(self, mod, inputs, collect_outputs=True):
        with self.autocast(**self.autocast_arg):
            return mod(*inputs)

    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
        cloned_inputs = clone_inputs(inputs)
        self.optimizer_zero_grad(mod)
        with self.autocast(**self.autocast_arg):
            pred = mod(*cloned_inputs)
            if isinstance(pred, tuple):
                pred = pred[0]
            loss = self.compute_loss(pred)
        self.grad_scaler.scale(loss).backward()
        self.optimizer_step()
        if collect_outputs:
            return collect_results(mod, pred, loss, cloned_inputs)
        return None


def timm_main():
    logging.basicConfig(level=logging.WARNING)
    warnings.filterwarnings("ignore")
    main(TimmRunner())


if __name__ == "__main__":
    timm_main()