1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715
|
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running GEMMs.
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS GEMMs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Gemm(A, B, C, D)
plan.run()
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
.. highlight:: python
.. code-block:: python
# The following is shorthand for:
# cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32,
# layout=cutlass.LayoutType.RowMajor)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
A0 = torch.rand((128, 256), device='cuda')
B0 = torch.rand((256, 64), device='cuda')
C0 = torch.zeros((128, 64), device='cuda')
D0 = torch.zeros((128, 64), device.'cuda')
plan.run(A0, B0, C0, D0)
A = torch.rand((32, 128), device='cuda')
B = torch.rand((128, 256), device='cuda')
C = torch.zeros((32, 256), device='cuda')
D = torch.zeros((32, 256), device.'cuda')
plan.run(A1, B1, C1, D1)
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
kernel from its execution:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.compile()
# Do other work...
plan.run(A0, B0, C0, D0)
# Do other work...
plan.run(A1, B1, C1, D1)
Elementwise activation functions are easily fused to the GEMM via the interface:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
args = plan.run()
# Do other work...
args.sync()
"""
from math import prod
from cuda import cuda
from cutlass_library import (
DataType,
DataTypeSize,
GemmUniversalMode,
)
import cutlass
from cutlass import epilogue, swizzle
from cutlass.backend import compiler
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes
class Gemm(OperationBase):
"""
Constructs a ``Gemm`` object.
The data types and layouts of operands A, B, and C, along with the data type of output D
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
these are not to be changed after a ``Gemm`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. The following
constructors are equivalent:
.. highlight:: python
.. code-block:: python
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
# for operands to the same values.
Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type and layout as those passed in here).
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
Gemm(A=A, B=B, C=C, D=D)
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
# the same as that for D, at present)
Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,
element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:param layout_A: layout of operand A
:type layout_A: cutlass.LayoutType
:param layout_B: layout of operand B
:type layout_B: cutlass.LayoutType
:param layout_C: layout of operand C
:type layout_C: cutlass.LayoutType
:param layout_D: layout of operand D
:type layout_D: cutlass.LayoutType
"""
def __init__(
self, A=None, B=None, C=None, D=None,
alpha=1.0, beta=0.0, element_accumulator=None,
element=None, layout=None,
element_A=None, element_B=None, element_C=None, element_D=None,
layout_A=None, layout_B=None, layout_C=None,
cc: int = None, kernel_cc: int = None
):
super().__init__(cc=cc, kernel_cc=kernel_cc)
self.name = "gemm"
self.compiled = False
elements = []
layouts = []
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
[layout_A, layout_B, layout_C, layout_C],
[A, B, C, D],
["A", "B", "C", "D"]):
if elt is not None and tens is not None:
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if lay is not None and tens is not None:
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
if lay is None and tens is None and layout is None:
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
lay_to_set = lay if lay is not None else layout
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
self._element_accumulator = datatypes.library_type(element_accumulator)
self.A = A
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
self.epilogue_functor = None
self.op_class = None
self._tile_description = None
self._reset_operations()
self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b, self._math_operation)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
else:
math_op_str = ''
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(cutlass.epilogue.identity)
@property
def swizzling_functor(self):
"""
Returns the type of the swizzling functor currently being used by the GEMM
:return: swizzing functor type
"""
return self._swizzling_functor
@swizzling_functor.setter
def swizzling_functor(self, swizzling_functor):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if swizzling_functor == cutlass.swizzle.ThreadblockSwizzleStreamK:
if self.op_class == cutlass.OpcodeClass.Simt:
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
if self.current_cc == 90:
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90')
self._swizzling_functor = swizzling_functor
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
"stages": int,
"instruction_shape": [int, int, int] (optional),
"cluster_shape": [int, int, int] (optional)
}
"""
if td is None:
return
if isinstance(td, dict):
if self._tile_description is None:
op = self.possible_operations.default_operation(self._math_operation)
self._tile_description = datatypes.td_from_profiler_op(op)
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
if valid:
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
this checks the following:
- Does the tile description use a number of stages supported by the compute capability in question?
- Does the tile size requested fit within shared memory?
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
more non-unit cluster dimensions for pre-SM90 architectures)?
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
"""
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
if not valid:
return (valid, msg)
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
if not valid:
return (valid, msg)
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
return valid, msg
def tile_descriptions(self) -> list:
"""
Returns a list of valid tile descriptions for the operations
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
if self._math_operation is not None:
tds = [td for td in tds if td.tile_description.math_instruction == self._math_operation]
return tds
def construct(
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
"""
Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass.backend.GemmOperationUniversal
"""
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
if alignment_C is None:
alignment_C = max(self.possible_operations.alignments("C"))
if self._element_c != DataType.void:
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
if tile_description is None:
if self._tile_description is None:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
# The selected op may have lower alignment than that determined above, so we must
# reset alignment here.
alignment_C = op.C.alignment
else:
tile_description = self._tile_description
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self._tile_description = tile_description
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
operation = GemmOperationUniversal(
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
epilogue_functor=self.epilogue_functor,
swizzling_functor=self._swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
print_module: bool = False) -> cutlass.backend.GemmOperationUniversal:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
tile description and alignments. Otherwise, a default tile description and alignment
will be used.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: operation that was compiled
:rtype: cutlass.backend.GemmOperationUniversal
"""
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
def _verify_rank(self, tensor):
"""
Verifies that ``tensor`` has rank greater than 1
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
"""
if len(tensor.shape) < 2:
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
def _get_batch_count(self, A, B, C, D) -> int:
"""
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
tensors match in batch size. Presence of a batch dimension is detected by one of the
tensors being rank 3. If a batch dimension is present, it must be present in one of
operands A, B, or C (but need not be in all), and must be present in D.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple of batch count dimensions
:rtype: tuple
"""
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
if 1 not in [A_batch, B_batch]:
if A_batch != B_batch:
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
return max(A_batch, B_batch)
def _get_batch_stride(self, tensor) -> int:
"""
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
:param tensor: tensor object to process
:type tensor: numpy/cupy/torch array/tensor object
:return: stride between each matrix in the batch
:rtype: int
"""
if tensor is not None and len(tensor.shape) > 2:
return tensor.shape[-2] * tensor.shape[-1]
else:
return 0
def _get_problem_args(self, A, B, C, D) -> tuple:
"""
Returns the problem size and GEMM universal mode to use for the
given operands.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
N = B.shape[-1]
mode = GemmUniversalMode.Gemm
batch_count = self._get_batch_count(A, B, C, D)
returned_batch_count = batch_count
# If we are running a batched GEMM in which there is a nonzero batch stride
# only for A, then we can fold the batched dimension of A into the M dimension
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if batch_count > 1:
A_row = self._layout_a == cutlass.LayoutType.RowMajor
B_row = self._layout_b == cutlass.LayoutType.RowMajor
C_row = self._layout_c == cutlass.LayoutType.RowMajor
# Consider a Tensor to be batched if its rank is > 2 and
# the product of the modes beyond rank 2 equals our pre-determined batch size.
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
M *= batch_count
returned_batch_count = 1
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
N *= batch_count
returned_batch_count = 1
else:
mode = GemmUniversalMode.Batched
return GemmCoord(M, N, K), mode, returned_batch_count
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
is raised if it does not.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
"""
dtype, layout = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type or layout != ref_layout:
try:
# Attempt to transpose the tensor to fit the desired layout
tensor = tensor.transpose(-1, -2)
except:
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in this call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments
"""
super().run_setup()
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
is_void_c = self._element_c == DataType.void
self._verify_rank(A)
self._verify_rank(B)
if not is_void_c:
self._verify_rank(C)
self._verify_rank(D)
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
# kernels, for which `C` is None.
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
if mode == GemmUniversalMode.Gemm or batch_count == 1:
kwargs = {'split_k_slices': 1}
else:
kwargs = {
'batch': batch_count,
'batch_strides': {
'A': self._get_batch_stride(A),
'B': self._get_batch_stride(B),
'C': self._get_batch_stride(C),
'D': self._get_batch_stride(D)
}
}
kwargs['stream'] = stream
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
output_op = self.operation.epilogue_type(visitor_args)
else:
output_op = self.operation.epilogue_type(alpha, beta)
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=output_op,
gemm_mode=mode,
**kwargs
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments
|