1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
|
# mypy: ignore-errors
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds a tensor with custom metadata and custom method
class ConstantExtraMetadataTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
shape = elem.shape
kwargs = {}
kwargs["strides"] = elem.stride()
kwargs["storage_offset"] = elem.storage_offset()
kwargs["device"] = elem.device
kwargs["layout"] = elem.layout
kwargs["requires_grad"] = elem.requires_grad
kwargs["dtype"] = elem.dtype
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, elem):
self.elem = elem
self.constant_attribute = 4
def __repr__(self):
inner_repr = repr(self.elem)
return f"CustomTensor({inner_repr})"
def __tensor_flatten__(self):
return ["elem"], self.constant_attribute
def add_constant(self, a):
self.constant_attribute += a
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is not None
elem = inner_tensors["elem"]
out = ConstantExtraMetadataTensor(elem)
out.constant_attribute = meta
return out
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, args
)
kwargs_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
)
out_inner = func(*args_inner, **kwargs_inner)
out_inner_flat, spec = pytree.tree_flatten(out_inner)
# for aten ops that return non-tensors, just assume that
# our cust inner tensors return the same value
out_flat = [
ConstantExtraMetadataTensor(o_inner)
if isinstance(o_inner, torch.Tensor)
else o_inner
for o_inner in out_inner_flat
]
out = pytree.tree_unflatten(out_flat, spec)
return return_and_correct_aliasing(func, args, kwargs, out)
|