File: pytorch.py

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (925 lines) | stat: -rw-r--r-- 37,350 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.

Example usage with JIT compilation:

.. highlight:: python
.. code-block:: python

    plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
    op = plan.construct()
    mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)

    # Generate inputs for the GEMM
    A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]

    # Run the module
    D = mod.run(A, B, C)


Example usage without JIT compilation:

.. highlight:: python
.. code-block:: python

    plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
    op = plan.construct()
    cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')

After this call, the directory ``output`` contains ``setup.py``,
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.

The module can later be used in Python via:

.. highlight:: python
.. code-block:: python

    import torch
    import cutlass_gemm

    # Generate inputs for the GEMM
    A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]

    # Run the module
    D = cutlass_gemm.run(A, B, C)
"""

import logging
import os

from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate

from cutlass import CUTLASS_PATH, logger, swizzle
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass.backend.conv2d_operation import Conv2dOperation
from cutlass.backend.library import ApiVersion
from cutlass.emit import common
from cutlass.utils.datatypes import is_torch_available

if is_torch_available():
    import torch


_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"

// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
    if (size > 0) {
        torch::Device device(torch::kCUDA, device_id);
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
        at::Tensor device_tensor = torch::empty({(long)size,}, options);
        return reinterpret_cast<void*>(device_tensor.data_ptr());
    } else {
        return nullptr;
    }
}

${includes}
${declaration}
${impl}
"""

_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);

// C++ interface
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
  return ${name}_kernel(A, B, C, alpha, beta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""

_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);

