File: experimental.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (139 lines) | stat: -rw-r--r-- 4,756 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import torch

# TODO (matthias) This file currently requires manual imports to let
# TorchScript work on decorated functions. Not totally sure why :(
from torch_geometric.utils import *  # noqa

__experimental_flag__: Dict[str, bool] = {
    'disable_dynamic_shapes': False,
}

Options = Optional[Union[str, List[str]]]


def get_options(options: Options) -> List[str]:
    if options is None:
        options = list(__experimental_flag__.keys())
    if isinstance(options, str):
        options = [options]
    return options


def is_experimental_mode_enabled(options: Options = None) -> bool:
    r"""Returns :obj:`True` if the experimental mode is enabled. See
    :class:`torch_geometric.experimental_mode` for a list of (optional)
    options.
    """
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        return False
    options = get_options(options)
    return all([__experimental_flag__[option] for option in options])


def set_experimental_mode_enabled(mode: bool, options: Options = None) -> None:
    for option in get_options(options):
        __experimental_flag__[option] = mode


class experimental_mode:
    r"""Context-manager that enables the experimental mode to test new but
    potentially unstable features.

    .. code-block:: python

        with torch_geometric.experimental_mode():
            out = model(data.x, data.edge_index)

    Args:
        options (str or list, optional): Currently there are no experimental
            features.
    """
    def __init__(self, options: Options = None) -> None:
        self.options = get_options(options)
        self.previous_state = {
            option: __experimental_flag__[option]
            for option in self.options
        }

    def __enter__(self) -> None:
        set_experimental_mode_enabled(True, self.options)

    def __exit__(self, *args: Any) -> None:
        for option, value in self.previous_state.items():
            __experimental_flag__[option] = value


class set_experimental_mode:
    r"""Context-manager that sets the experimental mode on or off.

    :class:`set_experimental_mode` will enable or disable the experimental mode
    based on its argument :attr:`mode`.
    It can be used as a context-manager or as a function.

    See :class:`experimental_mode` above for more details.
    """
    def __init__(self, mode: bool, options: Options = None) -> None:
        self.options = get_options(options)
        self.previous_state = {
            option: __experimental_flag__[option]
            for option in self.options
        }
        set_experimental_mode_enabled(mode, self.options)

    def __enter__(self) -> None:
        pass

    def __exit__(self, *args: Any) -> None:
        for option, value in self.previous_state.items():
            __experimental_flag__[option] = value


def disable_dynamic_shapes(required_args: List[str]) -> Callable:
    r"""A decorator that disables the usage of dynamic shapes for the given
    arguments, i.e., it will raise an error in case :obj:`required_args` are
    not passed and needs to be automatically inferred.
    """
    def decorator(func: Callable) -> Callable:
        spec = inspect.getfullargspec(func)

        required_args_pos: Dict[str, int] = {}
        for arg_name in required_args:
            if arg_name not in spec.args:
                raise ValueError(f"The function '{func}' does not have a "
                                 f"'{arg_name}' argument")
            required_args_pos[arg_name] = spec.args.index(arg_name)

        num_args = len(spec.args)
        num_default_args = 0 if spec.defaults is None else len(spec.defaults)
        num_positional_args = num_args - num_default_args

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            if not is_experimental_mode_enabled('disable_dynamic_shapes'):
                return func(*args, **kwargs)

            for required_arg in required_args:
                index = required_args_pos[required_arg]

                value: Optional[Any] = None
                if index < len(args):
                    value = args[index]
                elif required_arg in kwargs:
                    value = kwargs[required_arg]
                elif num_default_args > 0:
                    assert spec.defaults is not None
                    value = spec.defaults[index - num_positional_args]

                if value is None:
                    raise ValueError(f"Dynamic shapes disabled. Argument "
                                     f"'{required_arg}' needs to be set")

            return func(*args, **kwargs)

        return wrapper

    return decorator