File: common.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (52 lines) | stat: -rw-r--r-- 2,191 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import importlib
from typing import List, Optional

from torch.testing._internal.common_utils import TestCase


class AOMigrationTestCase(TestCase):
    def _test_function_import(
        self,
        package_name: str,
        function_list: List[str],
        base: Optional[str] = None,
        new_package_name: Optional[str] = None,
    ):
        r"""Tests individual function list import by comparing the functions
        and their hashes."""
        if base is None:
            base = "quantization"
        old_base = "torch." + base
        new_base = "torch.ao." + base
        if new_package_name is None:
            new_package_name = package_name
        old_location = importlib.import_module(f"{old_base}.{package_name}")
        new_location = importlib.import_module(f"{new_base}.{new_package_name}")
        for fn_name in function_list:
            old_function = getattr(old_location, fn_name)
            new_function = getattr(new_location, fn_name)
            assert old_function == new_function, f"Functions don't match: {fn_name}"
            assert hash(old_function) == hash(new_function), (
                f"Hashes don't match: {old_function}({hash(old_function)}) vs. "
                f"{new_function}({hash(new_function)})"
            )

    def _test_dict_import(
        self, package_name: str, dict_list: List[str], base: Optional[str] = None
    ):
        r"""Tests individual function list import by comparing the functions
        and their hashes."""
        if base is None:
            base = "quantization"
        old_base = "torch." + base
        new_base = "torch.ao." + base
        old_location = importlib.import_module(f"{old_base}.{package_name}")
        new_location = importlib.import_module(f"{new_base}.{package_name}")
        for dict_name in dict_list:
            old_dict = getattr(old_location, dict_name)
            new_dict = getattr(new_location, dict_name)
            assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
            for key in new_dict.keys():
                assert (
                    old_dict[key] == new_dict[key]
                ), f"Dicts don't match: {dict_name} for key {key}"