1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
import getpass
import os
import re
import tempfile
# Factoring out to file without torch dependencies
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def default_cache_dir() -> str:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
return os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)
def triton_cache_dir(device: int) -> str:
if (directory := os.getenv("TRITON_CACHE_DIR")) is not None:
return directory
return os.path.join(
cache_dir(),
"triton",
str(device),
)
|