1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
# mypy: allow-untyped-defs
import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",
"semi_sparse_indices",
"semi_sparse_t",
"semi_sparse_view",
"semi_sparse_detach",
"semi_sparse_mm",
"semi_sparse_addmm",
"semi_sparse_linear",
"semi_sparse_scaled_mm",
]
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def fallback_dispatcher(func, types, args, kwargs):
with no_dispatch():
return func(*args)
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
return A.packed[:num_kept_elements:].view(m, -1)
else:
return A.packed.detach()
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
metadata = A.packed[num_kept_elements:].view(m, -1)
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
else:
return A.meta
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
self = args[0]
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
assert len(self.shape) == 2
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,
meta=self.meta_t,
packed_t=self.packed,
meta_t=self.meta,
compressed_swizzled_bitmask=(
self.compressed_swizzled_bitmask.transpose(0, 1)
if self.compressed_swizzled_bitmask is not None
else None
),
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt,
)
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
self, shape = args
if tuple(shape) != self.shape:
raise NotImplementedError(
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
)
return self
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
assert len(args) == 1
self = args[0]
return self.__class__(
shape=self.shape,
packed=self.packed,
meta=self.meta,
packed_t=self.packed_t,
meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
alg_id_cusparselt=self.alg_id_cusparselt,
requires_grad=False,
)
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
row, col = B.shape
B_padded = A._pad_dense_input(B)
res = A._mm(B_padded)
return res[:, :col]
else:
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B._pad_dense_input(A)
res = B_t._mm(A_padded.t()).t()
return res[:row, :]
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 3
bias, A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if bias.ndim != 1:
raise NotImplementedError(
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
)
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, _col = A.shape
A_padded = B_t._pad_dense_input(A)
result = B_t._mm(A_padded.t(), bias=bias).t()
return result[:row, :]
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) in [2, 3]
A, B = args[:2]
bias = args[2] if len(args) == 3 else None
shape = A.shape
A_2d = A.view(-1, shape[-1])
if bias is None:
res = A_2d @ B.t()
else:
res = semi_sparse_addmm(
func=None,
types=None,
args=[bias, A_2d, B.t()],
)
return res.view(*shape[:-1], -1)
def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
# pull all args, excluding use_fast_accum flag if set.
A, B, A_scale, B_scale, bias, scale_result, out_dtype = args[:7]
assert A.dtype == torch.float8_e4m3fn
assert B.dtype == torch.float8_e4m3fn
# only cuSPARSELt supports float8_e4m3fn currentl
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
assert A.packed is not None
# Currently we only support per-tensor scaling, with float32 scales
assert A_scale.numel() == 1 and B_scale.numel() == 1
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
# cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result.
# Note that this limits us to per-tensor scalig only.
sparse_result = torch._cslt_sparse_mm(
A.packed,
B,
alpha=A_scale * B_scale,
out_dtype=out_dtype,
)
return sparse_result
|