1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
|
# mypy: ignore-errors
import torch
from copy import deepcopy
from torch.utils._pytree import tree_map
import torch.utils._pytree as pytree
# TODO: Move LoggingTensor here.
from torch.testing._internal.logging_tensor import LoggingTensor
# Base class for wrapper-style tensors.
class WrapperTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
if "size" not in kwargs:
size = t.size()
else:
size = kwargs["size"]
del kwargs["size"]
if "dtype" not in kwargs:
kwargs["dtype"] = t.dtype
if "layout" not in kwargs:
kwargs["layout"] = t.layout
if "device" not in kwargs:
kwargs["device"] = t.device
if "requires_grad" not in kwargs:
kwargs["requires_grad"] = False
# Ignore memory_format and pin memory for now as I don't know how to
# safely access them on a Tensor (if possible??)
wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
wrapper._validate_methods()
return wrapper
@classmethod
def get_wrapper_properties(cls, *args, **kwargs):
# Should return both an example Tensor and a dictionary of kwargs
# to override any of that example Tensor's properly.
# This is very similar to the `t.new_*(args)` API
raise NotImplementedError("You need to implement get_wrapper_properties")
def _validate_methods(self):
# Skip this if not in debug mode?
# Changing these on the python side is wrong as it would not be properly reflected
# on the c++ side
# This doesn't catch attributes set in the __init__
forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
for el in forbidden_overrides:
if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
f"property {el} but this is not allowed as such change would "
"not be reflected to c++ callers.")
class WrapperTensorWithCustomSizes(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, t, requires_grad=False):
return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "sizes"}
def __init__(self, t, requires_grad=False):
self.t = t
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
if kwargs is None:
kwargs = {}
def unwrap(e):
return e.t if isinstance(e, WrapperTensorWithCustomSizes) else e
def wrap(e):
return WrapperTensorWithCustomSizes(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"t={self.t}")
class WrapperTensorWithCustomStrides(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, t, requires_grad=False):
return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "strides"}
def __init__(self, t, requires_grad=False):
self.t = t
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
if kwargs is None:
kwargs = {}
def unwrap(e):
return e.t if isinstance(e, WrapperTensorWithCustomStrides) else e
def wrap(e):
return WrapperTensorWithCustomStrides(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"t={self.t}")
class DiagTensorBelow(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, diag, requires_grad=False):
assert diag.ndim == 1
return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
def __init__(self, diag, requires_grad=False):
self.diag = diag
handled_ops = {}
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# For everything else, call the handler:
fn = cls.handled_ops.get(func.__name__, None)
if fn:
return fn(*args, **(kwargs or {}))
else:
# Note that here, because we don't need to provide the autograd formulas
# we can have a default "fallback" that creates a plain Tensor based
# on the diag elements and calls the func again.
def unwrap(e):
return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
def wrap(e):
if isinstance(e, torch.Tensor) and e.ndim == 1:
return DiagTensorBelow(e)
if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
return DiagTensorBelow(e.diag())
return e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"diag={self.diag}")
class SparseTensor(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
assert values.device == indices.device
return values, {"size": size, "requires_grad": requires_grad}
def __init__(self, size, values, indices, requires_grad=False):
self.values = values
self.indices = indices
def __repr__(self):
return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
def sparse_to_dense(self):
res = torch.zeros(self.size(), dtype=self.values.dtype)
res[self.indices.unbind(1)] = self.values
return res
@staticmethod
def from_dense(t):
indices = t.nonzero()
values = t[indices.unbind(1)]
return SparseTensor(t.size(), values, indices)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
func_name = f"{func.__module__}.{func.__name__}"
res = cls._try_call_special_impl(func_name, args, kwargs)
if res is not NotImplemented:
return res
# Otherwise, use a default implementation that construct dense
# tensors and use that to compute values
def unwrap(e):
return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
# Wrap back all Tensors into our custom class
def wrap(e):
# Check for zeros and use that to get indices
return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
# To show how things happen later
def __rmul__(self, other):
return super().__rmul__(other)
_SPECIAL_IMPLS = {}
@classmethod
def _try_call_special_impl(cls, func, args, kwargs):
if func not in cls._SPECIAL_IMPLS:
return NotImplemented
return cls._SPECIAL_IMPLS[func](args, kwargs)
# Example non-wrapper subclass that stores extra state.
class NonWrapperTensor(torch.Tensor):
def __new__(cls, data):
t = torch.Tensor._make_subclass(cls, data)
t.extra_state = {
'last_func_called': None
}
return t
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, cls):
# Do something with the extra state. For the example here, just store the name of the
# last function called (skip for deepcopy so the copy has the same extra state).
if func is torch.Tensor.__deepcopy__:
result.extra_state = deepcopy(args[0].extra_state)
else:
result.extra_state = {
'last_func_called': func.__name__,
}
return result
# new_empty() must be defined for deepcopy to work
def new_empty(self, shape):
return type(self)(torch.empty(shape))
# Class used to store info about subclass tensors used in testing.
class SubclassInfo:
__slots__ = ['name', 'create_fn', 'closed_under_ops']
def __init__(self, name, create_fn, closed_under_ops=True):
self.name = name
self.create_fn = create_fn # create_fn(shape) -> tensor instance
self.closed_under_ops = closed_under_ops
# Helper function to create a subclass of the given class and possibly cache sizes / strides.
def _create_and_access_shape(cls, shape):
sub = cls(torch.randn(shape))
# NB: Wrapper subclasses with custom dispatched sizes / strides cache this info
# on the first call via non-serializable PyCapsules. We purposefully trigger cache
# population here for serialization / deepcopy tests to verify that the presence of this
# cache info doesn't cause problems.
sub.size()
sub.stride()
return sub
subclass_db = {
torch.Tensor: SubclassInfo(
'base_tensor', create_fn=torch.randn
),
NonWrapperTensor: SubclassInfo(
'non_wrapper_tensor',
create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
),
LoggingTensor: SubclassInfo(
'logging_tensor',
create_fn=lambda shape: LoggingTensor(torch.randn(shape))
),
SparseTensor: SubclassInfo(
'sparse_tensor',
create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
),
DiagTensorBelow: SubclassInfo(
'diag_tensor_below',
create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
closed_under_ops=False # sparse semantics
),
WrapperTensorWithCustomSizes: SubclassInfo(
'wrapper_with_custom_sizes',
create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomSizes, shape),
closed_under_ops=False,
),
WrapperTensorWithCustomStrides: SubclassInfo(
'wrapper_with_custom_strides',
create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomStrides, shape),
closed_under_ops=False,
),
}
class SubclassWithTensorFactory(torch.Tensor):
@staticmethod
def __new__(cls, src):
shape = src.shape
kwargs = {}
kwargs["strides"] = src.stride()
kwargs["storage_offset"] = src.storage_offset()
kwargs["device"] = src.device
kwargs["layout"] = src.layout
kwargs["requires_grad"] = src.requires_grad
kwargs["dtype"] = src.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, src):
self.src = src
def __repr__(self):
return f"{self.__class__.__name__}"
def __tensor_flatten__(self):
return ["src"], None
@classmethod
def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
src = inner_tensors["src"]
return cls(src)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
def _fn(x):
return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
_args = pytree.tree_map_only(cls, _fn, args)
_kwargs = pytree.tree_map_only(cls, _fn, kwargs)
_out = func(*_args, **_kwargs)
_out_flat, _out_spec = pytree.tree_flatten(_out)
out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
return pytree.tree_unflatten(out_flat, _out_spec)
|