// C++ interface
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
  return ${name}_kernel(A, B, C, alpha, beta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
        py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""

_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
at::Tensor ${name}_kernel(
    const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
    float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1);

// C++ interface
at::Tensor ${name}(
    const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
    float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1) {
    return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run",
  py::overload_cast<
    const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
    std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float,  std::string, int>(
        &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
        py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
        py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
        py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""

_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
at::Tensor ${name}_kernel(
    std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
    float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1);

// C++ interface
at::Tensor ${name}(
    std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
    float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1) {
    return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run",
  py::overload_cast<
    std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
    std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
        &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
        py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
        py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
        py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""

_PYTORCH_GEMM_INCLUDES = {
    ApiVersion.v2x: """
#include "cutlass/gemm/device/gemm_universal.h"
""",
    ApiVersion.v3x: """
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/packed_stride.hpp"
""",
}

_PYTORCH_GROUPED_GEMM_INCLUDES = """
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"
"""

_PYTORCH_CONV2D_INCLUDES = """
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
"""

_CUTLASS_TYPE_TO_TORCH_TYPE = {
    DataType.f16: "torch::kF16",
    DataType.f32: "torch::kF32",
    DataType.f64: "torch::kF64",
    DataType.s8: "torch::I8",
    DataType.s32: "torch::I32",
}

_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
    common._CUTLASS_KERNEL_RUN_GEMM_2x
    + """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, ${torch_type_C});

    cutlass::Status status = ${name}_kernel_run(M, N, K,
                                                reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
                                                reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
                                                ptrC,
                                                reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
                                                ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
"""
)

_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
    common._CUTLASS_KERNEL_RUN_GEMM_3x
    + """
bool hw_info_queried = false;
cutlass::KernelHardwareInfo hw_info;

at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);
    int L = 1;

    // Query hardware info if we haven't already
    if (!hw_info_queried) {
        hw_info.device_id = 0;
        hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
    }

    typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, ${torch_type_C});

    cutlass::Status status = ${name}_kernel_run(M, N, K, L,
                                                reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
                                                reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
                                                ptrC,
                                                reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
                                                ElementCompute(alpha), ElementCompute(beta),
                                                hw_info);

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
"""
)


_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
    common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
    + """
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
    size_t num = A.size();

    // To avoid performing many small cudaMallocs and host-to-device copies,
    // we serialize the grouped GEMM arguments on the host, allocate one
    // large chunk of device memory, and perform a single cudaMemcpy to
    // copy the host data to the device. Allocation overheads could be
    // avoided by using a memory pool.

    // Calculate the total size of the data to be copied from host to device
    size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
                        sizeof(DeviceKernel::ElementA*) +
                        sizeof(DeviceKernel::ElementB*) +
                        sizeof(DeviceKernel::ElementC*) +
                        sizeof(DeviceKernel::ElementC*) +
                        sizeof(int64_t) +
                        sizeof(int64_t) +
                        sizeof(int64_t);
    total_size *= num;

    // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
    // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
    // To ensure that we don't end up having misaligned loads in the kernel,
    // we pad to the nearest multiple of 8.
    //
    // Note that, even on a 32-bit system (for which sizeof(X*) will not equal
    // sizeof(int64_t)), only padding between the list of GemmCoords and the
    // list of ptr_As is sufficient because the set of four equal-length lists of pointers
    // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
    // start on a multiple of 8.
    int64_t padding = 8 - (total_size % 8);
    total_size += padding;

    uint8_t* host_data = new uint8_t[total_size];
    cutlass::DeviceAllocation<uint8_t> device_data(total_size);

    uint8_t* start = host_data;
    cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);

    // Apply the padding after the list of GemmCoords
    start += num * sizeof(cutlass::gemm::GemmCoord) + padding;

    int64_t ptr_A_offset = start - host_data;
    DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
    start += num * sizeof(DeviceKernel::ElementA*);

    int64_t ptr_B_offset = start - host_data;
    DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
    start += num * sizeof(DeviceKernel::ElementB*);

    int64_t ptr_C_offset = start - host_data;
    DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
    start += num * sizeof(DeviceKernel::ElementC*);

    int64_t ptr_D_offset = start - host_data;
    DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
    start += num * sizeof(DeviceKernel::ElementC*);

    int64_t lda_offset = start - host_data;
    int64_t* lda_host = reinterpret_cast<int64_t*>(start);
    start += num * sizeof(int64_t);

    int64_t ldb_offset = start - host_data;
    int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
    start += num * sizeof(int64_t);

    int64_t ldc_offset = start - host_data;
    int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
    start += num * sizeof(int64_t);

    std::vector<at::Tensor> D(num);

    bool need_C = (C != at::nullopt) && (beta != 0.f);
    for (size_t i = 0; i < num; ++i) {
        int M = A[i].size(0);
        int N = B[i].size(1);
        int K = A[i].size(1);
        *(problem_sizes_host + i) = {M, N, K};
        *(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
        *(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());

        if (need_C) {
            *(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
        }
        else {
            *(ptr_C_host + i) = nullptr;
        }

        D[i] = B[i].new_empty({M, N}, ${torch_type_C});
        *(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());

        *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
        *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
        *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
    }

    device_data.copy_from_host(host_data);

    cutlass::Status status = ${name}_kernel_run(
        num,
        reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
        reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
        reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
        reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
        reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
        reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
        reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
        reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
        reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
        ElementCompute(alpha), ElementCompute(beta));

    delete[] host_data;

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
"""
)

_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    cutlass::Status status = ${name}_kernel_run(
        &problem_size,
        reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
        reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
        ptrC,
        reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
        alpha, beta,
        split_k_mode, stream, B.device().index());

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
"""

_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
    common._CUTLASS_KERNEL_RUN_CONV2D_2x
    + """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
    float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
    int N, H, W, C_, K, R, S, P, Q;
    N = A.size(0);
    C_ = A.size(1);
    H = A.size(2);
    W = A.size(3);

    K = B.size(0);
    R = B.size(2);
    S = B.size(3);

    cutlass::conv::Conv2dProblemSize problem_size(
        cutlass::Tensor4DCoord(N, H, W, C_),
        cutlass::Tensor4DCoord(K, R, S, C_),
        cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
        cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
        cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
        cutlass::conv::Mode::kCrossCorrelation,
        split_k_slices
    );

    P = problem_size.P;
    Q = problem_size.Q;

    typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());

    torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
    at::Tensor D = torch::zeros({N, K, P, Q}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)


_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
    common._CUTLASS_KERNEL_RUN_CONV2D_2x
    + """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1) {
    int N, H, W, C_, K, R, S;
    N = std::get<0>(input_size);
    C_ = std::get<1>(input_size);
    H = std::get<2>(input_size);
    W = std::get<3>(input_size);

    K = B.size(0);
    R = B.size(2);
    S = B.size(3);

    cutlass::conv::Conv2dProblemSize problem_size(
        cutlass::Tensor4DCoord(N, H, W, C_),
        cutlass::Tensor4DCoord(K, R, S, C_),
        cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
        cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
        cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
        cutlass::conv::Mode::kCrossCorrelation,
        split_k_slices
    );

    typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());

    torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
    at::Tensor D = torch::empty({N, C_, H, W}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)


_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
    common._CUTLASS_KERNEL_RUN_CONV2D_2x
    + """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
    std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
    std::string split_k_mode="serial", int split_k_slices=1) {
    int N, H, W, C_, K, R, S;
    K = std::get<0>(weight_size);
    C_ = std::get<1>(weight_size);
    R = std::get<2>(weight_size);
    S = std::get<3>(weight_size);

    N = B.size(0);
    H = B.size(2);
    W = B.size(3);

    cutlass::conv::Conv2dProblemSize problem_size(
        cutlass::Tensor4DCoord(N, H, W, C_),
        cutlass::Tensor4DCoord(K, R, S, C_),
        cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
        cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
        cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
        cutlass::conv::Mode::kCrossCorrelation,
        split_k_slices
    );

    typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());

    torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
    at::Tensor D = torch::empty({K, C_, R, S}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)


