File: threadblock_swizzle.h

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (193 lines) | stat: -rw-r--r-- 8,050 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
/***************************************************************************************************
 * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Implements several possible threadblock-swizzling functions mapping blockIdx to 
      Convolution problems.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace conv {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
static int get_strided_dgrad_tile_m(
  cutlass::conv::Conv2dProblemSize const &problem_size,
  int tile_size_m) {

  // CTAs in M dimension per starting filter position
  int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m);

  // Inflate number of CTAs in M dimension to cover every strating filter position even those that
  // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) 
  // and point-wise fusion
  int tile_m = tile_m_per_filter * int(problem_size.stride().product());

  // There is a possible performance optimization here that leads up to 2x speeds than the current 
  // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1})
  //
  // * Optimization * 
  // Only launch CTAs in M dimension which contribute to a row in Dx output
  // 
  // 
  // * Constraints *
  // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: 
  //       - (A.1): There are no constraints for this case and the optimization does 
  //                affect this case functionality or performance. 
  // (B) stride > filter, for example, stride={2x2} and filter={1x1}: 
  //       - (B.1): Dx output tensor should be zero initialized
  //       - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero 

  return tile_m;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Threadblock swizzling function for strided dgrad convolution
struct StridedDgradHorizontalThreadblockSwizzle : 
  public gemm::threadblock::GemmHorizontalThreadblockSwizzle {

  using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle;

  CUTLASS_HOST_DEVICE
  StridedDgradHorizontalThreadblockSwizzle() { }

  /// Returns the shape of the problem in units of logical tiles
  /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
  CUTLASS_HOST_DEVICE
  static gemm::GemmCoord get_tiled_shape(
    cutlass::conv::Operator conv_operator,
    cutlass::conv::Conv2dProblemSize const &problem_size,
    gemm::GemmCoord tile_size,
    int split_k_slices) {

    gemm::GemmCoord implicit_gemm_problem_size = 
    cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);

    // compute number of tiles in m dimension
    int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m());

    // compute number of tiles in n dimension 
    int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n();

    return gemm::GemmCoord(
      tile_m,
      tile_n,
      split_k_slices);
  }

  /// Returns the shape of the problem in units of logical tiles
  /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
  private:
    using Base::get_tiled_shape;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Threadblock swizzling function for strided dgrad convolution
template <int N = 1>
struct StridedDgradIdentityThreadblockSwizzle : 
  public gemm::threadblock::GemmIdentityThreadblockSwizzle<N> {

  using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle<N>;

  CUTLASS_HOST_DEVICE
  StridedDgradIdentityThreadblockSwizzle() { }

  /// Returns the shape of the problem in units of logical tiles
  /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
  CUTLASS_HOST_DEVICE
  static gemm::GemmCoord get_tiled_shape(
    cutlass::conv::Operator conv_operator,
    cutlass::conv::Conv2dProblemSize const &problem_size,
    gemm::GemmCoord tile_size,
    int split_k_slices) {

    gemm::GemmCoord implicit_gemm_problem_size = 
    cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);

    // compute number of tiles in m dimension
    int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m());

    // compute number of tiles in n dimension 
    int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n();

    return gemm::GemmCoord(
      tile_m,
      tile_n,
      split_k_slices);
  }

  /// Returns the shape of the problem in units of logical tiles
  /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
  private:
    using Base::get_tiled_shape;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Threadblock swizzling function for GEMMs
template <int N = 1, int Output_N = 1, int Output_P = 1, int Output_Q = 1>
struct DepthwiseDirect2dConvIdentityThreadblockSwizzle
    : public gemm::threadblock::GemmIdentityThreadblockSwizzle<N> {
  CUTLASS_HOST_DEVICE
  DepthwiseDirect2dConvIdentityThreadblockSwizzle() {}

  /// Returns the shape of the problem in units of logical tiles
  CUTLASS_HOST_DEVICE
  static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator,
                            cutlass::conv::Conv2dProblemSize const &problem_size,
                            gemm::GemmCoord tile_size,
                            int split_k_slices) {
        
    gemm::GemmCoord implicit_gemm_problem_size =
        cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);

    return gemm::GemmCoord(1,
                     (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(),
                     split_k_slices);
  }
};

} // namespace threadblock
} // namespace conv
} // namespace cutlass