File: array_subbyte.h

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (573 lines) | stat: -rw-r--r-- 13,552 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
/***************************************************************************************************
 * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
           and is safe to use in a union.
*/
/*
  Note:  CUTLASS 3x increases the host compiler requirements to C++17. However, certain
         existing integrations of CUTLASS require C++11 host compilers.

         Until this requirement can be lifted, certain headers with this annotation are required
         to be remain consistent with C++11 syntax.

         C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"

namespace cutlass {

////////////////////////////////////////////////////////////////////////////////////////////////////

/// Statically sized array for any data type
template <
  typename T,
  int N
>
class Array<T, N, false> {
public:

  static int const kSizeBits = sizeof_bits<T>::value * N;

  /// Storage type
  using Storage = typename platform::conditional<
    ((kSizeBits % 32) != 0),
    typename platform::conditional<
      ((kSizeBits % 16) != 0),
      uint8_t,
      uint16_t
    >::type,
    uint32_t
  >::type;

  /// Element type
  using Element = T;

  /// Number of logical elements per stored object
  static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;

  /// Number of storage elements
  static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;

  /// Number of logical elements
  static size_t const kElements = N;

  /// Bitmask for covering one item
  static Storage const kMask = ((Storage(1) << sizeof_bits<T>::value) - 1);

  //
  // C++ standard members with pointer types removed
  //

  typedef T value_type;
  typedef size_t size_type;
  typedef ptrdiff_t difference_type;
  typedef value_type *pointer;
  typedef value_type const *const_pointer;

  //
  // References
  //

  /// Reference object inserts or extracts sub-byte items
  class reference {
    /// Pointer to storage element
    Storage *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    /// Default ctor
    CUTLASS_HOST_DEVICE
    reference(): ptr_(nullptr), idx_(0) { }

    /// Ctor
    CUTLASS_HOST_DEVICE
    reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }

    /// Assignment
    CUTLASS_HOST_DEVICE
    reference &operator=(T x) {
      Storage item = (reinterpret_cast<Storage const &>(x) & kMask);

      Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits<T>::value)));
      *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits<T>::value)));

      return *this;
    }

    CUTLASS_HOST_DEVICE
    T get() const {
      Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask);
      return reinterpret_cast<T const &>(item);
    }

    /// Extract
    CUTLASS_HOST_DEVICE
    operator T() const {
      return get();
    }

    /// Explicit cast to int
    CUTLASS_HOST_DEVICE
    explicit operator int() const {
      return int(get());
    }

    /// Explicit cast to float
    CUTLASS_HOST_DEVICE
    explicit operator float() const {
      return float(get());
    }
  };

  /// Reference object extracts sub-byte items
  class const_reference {

    /// Pointer to storage element
    Storage const *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    /// Default ctor
    CUTLASS_HOST_DEVICE
    const_reference(): ptr_(nullptr), idx_(0) { }

    /// Ctor
    CUTLASS_HOST_DEVICE
    const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }

    CUTLASS_HOST_DEVICE
    const T get() const {
      Storage item = (*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask;
      return reinterpret_cast<T const &>(item);
    }

    /// Extract
    CUTLASS_HOST_DEVICE
    operator T() const {
      Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits<T>::value)) & kMask);
      return reinterpret_cast<T const &>(item);
    }

    /// Explicit cast to int
    CUTLASS_HOST_DEVICE
    explicit operator int() const {
      return int(get());
    }

    /// Explicit cast to float
    CUTLASS_HOST_DEVICE
    explicit operator float() const {
      return float(get());
    }
  };

  //
  // Iterators
  //

  /// Bidirectional iterator over elements
  class iterator {

    /// Pointer to storage element
    Storage *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    CUTLASS_HOST_DEVICE
    iterator(): ptr_(nullptr), idx_(0) { }

    CUTLASS_HOST_DEVICE
    iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }

    CUTLASS_HOST_DEVICE
    iterator &operator++() {
      ++idx_;
      if (idx_ == kElementsPerStoredItem) {
        ++ptr_;
        idx_ = 0;
      }
      return *this;
    }

    CUTLASS_HOST_DEVICE
    iterator &operator--() {
      if (!idx_) {
        --ptr_;
        idx_ = kElementsPerStoredItem - 1;
      }
      else {
        --idx_;
      }
      return *this;
    }

    CUTLASS_HOST_DEVICE
    iterator operator++(int) {
      iterator ret(*this);
      ++idx_;
      if (idx_ == kElementsPerStoredItem) {
        ++ptr_;
        idx_ = 0;
      }
      return ret;
    }

    CUTLASS_HOST_DEVICE
    iterator operator--(int) {
      iterator ret(*this);
      if (!idx_) {
        --ptr_;
        idx_ = kElementsPerStoredItem - 1;
      }
      else {
        --idx_;
      }
      return ret;
    }

    CUTLASS_HOST_DEVICE
    reference operator*() const {
      return reference(ptr_, idx_);
    }

    CUTLASS_HOST_DEVICE
    bool operator==(iterator const &other) const {
      return ptr_ == other.ptr_ && idx_ == other.idx_;
    }

    CUTLASS_HOST_DEVICE
    bool operator!=(iterator const &other) const {
      return !(*this == other);
    }
  };

  /// Bidirectional constant iterator over elements
  class const_iterator {

    /// Pointer to storage element
    Storage const *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    CUTLASS_HOST_DEVICE
    const_iterator(): ptr_(nullptr), idx_(0) { }

    CUTLASS_HOST_DEVICE
    const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }

    CUTLASS_HOST_DEVICE
    iterator &operator++() {
      ++idx_;
      if (idx_ == kElementsPerStoredItem) {
        ++ptr_;
        idx_ = 0;
      }
      return *this;
    }

    CUTLASS_HOST_DEVICE
    iterator &operator--() {
      if (!idx_) {
        --ptr_;
        idx_ = kElementsPerStoredItem - 1;
      }
      else {
        --idx_;
      }
      return *this;
    }

    CUTLASS_HOST_DEVICE
    iterator operator++(int) {
      iterator ret(*this);
      ++idx_;
      if (idx_ == kElementsPerStoredItem) {
        ++ptr_;
        idx_ = 0;
      }
      return ret;
    }

    CUTLASS_HOST_DEVICE
    iterator operator--(int) {
      iterator ret(*this);
      if (!idx_) {
        --ptr_;
        idx_ = kElementsPerStoredItem - 1;
      }
      else {
        --idx_;
      }
      return ret;
    }

    CUTLASS_HOST_DEVICE
    const_reference operator*() const {
      return const_reference(ptr_, idx_);
    }

    CUTLASS_HOST_DEVICE
    bool operator==(iterator const &other) const {
      return ptr_ == other.ptr_ && idx_ == other.idx_;
    }

    CUTLASS_HOST_DEVICE
    bool operator!=(iterator const &other) const {
      return !(*this == other);
    }
  };

  /// Bidirectional iterator over elements
  class reverse_iterator {

    /// Pointer to storage element
    Storage *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    CUTLASS_HOST_DEVICE
    reverse_iterator(): ptr_(nullptr), idx_(0) { }

    CUTLASS_HOST_DEVICE
    reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
  };

  /// Bidirectional constant iterator over elements
  class const_reverse_iterator {

    /// Pointer to storage element
    Storage const *ptr_;

    /// Index into elements packed into Storage object
    int idx_;

  public:

    CUTLASS_HOST_DEVICE
    const_reverse_iterator(): ptr_(nullptr), idx_(0) { }

    CUTLASS_HOST_DEVICE
    const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
  };

private:

  /// Internal storage
  Storage storage[kStorageElements];

public:

  #if 0
  CUTLASS_HOST_DEVICE
  Array() { }

  CUTLASS_HOST_DEVICE
  Array(Array const &x) {
    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < int(kStorageElements); ++i) {
      storage[i] = x.storage[i];
    }
  }
  #endif

  /// Efficient clear method
  CUTLASS_HOST_DEVICE
  void clear() {

    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < int(kStorageElements); ++i) {
      storage[i] = Storage(0);
    }
  }

  CUTLASS_HOST_DEVICE
  reference at(size_type pos) {
    return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
  }

  CUTLASS_HOST_DEVICE
  const_reference at(size_type pos) const {
    return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
  }

  CUTLASS_HOST_DEVICE
  reference operator[](size_type pos) {
    return at(pos);
  }

  CUTLASS_HOST_DEVICE
  const_reference operator[](size_type pos) const {
    return at(pos);
  }

  CUTLASS_HOST_DEVICE
  reference front() {
    return at(0);
  }

  CUTLASS_HOST_DEVICE
  const_reference front() const {
    return at(0);
  }

  CUTLASS_HOST_DEVICE
  reference back() {
    return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
  }

  CUTLASS_HOST_DEVICE
  const_reference back() const {
    return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
  }

  CUTLASS_HOST_DEVICE
  pointer data() {
    return reinterpret_cast<pointer>(storage);
  }

  CUTLASS_HOST_DEVICE
  const_pointer data() const {
    return reinterpret_cast<const_pointer>(storage);
  }
  
  CUTLASS_HOST_DEVICE
  Storage * raw_data() {
    return storage;
  }

  CUTLASS_HOST_DEVICE
  Storage const * raw_data() const {
    return storage;
  }


  CUTLASS_HOST_DEVICE
  constexpr bool empty() const {
    return !kElements;
  }

  CUTLASS_HOST_DEVICE
  constexpr size_type size() const {
    return kElements;
  }

  CUTLASS_HOST_DEVICE
  constexpr size_type max_size() const {
    return kElements;
  }

  CUTLASS_HOST_DEVICE
  void fill(T const &value) {

    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < kElementsPerStoredItem; ++i) {
      reference ref(storage, i);
      ref = value;
    }

    CUTLASS_PRAGMA_UNROLL
    for (int i = 1; i < kStorageElements; ++i) {
      storage[i] = storage[0];
    }
  }

  CUTLASS_HOST_DEVICE
  iterator begin() {
    return iterator(storage);
  }

  CUTLASS_HOST_DEVICE
  const_iterator cbegin() const {
    return const_iterator(storage);
  }

  CUTLASS_HOST_DEVICE
  iterator end() {
    return iterator(storage + kStorageElements);
  }

  CUTLASS_HOST_DEVICE
  const_iterator cend() const {
    return const_iterator(storage + kStorageElements);
  }

  CUTLASS_HOST_DEVICE
  reverse_iterator rbegin() {
    return reverse_iterator(storage + kStorageElements);
  }

  CUTLASS_HOST_DEVICE
  const_reverse_iterator crbegin() const {
    return const_reverse_iterator(storage + kStorageElements);
  }

  CUTLASS_HOST_DEVICE
  reverse_iterator rend() {
    return reverse_iterator(storage);
  }

  CUTLASS_HOST_DEVICE
  const_reverse_iterator crend() const {
    return const_reverse_iterator(storage);
  }

  //
  // Comparison operators
  //

};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace cutlass

////////////////////////////////////////////////////////////////////////////////////////////////////