File: __init__.py

package info (click to toggle)
python-e3nn 0.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,700 kB
  • sloc: python: 13,368; makefile: 23
file content (82 lines) | stat: -rw-r--r-- 2,697 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
__version__ = "0.6.0"


from typing import Dict
import torch
import packaging.version

# torch.jit.script is deprecated in PT 2.10+
_TORCH_VERSION = packaging.version.parse(torch.__version__.split("+")[0])
_DEFAULT_JIT_MODE = "eager" if _TORCH_VERSION >= packaging.version.parse("2.10") else "script"

_OPT_DEFAULTS: Dict[str, bool] = dict(specialized_code=True, optimize_einsums=True, jit_script_fx=True, jit_mode=_DEFAULT_JIT_MODE)


def _handle_jit_script_fx_legacy(jit_script_fx: bool, current_jit_mode: str) -> str:
    """Handle the legacy jit_script_fx flag mapping to jit_mode.

    Parameters
    ----------
    jit_script_fx : bool
        The legacy jit_script_fx flag value
    current_jit_mode : str
        The current jit_mode value

    Returns
    -------
    str
        The new jit_mode value based on the legacy mapping rules
    """
    if not jit_script_fx and current_jit_mode == "eager":
        # Keep it eager
        return "eager"
    elif not jit_script_fx:
        # Map False to eager if not already eager
        return "eager"
    elif jit_script_fx and current_jit_mode not in ["script", "inductor"]:
        # Map True to script only if not already script or inductor
        return "script"
    # In all other cases, keep current jit_mode
    return current_jit_mode


def _validate_and_set_jit_mode(jit_mode: str) -> None:
    """Validate and set the jit_mode in _OPT_DEFAULTS."""
    assert jit_mode in [
        "script",
        "inductor",
        "eager",
    ], f"Invalid jit_mode: {jit_mode}. Expected 'script', 'inductor', or 'eager'."
    _OPT_DEFAULTS["jit_mode"] = jit_mode


def set_optimization_defaults(**kwargs) -> None:
    r"""Globally set the default optimization settings.

    Parameters
    ----------
    **kwargs
        Keyword arguments to set the default optimization settings.
    """
    for k, v in kwargs.items():
        if k not in _OPT_DEFAULTS:
            raise ValueError(f"Unknown optimization option: {k}")

        # Handles the legacy mapping for jit_script_fx
        # to jit_mode so that old code can still work
        # with the new defaults.
        if k == "jit_script_fx":
            # Update jit_mode based on the legacy mapping
            new_jit_mode = _handle_jit_script_fx_legacy(v, _OPT_DEFAULTS["jit_mode"])
            _validate_and_set_jit_mode(new_jit_mode)
            _OPT_DEFAULTS[k] = v
        elif k == "jit_mode":
            # Validate and set the new jit_mode
            _validate_and_set_jit_mode(v)
        else:
            _OPT_DEFAULTS[k] = v


def get_optimization_defaults() -> Dict[str, bool]:
    r"""Get the global default optimization settings."""
    return dict(_OPT_DEFAULTS)