1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
|
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
import torch
from torch.overrides import get_default_nowrap_functions
__all__ = [
"MaskedTensor",
"is_masked_tensor",
]
def is_masked_tensor(a):
r""" Returns True if the input is a MaskedTensor, else False
Args:
a: any input
Examples:
>>> # xdoctest: +SKIP
>>> from torch.masked import MaskedTensor
>>> data = torch.arange(6).reshape(2,3)
>>> mask = torch.tensor([[True, False, False], [True, True, False]])
>>> mt = MaskedTensor(data, mask)
>>> is_masked_tensor(mt)
True
"""
return isinstance(a, MaskedTensor)
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
if is_masked_tensor(a) or is_masked_tensor(b):
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
if a.layout != b.layout:
raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}")
if a.dtype != b.dtype:
b = b.type(a.dtype)
if a.layout == b.layout == torch.sparse_coo:
return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
a.indices(), b.indices(), exact
)
elif a.layout == b.layout == torch.sparse_csr:
return (
_tensors_match(a.crow_indices(), b.crow_indices(), exact)
and _tensors_match(a.col_indices(), b.col_indices(), exact)
and _tensors_match(a.values(), b.values(), exact)
)
if exact:
return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
def _masks_match(a, b):
if is_masked_tensor(a) and is_masked_tensor(b):
mask_a = a.get_mask()
mask_b = b.get_mask()
return _tensors_match(mask_a, mask_b, exact=True)
return True
def _map_mt_args_kwargs(args, kwargs, map_fn):
def _helper(a, map_fn):
if is_masked_tensor(a):
return map_fn(a)
elif torch.is_tensor(a):
return a
elif isinstance(a, list):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return a_impl
elif isinstance(a, tuple):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return tuple(a_impl)
else:
return a
if kwargs is None:
kwargs = {}
impl_args = []
for a in args:
impl_args.append(_helper(a, map_fn))
impl_kwargs = {}
for k, v in kwargs.items():
impl_kwargs[k] = _helper(a, map_fn)
return impl_args, impl_kwargs
def _wrap_result(result_data, result_mask):
if isinstance(result_data, list):
return list(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
if isinstance(result_data, tuple):
return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
if torch.is_tensor(result_data):
return MaskedTensor(result_data, result_mask)
# Expect result_data and result_mask to be Tensors only
return NotImplemented
def _masked_tensor_str(data, mask, formatter):
if data.layout in {torch.sparse_coo, torch.sparse_csr}:
data = data.to_dense()
mask = mask.to_dense()
if data.dim() == 1:
formatted_elements = [
formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
for d in data
]
max_len = max(
map(lambda x: 8 if x[1] else len(x[0]), zip(formatted_elements, ~mask))
)
return (
"["
+ ", ".join(
[
"--".rjust(max_len) if m else e
for (e, m) in zip(formatted_elements, ~mask)
]
)
+ "]"
)
sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
sub_strings = ["\n".join([" " + si for si in s.split("\n")]) for s in sub_strings]
return "[\n" + ",\n".join(sub_strings) + "\n]"
def _get_data(a):
if is_masked_tensor(a):
return a._masked_data
return a
def _maybe_get_mask(a):
if is_masked_tensor(a):
return a.get_mask()
return None
class MaskedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, mask, requires_grad=False):
if is_masked_tensor(data) or not torch.is_tensor(data):
raise TypeError("data must be a Tensor")
if is_masked_tensor(mask) or not torch.is_tensor(mask):
raise TypeError("mask must be a Tensor")
# Use a Tensor that of the give size for the wrapper.
kwargs = {}
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = requires_grad
kwargs["dispatch_sizes_strides_policy"] = "strides"
kwargs["dispatch_layout"] = True
warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.masked "
"module for further information about the project."), UserWarning)
if data.requires_grad:
warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()", UserWarning)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
def _preprocess_data(self, data, mask):
from .._ops import _sparse_coo_where, _sparse_csr_where
if data.layout != mask.layout:
raise TypeError("data and mask must have the same layout.")
if data.layout == torch.sparse_coo:
data = data.coalesce()
mask = mask.coalesce()
if data._nnz() != mask._nnz():
data = _sparse_coo_where(mask, data, torch.tensor(0))
elif data.layout == torch.sparse_csr:
if data._nnz() != mask._nnz():
data = _sparse_csr_where(mask, data, torch.tensor(0))
# Have to pick awkward names to not conflict with existing fields such as data
self._masked_data = data.clone()
self._masked_mask = mask.clone()
def _validate_members(self):
data = self._masked_data
mask = self.get_mask()
if type(data) != type(mask):
raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}")
if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
raise TypeError(f"data layout of {data.layout} is not supported.")
if data.layout == torch.sparse_coo:
if not _tensors_match(data.indices(), mask.indices(), exact=True):
raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.")
elif data.layout == torch.sparse_csr:
if not _tensors_match(
data.crow_indices(), mask.crow_indices(), exact=True
) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.")
if mask.dtype != torch.bool:
raise TypeError("mask must have dtype bool.")
if not (
data.dtype == torch.float16
or data.dtype == torch.float32
or data.dtype == torch.float64
or data.dtype == torch.bool
or data.dtype == torch.int8
or data.dtype == torch.int16
or data.dtype == torch.int32
or data.dtype == torch.int64
):
raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
if data.dim() != mask.dim():
raise ValueError("data.dim() must equal mask.dim()")
if data.size() != mask.size():
raise ValueError("data.size() must equal mask.size()")
def __init__(self, data, mask, requires_grad=False):
self._preprocess_data(data, mask)
self._validate_members()
@staticmethod
def _from_values(data, mask):
""" Differentiable constructor for MaskedTensor """
class Constructor(torch.autograd.Function):
@staticmethod
def forward(ctx, data, mask):
return MaskedTensor(data, mask)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
result = Constructor.apply(data, mask)
return result
def _set_data_mask(self, data, mask):
self._masked_data = data
self._masked_mask = mask
self._validate_members()
def __repr__(self):
formatter = "{0:8.4f}"
if self.dim() == 0:
scalar_data = self.get_data().item()
data_formatted = (
formatter.format(scalar_data)
if isinstance(scalar_data, float)
else str(scalar_data)
)
if not self.get_mask().item():
data_formatted = "--"
return (
"MaskedTensor("
+ data_formatted
+ ", "
+ str(self.get_mask().item())
+ ")"
)
s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
s = "\n".join(" " + si for si in s.split("\n"))
return "MaskedTensor(\n" + s + "\n)"
# Seems like this needs to be defined before torch_dispatch to work
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
if func in _MASKEDTENSOR_FUNCTION_TABLE:
return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return torch._tensor._convert(ret, cls)
@classmethod
def unary(cls, fn, data, mask):
return MaskedTensor(fn(data), mask)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
func = func.overloadpacket
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
if func in _MASKEDTENSOR_DISPATCH_TABLE:
return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
msg = (
f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
"In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics."
)
warnings.warn(msg)
return NotImplemented
def __lt__(self, other):
if is_masked_tensor(other):
return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
return MaskedTensor(self.get_data() < other, self.get_mask())
def to_tensor(self, value):
return self.get_data().masked_fill(~self.get_mask(), value)
def get_data(self):
class GetData(torch.autograd.Function):
@staticmethod
def forward(ctx, self):
return self._masked_data
@staticmethod
def backward(ctx, grad_output):
if is_masked_tensor(grad_output):
return grad_output
return MaskedTensor(grad_output, self.get_mask())
return GetData.apply(self)
def get_mask(self):
return self._masked_mask
def is_sparse_coo(self):
return self.layout == torch.sparse_coo
def is_sparse_csr(self):
return self.layout == torch.sparse_csr
# Update later to support more sparse layouts
def is_sparse(self):
return self.is_sparse_coo() or self.is_sparse_csr()
|