1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
__version__ = "0.6.0"
from typing import Dict
import torch
import packaging.version
# torch.jit.script is deprecated in PT 2.10+
_TORCH_VERSION = packaging.version.parse(torch.__version__.split("+")[0])
_DEFAULT_JIT_MODE = "eager" if _TORCH_VERSION >= packaging.version.parse("2.10") else "script"
_OPT_DEFAULTS: Dict[str, bool] = dict(specialized_code=True, optimize_einsums=True, jit_script_fx=True, jit_mode=_DEFAULT_JIT_MODE)
def _handle_jit_script_fx_legacy(jit_script_fx: bool, current_jit_mode: str) -> str:
"""Handle the legacy jit_script_fx flag mapping to jit_mode.
Parameters
----------
jit_script_fx : bool
The legacy jit_script_fx flag value
current_jit_mode : str
The current jit_mode value
Returns
-------
str
The new jit_mode value based on the legacy mapping rules
"""
if not jit_script_fx and current_jit_mode == "eager":
# Keep it eager
return "eager"
elif not jit_script_fx:
# Map False to eager if not already eager
return "eager"
elif jit_script_fx and current_jit_mode not in ["script", "inductor"]:
# Map True to script only if not already script or inductor
return "script"
# In all other cases, keep current jit_mode
return current_jit_mode
def _validate_and_set_jit_mode(jit_mode: str) -> None:
"""Validate and set the jit_mode in _OPT_DEFAULTS."""
assert jit_mode in [
"script",
"inductor",
"eager",
], f"Invalid jit_mode: {jit_mode}. Expected 'script', 'inductor', or 'eager'."
_OPT_DEFAULTS["jit_mode"] = jit_mode
def set_optimization_defaults(**kwargs) -> None:
r"""Globally set the default optimization settings.
Parameters
----------
**kwargs
Keyword arguments to set the default optimization settings.
"""
for k, v in kwargs.items():
if k not in _OPT_DEFAULTS:
raise ValueError(f"Unknown optimization option: {k}")
# Handles the legacy mapping for jit_script_fx
# to jit_mode so that old code can still work
# with the new defaults.
if k == "jit_script_fx":
# Update jit_mode based on the legacy mapping
new_jit_mode = _handle_jit_script_fx_legacy(v, _OPT_DEFAULTS["jit_mode"])
_validate_and_set_jit_mode(new_jit_mode)
_OPT_DEFAULTS[k] = v
elif k == "jit_mode":
# Validate and set the new jit_mode
_validate_and_set_jit_mode(v)
else:
_OPT_DEFAULTS[k] = v
def get_optimization_defaults() -> Dict[str, bool]:
r"""Get the global default optimization settings."""
return dict(_OPT_DEFAULTS)
|