1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
|
#################################################################################################
#
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from typing import Union
from cuda import cuda
from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import (
ConvKindNames,
ConvKindTag,
DataTypeNames,
DataTypeSize,
DataTypeTag,
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
LayoutType,
MathOperation,
MathOperationTag,
OpcodeClass,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
ShortDataTypeNames,
ShortLayoutTypeNames,
SplitKMode,
StrideSupport,
StrideSupportTag,
SwizzlingFunctor,
SwizzlingFunctorTag,
get_complex_from_real,
)
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import dim3_, get_conv2d_arguments
from cutlass.backend.library import (
EmissionType,
TensorDescription,
TileDescription,
)
from cutlass.backend.memory_manager import device_mem_alloc
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.utils.device import to_device_ptr
from cutlass.shape import GemmCoord
class Conv2dArguments(ArgumentBase):
"""
Argument wrapper for Conv2d. It encodes problem information and
user-provide tensors into the kernel's argument.
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass.shape.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param C: tensor C
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param D: tensor D
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
:type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
"""
def __init__(self, operation, problem_size, A, B, C, D,
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
self.operation = operation
self.conv_kind = operation.conv_kind
self.layout_A = operation.A.layout
self.layout_B = operation.B.layout
self.layout_C = operation.C.layout
self.element_A = operation.A.element
self.element_B = operation.B.element
self.element_C = operation.C.element
if self.layout_C == LayoutType.TensorNC32HW32:
raise Exception("Layout type TensorNC32HW32 is not currently supported")
super().__init__(A, B, C, D, **kwargs)
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = SplitKMode.Serial
self.split_k_slices = 1
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
self.problem_size = problem_size
self.problem_size.split_k_slices = self.split_k_slices
self.initialize()
def get_arguments(self):
tc_numel = -1
if hasattr(self, "tensor_c_numel"):
tc_numel = self.tensor_c_numel
self.c_arguments = self.operation.argument_type(
int(self.conv_kind),
self.problem_size.ctype,
int(to_device_ptr(self.ptr_A)),
int(to_device_ptr(self.ptr_B)),
int(to_device_ptr(self.ptr_C)),
int(to_device_ptr(self.ptr_D)),
tc_numel,
self.output_op,
int(self.split_k_mode)
)
def initialize(self):
self.launch_config = self.operation.rt_module.plan(self)
self.get_arguments()
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
if device_workspace_size > 0:
self.workspace_buffer = device_mem_alloc(device_workspace_size)
workspace_ptr = self.workspace_buffer.ptr
err, = cuda.cuMemsetD32(
workspace_ptr, 0, device_workspace_size // 4)
else:
workspace_ptr = None
self.semaphore = 0
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
self.ptr_D = workspace_ptr
# Reset arguments now that ptr_D has been updated
self.get_arguments()
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
self.semaphore = workspace_ptr
params_ = self.operation.rt_module.get_args(
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
self.host_workspace = bytearray(params_.contents)
self.device_workspace = None
def sync(self):
"""
Synchronize the arguments. If the input tensor is in host,
copy it from device to host.
"""
return super().sync()
class Conv2dRT(ExecutableOperation):
"""
Conv2dRT manages the CUTLASS runtime components
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
using ElementA = typename ${operation_name}_base::ElementA;
using ElementB = typename ${operation_name}_base::ElementB;
using ElementC = typename ${operation_name}_base::ElementC;
using LayoutA = typename ${operation_name}_base::LayoutA;
using LayoutB = typename ${operation_name}_base::LayoutB;
using LayoutC = typename ${operation_name}_base::LayoutC;
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
struct ${operation_name}_TemporaryArgs {
int conv_kind;
cutlass::conv::Conv2dProblemSize problem_size;
ElementA* ptr_A;
ElementB* ptr_B;
ElementC* ptr_C;
ElementC* ptr_D;
int tensor_c_numel;
typename EpilogueOutputOp::Params epilogue_params;
int split_k_mode;
};
typename ${operation_name}${operation_suffix}::Arguments
construct_arguments(${operation_name}_TemporaryArgs args) {
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
// C is interpreted as bias
tc_C = {0, 0, 0, 0};
}
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
return {
args.problem_size,
tref_A,
tref_B,
tref_C,
tref_D,
args.epilogue_params,
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
};
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
auto arguments = construct_arguments(args);
typename ${operation_name}${operation_suffix}::Params* params;
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
dim3 ${operation_name}_get_grid_shape(
int conv_kind,
cutlass::conv::Conv2dProblemSize problem_size,
cutlass::gemm::GemmCoord tile_size,
int split_k_slices
) {
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
auto tiled_shape = Swizzle::get_tiled_shape(
static_cast<cutlass::conv::Operator>(conv_kind),
problem_size,
tile_size,
split_k_slices);
return Swizzle::get_grid_shape(tiled_shape);
}
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
auto arguments = construct_arguments(args);
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
return DeviceConv::get_workspace_size(arguments);
}
}
"""
def __init__(self, operation: "Conv2dOperation"):
super().__init__(operation)
self.extra_funcs = {
"get_grid_shape": dim3_,
"get_workspace_size": ctypes.c_uint64
}
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
self.operation: Conv2dOperation = operation
self.emitter = EmitConv2dInstance("_type")
self.threads = operation.tile_description.num_threads
self.swizzle_functor = operation.swizzling_functor
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: Conv2dArguments):
tile_size = GemmCoord(
self.operation.tile_description.threadblock_shape[0],
self.operation.tile_description.threadblock_shape[1],
self.operation.tile_description.threadblock_shape[2],
)
grid = self.get_grid_shape(
int(self.conv_kind),
arguments.problem_size.ctype,
tile_size.ctype,
arguments.split_k_slices
)
return LaunchConfiguration(
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
self.shared_memory_capacity)
def initialize(self):
err, = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error: {err}")
class Conv2dOperation:
"""
CUTLASS Conv2d operation description.
:param conv_kind: convolution operator
:type conv_kind: :class:`cutlass_library.library.ConvKind`
:param iterator_algorithm: Selects among several implementation
variants trading off performance with simplicity
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
:param arch: GPU compute capability (sm_xx)
:type arch: int
:param tile_description: tile description
:type tile_description: :class:`cutlass.backend.TileDescription`
:param A: tensor A description
:type A: :class:`cutlass.backend.TensorDescription`
:param B: tensor B description
:type B: :class:`cutlass.backend.TensorDescription`
:param C: tensor C description
:type C: :class:`cutlass.backend.TensorDescription`
:param D: tensor D description
:type D: :class:`cutlass.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_library.library.DataType
:param stride_support: distinguish among partial specializations that \
accelerate certain problems where convolution stride is unit \
:type stride_support: :class:`cutlass_library.library.StrideSupport`
:param epilogue_functor: convolution epilogue functor
:type epilogue_functor: :class:`EpilogueFunctor`
:param swizzling_functor: threadblock swizzling functor
"""
def __init__(
self,
conv_kind,
iterator_algorithm,
arch: int,
tile_description: TileDescription,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=SwizzlingFunctor.Identity1,
emission_type=EmissionType.Kernel,
**kwargs
):
self.operation_kind: OperationKind = OperationKind.Conv2d
self.arch: int = arch
self.tile_description: TileDescription = tile_description
self.conv_kind = conv_kind
self.A: TensorDescription = A
self.B: TensorDescription = B
self.C: TensorDescription = C
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.emission_type = emission_type
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
"""
Launch the cuda kernel with input arguments
:param arguments: conv2d arguments
:type arguments: :class:`cutlass.backend.Conv2dArguments`
"""
# launch the kernel
err = self.rt_module.run(
arguments.host_workspace,
arguments.device_workspace,
arguments.launch_config,
arguments.stream
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {err}")
return err
#
# Get function name
#
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
return self.configuration_name()
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[
self.tile_description.math_instruction.opcode_class
]
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages,
)
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
else:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
return SubstituteTemplate(
configuration_name,
{
"arch": str(self.arch),
"opcode_class": opcode_class_name,
"extended_name": self.extended_name(),
"threadblock": threadblock,
"layout": self.layout_name(),
"alignment": "%d" % self.A.alignment
},
)
def extended_name(self):
"""Append data types if they differ from compute type."""
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
"element_a": DataTypeNames[self.A.element],
"element_c": DataTypeNames[self.C.element],
"core_name": self.core_name(),
})
return extended_name
def layout_name(self):
return "%s" % (ShortLayoutTypeNames[self.A.layout])
def core_name(self):
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
intermediate_type = ""
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
inst_shape = "%dx%dx%d" % tuple(
self.tile_description.math_instruction.instruction_shape)
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.accumulator_type():
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
else:
inst_shape = ""
return "%s%s%s%s_%s" % (
ShortDataTypeNames[self.accumulator_type()],
inst_shape,
intermediate_type,
ConvKindNames[self.conv_kind],
IteratorAlgorithmNames[self.iterator_algorithm]
)
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
]
return self.tile_description.math_instruction.math_operation in complex_operators
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
if self.is_complex():
return get_complex_from_real(accum)
return accum
def device_op(self):
"""
Returns a new Conv2dOperation object that is constructed with emission type
``EmissionType.Device``.
:return: operation ready for device-level code emission
:rtype: Conv2dOperation
"""
return Conv2dOperation(
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
emission_type=EmissionType.Device)
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
#
###################################################################################################
class EmitConv2dInstance:
def __init__(self, operation_suffix=""):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/conv/kernel/default_conv2d_fprop.h",
"cutlass/conv/kernel/default_conv2d_dgrad.h",
"cutlass/conv/kernel/default_conv2d_wgrad.h",
"cutlass/conv/device/implicit_gemm_convolution.h"
]
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
self.template_device = """
// Conv2d operation ${operation_name}
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
using DeviceKernel =
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
"""
def emit(self, operation):
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
operation.tile_description.warp_count[idx]) for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
values = {
"operation_name": operation.procedural_name(),
"operation_suffix": self.operation_suffix,
"conv_kind": ConvKindTag[operation.conv_kind],
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
"element_a": DataTypeTag[operation.A.element],
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": operation.epilogue_functor.emit(),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
"stride_support": StrideSupportTag[operation.stride_support],
"math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
if operation.emission_type == EmissionType.Kernel:
conv2d_template = self.template
else:
conv2d_template = self.template_device
return SubstituteTemplate(conv2d_template, values)
|