1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
# mypy: allow-untyped-defs
import re
import torch
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
# "...
# from ..codecache import CudaKernelParamCache
# ..."
# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
if torch.version.hip is None and not force_hipify:
return source_codes
def c2_repl(m):
return PYTORCH_MAP[m.group(0)]
# We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
# it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
# keyword at the beginning of code line. However, this can happen in codegen,
# which will cause the pattern to not match.
# Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
# we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
return source_codes
|