File: case.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (174 lines) | stat: -rw-r--r-- 5,022 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# mypy: allow-untyped-defs
import inspect
import re
import string
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
from types import ModuleType

import torch

_TAGS: Dict[str, Dict[str, Any]] = {
    "torch": {
        "cond": {},
        "dynamic-shape": {},
        "escape-hatch": {},
        "map": {},
        "dynamic-value": {},
        "operator": {},
        "mutation": {},
    },
    "python": {
        "assert": {},
        "builtin": {},
        "closure": {},
        "context-manager": {},
        "control-flow": {},
        "data-structure": {},
        "standard-library": {},
        "object-model": {},
    },
}


class SupportLevel(Enum):
    """
    Indicates at what stage the feature
    used in the example is handled in export.
    """

    SUPPORTED = 1
    NOT_SUPPORTED_YET = 0


ArgsType = Tuple[Any, ...]


def check_inputs_type(args, kwargs):
    if not isinstance(args, tuple):
        raise ValueError(
            f"Expecting args type to be a tuple, got: {type(args)}"
        )
    if not isinstance(kwargs, dict):
        raise ValueError(
            f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
        )
    for key in kwargs:
        if not isinstance(key, str):
            raise ValueError(
                f"Expecting kwargs keys to be a string, got: {type(key)}"
            )

def _validate_tag(tag: str):
    parts = tag.split(".")
    t = _TAGS
    for part in parts:
        assert set(part) <= set(
            string.ascii_lowercase + "-"
        ), f"Tag contains invalid characters: {part}"
        if part in t:
            t = t[part]
        else:
            raise ValueError(f"Tag {tag} is not found in registered tags.")


@dataclass(frozen=True)
class ExportCase:
    example_args: ArgsType
    description: str  # A description of the use case.
    model: torch.nn.Module
    name: str
    example_kwargs: Dict[str, Any] = field(default_factory=dict)
    extra_args: Optional[ArgsType] = None  # For testing graph generalization.
    # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
    tags: Set[str] = field(default_factory=set)
    support_level: SupportLevel = SupportLevel.SUPPORTED
    dynamic_shapes: Optional[Dict[str, Any]] = None

    def __post_init__(self):
        check_inputs_type(self.example_args, self.example_kwargs)
        if self.extra_args is not None:
            check_inputs_type(self.extra_args, {})

        for tag in self.tags:
            _validate_tag(tag)

        if not isinstance(self.description, str) or len(self.description) == 0:
            raise ValueError(f'Invalid description: "{self.description}"')


_EXAMPLE_CASES: Dict[str, ExportCase] = {}
_MODULES: Set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}


def register_db_case(case: ExportCase) -> None:
    """
    Registers a user provided ExportCase into example bank.
    """
    if case.name in _EXAMPLE_CASES:
        if case.name not in _EXAMPLE_CONFLICT_CASES:
            _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
        _EXAMPLE_CONFLICT_CASES[case.name].append(case)
        return

    _EXAMPLE_CASES[case.name] = case


def to_snake_case(name):
    name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()


def _make_export_case(m, name, configs):
    if not isinstance(m, torch.nn.Module):
        raise TypeError("Export case class should be a torch.nn.Module.")

    if "description" not in configs:
        # Fallback to docstring if description is missing.
        assert (
            m.__doc__ is not None
        ), f"Could not find description or docstring for export case: {m}"
        configs = {**configs, "description": m.__doc__}
    return ExportCase(**{**configs, "model": m, "name": name})


def export_case(**kwargs):
    """
    Decorator for registering a user provided case into example bank.
    """

    def wrapper(m):
        configs = kwargs
        module = inspect.getmodule(m)
        if module in _MODULES:
            raise RuntimeError("export_case should only be used once per example file.")

        assert module is not None
        _MODULES.add(module)
        module_name = module.__name__.split(".")[-1]
        case = _make_export_case(m, module_name, configs)
        register_db_case(case)
        return case

    return wrapper


def export_rewrite_case(**kwargs):
    def wrapper(m):
        configs = kwargs

        parent = configs.pop("parent")
        assert isinstance(parent, ExportCase)
        key = parent.name
        if key not in _EXAMPLE_REWRITE_CASES:
            _EXAMPLE_REWRITE_CASES[key] = []

        configs["example_args"] = parent.example_args
        case = _make_export_case(m, to_snake_case(m.__name__), configs)
        _EXAMPLE_REWRITE_CASES[key].append(case)
        return case

    return wrapper