1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
from typing import Any, Dict, Set
NodeMetadataValue = Any
PROTECTED_KEYS: Set[str] = {
"val",
"stack_trace",
"nn_module_stack",
"debug_handle",
"tensor_meta",
}
class NodeMetadata:
def __init__(self, data: Dict[str, Any]) -> None:
self.data: Dict[str, Any] = data.copy()
def __getitem__(self, key: str) -> NodeMetadataValue:
return self.data[key]
def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
if key in PROTECTED_KEYS:
raise RuntimeError(f"Could not override node key: {key}")
self.data[key] = value
def __contains__(self, key: str) -> bool:
return key in self.data
def copy(self) -> "NodeMetadata":
return NodeMetadata(self.data.copy())
|