File: test_hub.py

package info (click to toggle)
pytorch-vision 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 15,188 kB
  • sloc: python: 49,008; cpp: 10,019; sh: 610; java: 550; xml: 79; objc: 56; makefile: 32
file content (46 lines) | stat: -rw-r--r-- 1,628 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import shutil
import sys
import tempfile

import pytest
import torch.hub as hub


def sum_of_model_parameters(model):
    s = 0
    for p in model.parameters():
        s += p.sum()
    return s


SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625


@pytest.mark.skipif("torchvision" in sys.modules, reason="TestHub must start without torchvision imported")
class TestHub:
    # Only run this check ONCE before all tests start.
    # - If torchvision is imported before all tests start, e.g. we might find _C.so
    #   which doesn't exist in downloaded zip but in the installed wheel.
    # - After the first test is run, torchvision is already in sys.modules due to
    #   Python cache as we run all hub tests in the same python process.

    def test_load_from_github(self):
        hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
        assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)

    def test_set_dir(self):
        temp_dir = tempfile.gettempdir()
        hub.set_dir(temp_dir)
        hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
        assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
        assert os.path.exists(temp_dir + "/pytorch_vision_master")
        shutil.rmtree(temp_dir + "/pytorch_vision_master")

    def test_list_entrypoints(self):
        entry_lists = hub.list("pytorch/vision", force_reload=True)
        assert "resnet18" in entry_lists


if __name__ == "__main__":
    pytest.main([__file__])