1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
# TODO (matthias) This file currently requires manual imports to let
# TorchScript work on decorated functions. Not totally sure why :(
from torch_geometric.utils import * # noqa
__experimental_flag__: Dict[str, bool] = {
'disable_dynamic_shapes': False,
}
Options = Optional[Union[str, List[str]]]
def get_options(options: Options) -> List[str]:
if options is None:
options = list(__experimental_flag__.keys())
if isinstance(options, str):
options = [options]
return options
def is_experimental_mode_enabled(options: Options = None) -> bool:
r"""Returns :obj:`True` if the experimental mode is enabled. See
:class:`torch_geometric.experimental_mode` for a list of (optional)
options.
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
return False
options = get_options(options)
return all([__experimental_flag__[option] for option in options])
def set_experimental_mode_enabled(mode: bool, options: Options = None) -> None:
for option in get_options(options):
__experimental_flag__[option] = mode
class experimental_mode:
r"""Context-manager that enables the experimental mode to test new but
potentially unstable features.
.. code-block:: python
with torch_geometric.experimental_mode():
out = model(data.x, data.edge_index)
Args:
options (str or list, optional): Currently there are no experimental
features.
"""
def __init__(self, options: Options = None) -> None:
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in self.options
}
def __enter__(self) -> None:
set_experimental_mode_enabled(True, self.options)
def __exit__(self, *args: Any) -> None:
for option, value in self.previous_state.items():
__experimental_flag__[option] = value
class set_experimental_mode:
r"""Context-manager that sets the experimental mode on or off.
:class:`set_experimental_mode` will enable or disable the experimental mode
based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
See :class:`experimental_mode` above for more details.
"""
def __init__(self, mode: bool, options: Options = None) -> None:
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in self.options
}
set_experimental_mode_enabled(mode, self.options)
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> None:
for option, value in self.previous_state.items():
__experimental_flag__[option] = value
def disable_dynamic_shapes(required_args: List[str]) -> Callable:
r"""A decorator that disables the usage of dynamic shapes for the given
arguments, i.e., it will raise an error in case :obj:`required_args` are
not passed and needs to be automatically inferred.
"""
def decorator(func: Callable) -> Callable:
spec = inspect.getfullargspec(func)
required_args_pos: Dict[str, int] = {}
for arg_name in required_args:
if arg_name not in spec.args:
raise ValueError(f"The function '{func}' does not have a "
f"'{arg_name}' argument")
required_args_pos[arg_name] = spec.args.index(arg_name)
num_args = len(spec.args)
num_default_args = 0 if spec.defaults is None else len(spec.defaults)
num_positional_args = num_args - num_default_args
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not is_experimental_mode_enabled('disable_dynamic_shapes'):
return func(*args, **kwargs)
for required_arg in required_args:
index = required_args_pos[required_arg]
value: Optional[Any] = None
if index < len(args):
value = args[index]
elif required_arg in kwargs:
value = kwargs[required_arg]
elif num_default_args > 0:
assert spec.defaults is not None
value = spec.defaults[index - num_positional_args]
if value is None:
raise ValueError(f"Dynamic shapes disabled. Argument "
f"'{required_arg}' needs to be set")
return func(*args, **kwargs)
return wrapper
return decorator
|