1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
|
# mypy: allow-untyped-defs
import inspect
import re
import string
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
from types import ModuleType
import torch
_TAGS: Dict[str, Dict[str, Any]] = {
"torch": {
"cond": {},
"dynamic-shape": {},
"escape-hatch": {},
"map": {},
"dynamic-value": {},
"operator": {},
"mutation": {},
},
"python": {
"assert": {},
"builtin": {},
"closure": {},
"context-manager": {},
"control-flow": {},
"data-structure": {},
"standard-library": {},
"object-model": {},
},
}
class SupportLevel(Enum):
"""
Indicates at what stage the feature
used in the example is handled in export.
"""
SUPPORTED = 1
NOT_SUPPORTED_YET = 0
ArgsType = Tuple[Any, ...]
def check_inputs_type(args, kwargs):
if not isinstance(args, tuple):
raise ValueError(
f"Expecting args type to be a tuple, got: {type(args)}"
)
if not isinstance(kwargs, dict):
raise ValueError(
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
)
for key in kwargs:
if not isinstance(key, str):
raise ValueError(
f"Expecting kwargs keys to be a string, got: {type(key)}"
)
def _validate_tag(tag: str):
parts = tag.split(".")
t = _TAGS
for part in parts:
assert set(part) <= set(
string.ascii_lowercase + "-"
), f"Tag contains invalid characters: {part}"
if part in t:
t = t[part]
else:
raise ValueError(f"Tag {tag} is not found in registered tags.")
@dataclass(frozen=True)
class ExportCase:
example_args: ArgsType
description: str # A description of the use case.
model: torch.nn.Module
name: str
example_kwargs: Dict[str, Any] = field(default_factory=dict)
extra_args: Optional[ArgsType] = None # For testing graph generalization.
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
tags: Set[str] = field(default_factory=set)
support_level: SupportLevel = SupportLevel.SUPPORTED
dynamic_shapes: Optional[Dict[str, Any]] = None
def __post_init__(self):
check_inputs_type(self.example_args, self.example_kwargs)
if self.extra_args is not None:
check_inputs_type(self.extra_args, {})
for tag in self.tags:
_validate_tag(tag)
if not isinstance(self.description, str) or len(self.description) == 0:
raise ValueError(f'Invalid description: "{self.description}"')
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
_MODULES: Set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
def register_db_case(case: ExportCase) -> None:
"""
Registers a user provided ExportCase into example bank.
"""
if case.name in _EXAMPLE_CASES:
if case.name not in _EXAMPLE_CONFLICT_CASES:
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
return
_EXAMPLE_CASES[case.name] = case
def to_snake_case(name):
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
def _make_export_case(m, name, configs):
if not isinstance(m, torch.nn.Module):
raise TypeError("Export case class should be a torch.nn.Module.")
if "description" not in configs:
# Fallback to docstring if description is missing.
assert (
m.__doc__ is not None
), f"Could not find description or docstring for export case: {m}"
configs = {**configs, "description": m.__doc__}
return ExportCase(**{**configs, "model": m, "name": name})
def export_case(**kwargs):
"""
Decorator for registering a user provided case into example bank.
"""
def wrapper(m):
configs = kwargs
module = inspect.getmodule(m)
if module in _MODULES:
raise RuntimeError("export_case should only be used once per example file.")
assert module is not None
_MODULES.add(module)
module_name = module.__name__.split(".")[-1]
case = _make_export_case(m, module_name, configs)
register_db_case(case)
return case
return wrapper
def export_rewrite_case(**kwargs):
def wrapper(m):
configs = kwargs
parent = configs.pop("parent")
assert isinstance(parent, ExportCase)
key = parent.name
if key not in _EXAMPLE_REWRITE_CASES:
_EXAMPLE_REWRITE_CASES[key] = []
configs["example_args"] = parent.example_args
case = _make_export_case(m, to_snake_case(m.__name__), configs)
_EXAMPLE_REWRITE_CASES[key].append(case)
return case
return wrapper
|