_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='${name}',
    ext_modules=[
        CUDAExtension('${name}', [
            '${name}.cpp',
            '${name}_kernel.cu',
        ],
        include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
        extra_compile_args=['-std=c++17']
        ),
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

"""


def _generate_setup(name: str, sourcedir: str):
    """
    Generates a setup.py file for the extension

    :param name: name of the module to generate
    :type name: str
    :param sourcedir: directory to which generated source files should be written
    :type sourcedir: str
    """
    setup_py_file = os.path.join(sourcedir, "setup.py")
    setup_source = SubstituteTemplate(
        _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH}
    )
    with open(setup_py_file, "w") as outfile:
        outfile.write(setup_source)


class _ArchListSetter:
    """
    Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
    environment variable when building a PyTorch CUDA module.

    ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
    CUDA module should be compiled.

    For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
    ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
    compilation of the module.

    This utility wraps the building of a PyTorch CUDA module with a setting of this environment
    variable according to the current compute capability being targetted.

    Example usage:

    .. highlight:: python
    .. code-block:: python

        # Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
        with _ArchListSetter(80):
            # Perform JIT compilation and loading of the module
            mod = torch.utils.cpp_extension.load(...)

    :param cc: compute capability
    :type cc: int
    """

    _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"

    def __init__(self, cc: int):
        self.cc_str = ".".join(list(str(cc)))

    def __enter__(self):
        """
        Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
        """
        self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
        os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str

        return self

    def __exit__(self, exc_type, exc_val, traceback):
        """
        Restores the old value of TORCH_CUDA_ARCH_LIST
        """
        if self.old_arch_list is None:
            del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
        else:
            os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list


def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
    """
    JIT compiles and loads a PyTorch CUDA extension.

    :param name: name of the module to generate
    :type name: str
    :param cc: compute capability of the device the module should target
    :type cc: int
    :param cpp_file: path to file containing extension's C++ interface
    :type cpp_file: str
    :param cuda_file: path to file containing extension's CUDA interface
    :type cuda_file: str

    :return: loaded PyTorch module
    """

    from torch.utils.cpp_extension import load

    extra_cuda_cflags = ["-std=c++17"]
    if cc == 90:
        # PyTorch does not currently add the sm_90a target when compute capability
        # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
        extra_cuda_cflags.append("-gencode=arch=compute_90a,code=sm_90a")

    with _ArchListSetter(cc):
        jitmodule = load(
            name,
            [cpp_file, cuda_file],
            extra_cuda_cflags=extra_cuda_cflags,
            extra_include_paths=[
                os.path.join(CUTLASS_PATH, "include"),
                os.path.join(CUTLASS_PATH, "tools/util/include"),
            ],
            verbose=(logger.level == logging.DEBUG)
        )
    return jitmodule


def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
    """
    Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
    specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
    compiled, loaded, and returned.

    :param op: operation to emit in the module
    :param name: name of the module to generate
    :type name: str
    :param cc: compute capability of the device the module should target
    :type cc: int
    :param jit: whether the module should be just-in-time compiled
    :type jit: bool
    :param sourcedir: directory to which generated source files should be written
    :type sourcedir: str

    :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
    """
    if sourcedir != "" and not os.path.isdir(sourcedir):
        os.makedirs(sourcedir)

    cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
    extra_kw = {}
    if op.api == ApiVersion.v3x:
        impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
    else:
        impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
        if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
            extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
        else:
            extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
    impl_template = (
        _PYTORCH_GEMM_IMPL_TEMPLATE_3x
        if op.api == ApiVersion.v3x
        else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
    )
    cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
    cuda_source = SubstituteTemplate(
        _PYTORCH_CUDA_TEMPLATE,
        {
            "includes": _PYTORCH_GEMM_INCLUDES[op.api],
            "declaration": op.rt_module.emit(),
            "procedural_name": op.procedural_name(),
            "impl": cuda_impl,
            "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
        },
    )
    with open(cuda_file, "w") as outfile:
        outfile.write(cuda_source)

    cpp_file = os.path.join(sourcedir, name + ".cpp")
    cpp_source = SubstituteTemplate(
        _PYTORCH_GEMM_CPP_TEMPLATE,
        {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
    )
    with open(cpp_file, "w") as outfile:
        outfile.write(cpp_source)

    _generate_setup(name, sourcedir)

    if jit:
        return _jit(name, cc, cpp_file, cuda_file)

    return None


def _pytorch_grouped_gemm(
    op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
):
    """
    Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
    specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
    compiled, loaded, and returned.

    :param op: operation to emit in the module
    :param name: name of the module to generate
    :type name: str
    :param cc: compute capability of the device the module should target
    :type cc: int
    :param jit: whether the module should be just-in-time compiled
    :type jit: bool
    :param sourcedir: directory to which generated source files should be written
    :type sourcedir: str

    :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
    """
    if op.api != ApiVersion.v2x:
        raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")

    if sourcedir != "" and not os.path.isdir(sourcedir):
        os.makedirs(sourcedir)

    cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
    cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
    cuda_source = SubstituteTemplate(
        _PYTORCH_CUDA_TEMPLATE,
        {
            "includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
            "declaration": op.rt_module.emit(),
            "procedural_name": op.procedural_name(),
            "impl": cuda_impl,
            "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
        },
    )
    with open(cuda_file, "w") as outfile:
        outfile.write(cuda_source)

    cpp_file = os.path.join(sourcedir, name + ".cpp")
    cpp_source = SubstituteTemplate(
        _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
        {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
    )
    with open(cpp_file, "w") as outfile:
        outfile.write(cpp_source)

    _generate_setup(name, sourcedir)

    if jit:
        return _jit(name, cc, cpp_file, cuda_file)

    return None


def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
    """
    Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
    specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
    compiled, loaded, and returned.

    :param op: operation to emit in the module
    :param name: name of the module to generate
    :type name: str
    :param cc: compute capability of the device the module should target
    :type cc: int
    :param jit: whether the module should be just-in-time compiled
    :type jit: bool
    :param sourcedir: directory to which generated source files should be written
    :type sourcedir: str

    Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
    weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
    for H/W/R/S given the same P/Q.

    :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
    """
    if sourcedir != "" and not os.path.isdir(sourcedir):
        os.makedirs(sourcedir)
    cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
    extra_kw = {}
    if op.conv_kind == ConvKind.Fprop:
        impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
        cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
    elif op.conv_kind == ConvKind.Dgrad:
        impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
        cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
    elif op.conv_kind == ConvKind.Wgrad:
        impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
        cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
    extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
    extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
    cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
    cuda_source = SubstituteTemplate(
        _PYTORCH_CUDA_TEMPLATE,
        {
            "includes": _PYTORCH_CONV2D_INCLUDES,
            "declaration": op.rt_module.emit(),
            "procedural_name": op.procedural_name(),
            "impl": cuda_impl,
            "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
        },
    )
    with open(cuda_file, "w") as outfile:
        outfile.write(cuda_source)

    cpp_file = os.path.join(sourcedir, name + ".cpp")
    cpp_source = SubstituteTemplate(
        cpp_template,
        {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
    )
    with open(cpp_file, "w") as outfile:
        outfile.write(cpp_source)

    _generate_setup(name, sourcedir)

    if jit:
        return _jit(name, cc, cpp_file, cuda_file)

    return None


def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
    """
    Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
    specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
    compiled, loaded, and returned.

    The result of this method is files within ``sourcedir`` that can be used for building
    a PyTorch module.

    :param op: operation to emit in the module
    :param name: name of the module to generate
    :type name: str
    :param cc: compute capability of the device the module should target
    :type cc: int
    :param jit: whether the module should be just-in-time compiled
    :type jit: bool
    :param sourcedir: directory to which generated source files should be written
    :type sourcedir: str

    :return: loaded PyTorch module (if ``jit=True``) or None
    """
    device_op = op.device_op()
    if isinstance(op, GemmOperationUniversal):
        return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
    elif isinstance(op, GemmOperationGrouped):
        return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
    elif isinstance(op, Conv2dOperation):
        return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
    else:
        raise Exception(
            f"Operation type {type(op)} is not currently supported for PyTorch emission."
        )