File: conv2d_tile_iterator.h

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (337 lines) | stat: -rw-r--r-- 11,202 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
/***************************************************************************************************
 * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template wraps the tile access iterator concept to load whole tiles from tensors in
      memory used for implicit GEMM convolution.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace conv {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename TileAccessIterator_>
class TileIterator {
public:
  using TileAccessIterator = TileAccessIterator_;

  using Shape = typename TileAccessIterator::Shape;
  using Element = typename TileAccessIterator::Element;
  using Layout = typename TileAccessIterator::Layout;
  using TensorCoord = typename Layout::TensorCoord;
  using ThreadMap = typename TileAccessIterator::ThreadMap;
  using AccessType = typename TileAccessIterator::AccessType;
  using TensorRef = typename TileAccessIterator::TensorRef;
  using Index = typename TileAccessIterator::Index;
  using LongIndex = typename TileAccessIterator::LongIndex;
  static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm;
  static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport;
  using Params = typename TileAccessIterator::Params;
  static int const kConvDim = TileAccessIterator::kConvDim;
  using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
  static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;

  /// Fragment object to be loaded or stored
  using Fragment = cutlass::Array<
    Element, 
    ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;

private:

  /// Internal state
  TileAccessIterator tile_access_iterator_;

public:

  /// Constructor
  CUTLASS_HOST_DEVICE
  TileIterator(
    Params const &params,
    ConvProblemSize const &problem_size,
    Element const *ptr,
    int thread_idx,
    MatrixCoord const &threadblock_offset = MatrixCoord()
  ):
    tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { }

  CUTLASS_HOST_DEVICE
  static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
    return TileAccessIterator::getParams(problem_size, layout);
  }

  /// Overrides the internal iteration index
  CUTLASS_HOST_DEVICE
  void set_iteration_index(Index index) {
    tile_access_iterator_.set_iteration_index(index);
  }

  /// Adds a pointer offset in units of Element
  CUTLASS_HOST_DEVICE
  void add_pointer_offset(LongIndex pointer_offset) {
    tile_access_iterator_.add_pointer_offset(pointer_offset);
  }

  /// Advances to the next tile in memory.
  CUTLASS_HOST_DEVICE
  TileIterator &operator++() {
    tile_access_iterator_.advance();
    return *this;
  }

  /// Advances to the next tile in memory.
  CUTLASS_HOST_DEVICE
  TileIterator operator++(int) {
    TileIterator self(*this);
    operator++();
    return self;
  }

  /// Loads a fragment from memory
  CUTLASS_DEVICE
  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {

    frag.clear();
    AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);

    CUTLASS_PRAGMA_UNROLL
    for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
      CUTLASS_PRAGMA_UNROLL
      for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
        CUTLASS_PRAGMA_UNROLL
        for (int v = 0; v < kAccessesPerVector; ++v) {

          int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);

          cutlass::arch::global_load<
            AccessType,
            sizeof(AccessType)
          >(
            frag_ptr[idx],
            tile_access_iterator_.get() + pointer_offset,
            tile_access_iterator_.valid()
          );
  
          ++tile_access_iterator_;
        }
      }
    }
  }

  /// Loads a fragment from memory
  CUTLASS_DEVICE
  void load(Fragment &frag) {
    tile_access_iterator_.set_iteration_index(0);
    load_with_pointer_offset(frag, 0);
  }

  CUTLASS_DEVICE
  void advance() {
    tile_access_iterator_.advance();
  }

  /// Determines whether the Implicit GEMM can execute the given problem.
  CUTLASS_HOST_DEVICE
  static Status can_implement(ConvProblemSize const &problem_size) {

    // dispatch to iterator implementation
    return TileAccessIterator::can_implement(problem_size);
  }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
// Strided Dgrad Tile Iterator
template <typename TileAccessIterator_>
class TileIteratorStridedDgrad {
public:
  using TileAccessIterator = TileAccessIterator_;

  using Shape = typename TileAccessIterator::Shape;
  using Element = typename TileAccessIterator::Element;
  using Layout = typename TileAccessIterator::Layout;
  using TensorCoord = typename Layout::TensorCoord;
  using ThreadMap = typename TileAccessIterator::ThreadMap;
  using AccessType = typename TileAccessIterator::AccessType;
  using TensorRef = typename TileAccessIterator::TensorRef;
  using Index = typename TileAccessIterator::Index;
  using LongIndex = typename TileAccessIterator::LongIndex;
  static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm;
  static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport;
  using Params = typename TileAccessIterator::Params;
  static int const kConvDim = TileAccessIterator::kConvDim;
  using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;

  /// Fragment object to be loaded or stored
  using Fragment = cutlass::Array<
    Element, 
    ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;

private:

  /// Internal state
  TileAccessIterator tile_access_iterator_;

public:

  /// Constructor (output gradient (Dy) OperandA ctor)
  CUTLASS_HOST_DEVICE
  TileIteratorStridedDgrad(
    Params const &params,
    ConvProblemSize const &problem_size,
    Element const *ptr,
    int thread_idx,
    FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
    int start_r, int start_s,
    MatrixCoord const &threadblock_offset = MatrixCoord()
  ):
    tile_access_iterator_(
      params, 
      problem_size, 
      ptr, 
      thread_idx, 
      stride_h_divmod, stride_w_divmod, 
      start_r, start_s, 
      threadblock_offset) { }

  /// Constructor (filter (w) OperandB ctor)
  CUTLASS_HOST_DEVICE
  TileIteratorStridedDgrad(
    Params const &params,
    ConvProblemSize const &problem_size,
    Element const *ptr,
    int thread_idx,
    int start_r, int start_s,
    MatrixCoord const &threadblock_offset = MatrixCoord()
  ):
    tile_access_iterator_(params, 
      problem_size, 
      ptr, 
      thread_idx, 
      start_r, start_s, 
      threadblock_offset) { }

  CUTLASS_HOST_DEVICE
  static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
    return TileAccessIterator::getParams(problem_size, layout);
  }


  /// Adds a pointer offset in units of Element
  CUTLASS_HOST_DEVICE
  void add_pointer_offset(LongIndex pointer_offset) {
    tile_access_iterator_.add_pointer_offset(pointer_offset);
  }

  /// Advances to the next tile in memory.
  CUTLASS_HOST_DEVICE
  TileIteratorStridedDgrad &operator++() {
    tile_access_iterator_.advance();
    return *this;
  }

  /// Advances to the next tile in memory.
  CUTLASS_HOST_DEVICE
  TileIteratorStridedDgrad operator++(int) {
    TileIteratorStridedDgrad self(*this);
    operator++();
    return self;
  }

  /// Loads a fragment from memory
  CUTLASS_DEVICE
  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {

    frag.clear();
    AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);

    CUTLASS_PRAGMA_UNROLL
    for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
      CUTLASS_PRAGMA_UNROLL
      for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {

        cutlass::arch::global_load<
          AccessType,
          sizeof(AccessType)
        >(
          frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
          tile_access_iterator_.get() + pointer_offset,
          tile_access_iterator_.valid()
        );

        ++tile_access_iterator_;
      }
    }
  }

  /// Loads a fragment from memory
  CUTLASS_DEVICE
  void load(Fragment &frag) {
    tile_access_iterator_.set_iteration_index(0);
    load_with_pointer_offset(frag, 0);
  }

  CUTLASS_DEVICE
  void advance() {
    tile_access_iterator_.advance();
  }

  /// Determines whether the Implicit GEMM can execute the given problem.
  CUTLASS_HOST_DEVICE
  static Status can_implement(ConvProblemSize const &problem_size) {

    // dispatch to iterator implementation
    return TileAccessIterator::can_implement(problem_size);
  }
};
/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace threadblock
} // namespace conv
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////