1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
import hashlib
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Callable, Optional, Tuple, Union
from typing_extensions import TypeAlias
import torch.fx.graph
class CustomGraphPass(ABC):
"""
Implement this interface for custom Graph passes:
1) The __call__() method contains the implementation of the custom pass.
2) The uuid() method enables inductor to cache compiled graphs when your custom
passes are applied. This method can return any identifier as long as it uniquely
identifies your implementation (and can be pickled). The caching logic includes this
identifier in its key calculation, i.e., any new value will effectively invalidate
existing entries. We expect custom passes would typically depend purely on the
textual reprensentation of the implementation. In that case, we recommend using the
'get_hash_for_files' helper below to compute a unique hash from the contents of a
static list of source files, i.e., the source(s) containing the custom pass
implementation. That approach ensures that any change to the implementation will
mean a new uuid.
** IMPORTANT ** If your custom pass's behavior depends on some external state, then
you'll need to implement something more complicated (or disable caching).
EXAMPLE:
class MyCustomGraphPass(CustomGraphPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
# my custom graph optimization pass
# ...
def uuid(self) -> Optional[Any]:
return get_hash_for_files((__file__,))
"""
@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
"""
Implementation of the custom pass.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation. Return None
to skip inductor code caching entirely.
"""
CustomGraphPassType: TypeAlias = Optional[
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
]
@lru_cache(1)
def get_hash_for_files(paths: Tuple[str], extra: str = "") -> bytes:
"""
Helper to compute a unique string by hashing the contents of a list of files.
"""
hasher = hashlib.sha256()
hasher.update(extra.encode("utf-8"))
for path in paths:
with open(path, "rb") as f:
hasher.update(path.encode("utf-8"))
hasher.update(f.read())
return hasher.digest()
|