1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
|
from __future__ import annotations
import atexit
import collections
import dataclasses
import functools
import json
import logging
import os
import sys
import typing
from abc import abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing_extensions import override, TypeAlias
from torch._dynamo.utils import dynamo_timed
from torch._inductor import config
from torch.monitor import _WaitCounter
try:
import redis
except ImportError:
redis = None # type: ignore[assignment]
log = logging.getLogger(__name__)
if config.is_fbcode():
from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
Sample as Sample_,
)
Sample: TypeAlias = Sample_
else:
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
_T = TypeVar("_T")
_U = TypeVar("_U")
remote_fx_cache_get_timed = functools.partial(
dynamo_timed,
"FbRemoteFxGraphCache.get",
phase_name="remote_fx_graph_cache_get",
log_pt2_compile_event=False,
dynamo_compile_column_us="remote_fx_graph_cache_get_time_us",
log_waitcounter=True,
)
remote_fx_cache_put_timed = functools.partial(
dynamo_timed,
"FbRemoteFxGraphCache.put",
phase_name="remote_fx_graph_cache_put",
log_pt2_compile_event=False,
dynamo_compile_column_us="remote_fx_graph_cache_put_time_us",
log_waitcounter=True,
)
class RemoteCacheBackend(Generic[_T]):
"""
A backend implementation for accessing a remote/distributed cache. Only
works with bytes in/out. For structured data use a RemoteCache.
"""
def __init__(self) -> None:
self._name = f"backend:{type(self).__name__}"
@abstractmethod
def _get(self, key: str) -> Optional[_T]:
pass
@abstractmethod
def _put(self, key: str, data: _T) -> None:
pass
def get(self, key: str) -> Optional[_T]:
try:
value = self._get(key)
cache_stats.get(self._name, value)
except Exception:
cache_stats.exception(self._name)
raise
return value
def put(self, key: str, data: _T) -> None:
try:
self._put(key, data)
cache_stats.put(self._name)
except Exception:
cache_stats.exception(self._name)
raise
# Serde that encodes from _T to _U and decodes from _U to _T.
class RemoteCacheSerde(Generic[_T, _U]):
@abstractmethod
def encode(self, data: _T) -> _U:
pass
@abstractmethod
def decode(self, data: _U) -> _T:
pass
JsonDataTy = Optional[
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
]
class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
def encode(self, data: JsonDataTy) -> bytes:
return bytes(json.dumps(data), "ascii")
def decode(self, data: bytes) -> JsonDataTy:
return json.loads(data)
class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
def encode(self, data: _T) -> _T:
return data
def decode(self, data: _T) -> _T:
return data
# This class is the top of a RemoteCache. A RemoteCache is fundamentally made of
# three parts:
#
# 1. The controller (this class).
# 2. A serializer/deserializer (instance of RemoteCacheSerde).
# 3. A backend (instance of RemoteCacheBackend).
#
# To write (`put`), the RemoteCache takes data, uses the RemoteCacheSerde to
# convert it for the backend and passes it to the backend.
#
# Conversly when reading (`get`), the RemoteCache takes data from the backend,
# uses the RemoteCacheSerde to convert it and returns it.
#
# The RemoteCacheBackend is generic on _U - which is the type of data the
# backend can directly cache (usually `bytes`).
#
# The RemoteCacheSerde is responsible for converting between _T (the type of
# data the RemoteCache accepts in `put` and returns in `get`) and _U.
#
# When instantiating a RemoteCache you should override, not directly create a
# RemoteCache. The reason is that when logging cache use (`TORCH_LOGS=cache`) we
# use the concrete type of the RemoteCache as the reported cache. See
# RemoteFxGraphCache below as an example.
class RemoteCache(Generic[_T]):
backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
def __init__(
self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
) -> None:
# Support for testing to mock out the backend on a class-by-class basis.
if (override_cls := self.__class__.backend_override_cls) is not None:
self.backend = override_cls()
else:
self.backend = backend
self.serde = serde
# See if the cache contains `key`. Returns `None` if the value is not
# present in the cache.
def get(self, key: str) -> Optional[_T]:
with _WaitCounter("pytorch.remote_cache.get").guard():
sample = self._create_sample()
try:
result = self._get(key, sample)
cache_stats.get(type(self).__name__, result)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
return result
# Add `value` to the cache with the key `key`. Note that `None` is not a
# valid value even if _T supports it (because you can't tell the difference
# between `None` and a missing cache entry).
def put(self, key: str, value: _T) -> None:
with _WaitCounter("pytorch.remote_cache.put").guard():
assert value is not None
sample = self._create_sample()
try:
self._put(key, value, sample)
cache_stats.put(type(self).__name__)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
# Used to convert data from the cache into structured data.
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
return self.serde.decode(data) # type: ignore[arg-type]
# Used to convert structured data into data for the cache.
def _encode(self, value: _T, sample: Optional[Sample]) -> object: # returns _U
return self.serde.encode(value)
# Get structured data from the cache.
# Separate from `get` so that it can be overridden.
def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
if data := self._backend_get(key):
return self._decode(data, sample)
return None
# Get unstructured data from the cache.
# Separate from `get` so that it can be overridden.
# Returns _U - but we aren't actually generic on _U
def _backend_get(self, key: str) -> object:
return self.backend.get(key)
# Put structured data into the cache.
# Separate from `put` so that it can be overridden.
def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
data = self._encode(value, sample)
self._backend_put(key, data)
# Put unstructured data into the cache.
# Separate from `put` so that it can be overridden.
# Takes data: _U - but we aren't actually generic on _U
def _backend_put(self, key: str, data: object) -> None:
self.backend.put(key, data)
# Create a logging Sample - used with internal loggers to monitor cache
# effectiveness.
def _create_sample(self) -> Optional[Sample]:
return None
# Write the logging Sample to the logger.
def _log_sample(self, sample: Optional[Sample]) -> None:
pass
class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
"""
A Redis implementation of a remote/distributed cache.
"""
_redis: Optional[redis.Redis] = None
def __init__(self, cache_id: str) -> None:
super().__init__()
if not redis:
# We had trouble importing redis - just skip init.
return
self._redis = redis.Redis(
host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
)
@override
def _get(self, key: str) -> Optional[bytes]:
if not self._redis:
# Either redis wasn't found or we already had some trouble...
return None
try:
value = self._redis.get(key)
except redis.exceptions.ConnectionError:
# Redis is lazy and doesn't actually attempt to connect until the
# first use. Mark is as unavailable now.
self._redis = None
return None
# In theory redis.get() can return an Awaitable as well...
assert value is None or isinstance(value, bytes)
return value
@override
def _put(self, key: str, data: bytes) -> None:
if not self._redis:
# Either redis wasn't found or we already had some trouble...
return
try:
self._redis.set(key, data)
except redis.exceptions.ConnectionError:
# Redis is lazy and doesn't actually attempt to connect until the
# first use. Mark is as unavailable now.
self._redis = None
class RedisRemoteCache(RemoteCache[JsonDataTy]):
def __init__(self, cache_id: str) -> None:
# Special test handling: If we're just going to override the backend
# anyway don't require redis
if self.__class__.backend_override_cls:
# This is totally bogus but it works for now...
backend = typing.cast(RemoteCacheBackend[bytes], None)
else:
backend = RedisRemoteCacheBackend(cache_id)
serde = RemoteCacheJsonSerde()
super().__init__(backend, serde)
version = 1 # consistency between various types of keys
self._key_fmt = f"pt2:{cache_id}::{{key}}:c{version}"
def _get_key(self, key: str) -> str:
return self._key_fmt.format(key=key)
@override
def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]:
key = self._get_key(key)
return super()._get(key, sample)
@override
def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None:
key = self._get_key(key)
super()._put(key, value, sample)
class RemoteAutotuneCache(RedisRemoteCache):
pass
class RemoteBundledAutotuneCache(RedisRemoteCache):
pass
class RemoteFxGraphCache(RedisRemoteCache):
pass
class RemoteAOTAutogradCache(RedisRemoteCache):
pass
class RemoteDynamoPGOCache(RedisRemoteCache):
pass
def create_cache(
key: str,
is_fbcode: bool,
fb_cache_cls: str,
oss_cache_cls: str,
) -> Optional[RemoteCache[JsonDataTy]]:
try:
if is_fbcode:
import torch._inductor.fb.remote_cache
cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls)
return cache_cls(key)
else:
this_module = sys.modules[__name__]
cache_cls = getattr(this_module, oss_cache_cls)
return cache_cls(key)
except Exception:
log.warning("Unable to create a remote cache", exc_info=True)
return None
# Some simple stat capture
@dataclasses.dataclass
class _CacheStat:
miss: int = 0
hit: int = 0
put: int = 0
exception: int = 0
def __str__(self) -> str:
return f"{{hit: {self.hit}, miss: {self.miss}, put: {self.put}, exception: {self.exception}}}"
class _CacheStats:
_stats: Dict[str, _CacheStat]
def __init__(self) -> None:
self._stats = collections.defaultdict(_CacheStat)
def miss(self, name: str, count: int = 1) -> None:
self._stats[name].miss += count
def hit(self, name: str, count: int = 1) -> None:
self._stats[name].hit += count
def get(self, name: str, value: Optional[object]) -> None:
if value is None:
self.miss(name)
else:
self.hit(name)
def put(self, name: str, count: int = 1) -> None:
self._stats[name].put += count
def exception(self, name: str, count: int = 1) -> None:
self._stats[name].exception += count
cache_stats = _CacheStats()
@atexit.register
def dump_cache_stats() -> None:
if not log.isEnabledFor(logging.INFO):
return
import io
out = io.StringIO()
if not cache_stats._stats:
print(" None", file=out)
else:
print(file=out)
for k, v in sorted(cache_stats._stats.items()):
print(f" {k}: {v}", file=out)
log.info("Cache Metrics:%s", out.getvalue())
|