1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
#ifndef CAFFE2_UTILS_GPU_BITONIC_SORT_H_
#define CAFFE2_UTILS_GPU_BITONIC_SORT_H_
#include "caffe2/utils/math.h"
#include "caffe2/utils/GpuDefs.cuh"
namespace caffe2 {
// Returns true if the given integer type is a power-of-2 (positive only)
// Note(jiayq): windows reported an error per
// https://github.com/caffe2/caffe2/issues/997
// and as a result will make it a macro.
#ifdef _MSC_VER
#define integerIsPowerOf2(v) ((v) && !((v) & ((v) - 1)))
#else // _MSC_VER
template <typename T>
constexpr bool integerIsPowerOf2(T v) {
return (v && !(v & (v - 1)));
}
#endif // _MSC_VER
/// The maximum in-block bitonic sort we support
constexpr int kMaxBitonicSortSize = 4096;
template <typename T>
__device__ inline void swapVars(T& t1, T& t2) {
T tmp = t1;
t1 = t2;
t2 = tmp;
}
template <typename Comparator, typename K, typename V>
__device__ inline void bitonicSwap(K& kA, V& vA,
K& kB, V& vB,
bool dir,
const Comparator& comp) {
bool swap = comp(kA, vA, kB, vB);
if (swap == dir) {
swapVars(kA, kB);
swapVars(vA, vB);
}
};
template <typename Comparator, typename K, typename V,
int Power2SortSize,
int ThreadsPerBlock>
__device__ inline void bitonicSort(K* keys,
V* values,
const Comparator& comp) {
static_assert(Power2SortSize <= kMaxBitonicSortSize,
"sort size <= 4096 only supported");
// Assume the sort is taking place in shared memory
// static_assert(Power2SortSize * (sizeof(K) + sizeof(V)) < 32768,
// "sort data too large (>32768 bytes)");
static_assert(integerIsPowerOf2(Power2SortSize),
"sort size must be power of 2");
static_assert(integerIsPowerOf2(ThreadsPerBlock),
"threads in block must be power of 2");
// If what we are sorting is too small, then not all threads
// participate
constexpr int numThreadsForSort = Power2SortSize / 2;
constexpr bool allThreads = numThreadsForSort >= ThreadsPerBlock;
// If what we are sorting is too large, then threads must loop more
// than once
constexpr int loopPerThread =
allThreads ? numThreadsForSort / ThreadsPerBlock : 1;
#pragma unroll
for (int size = 2; size < Power2SortSize; size *= 2) {
#pragma unroll
for (int stride = size / 2; stride > 0; stride /= 2) {
#pragma unroll
for (int loop = 0; loop < loopPerThread; ++loop) {
int threadId = loop * ThreadsPerBlock + threadIdx.x;
bool flag = ((threadId & (size / 2)) != 0);
int pos = 2 * threadId - (threadId & (stride - 1));
if (allThreads || (threadId < numThreadsForSort)) {
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos],
keys[pos + stride], values[pos + stride],
flag, comp);
}
__syncthreads();
}
}
}
#pragma unroll
for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
#pragma unroll
for (int loop = 0; loop < loopPerThread; ++loop) {
int threadId = loop * ThreadsPerBlock + threadIdx.x;
int pos = 2 * threadId - (threadId & (stride - 1));
if (allThreads || (threadId < numThreadsForSort)) {
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos],
keys[pos + stride], values[pos + stride],
false, comp);
}
__syncthreads();
}
}
}
template <typename Comparator, typename K, typename V, int Power2SortSize>
__device__ inline void warpBitonicSort(K* keys,
V* values,
const Comparator& comp) {
// Smaller sorts should use a warp shuffle sort
static_assert(Power2SortSize > kWarpSize,
"sort not large enough");
static_assert(integerIsPowerOf2(Power2SortSize),
"sort size must be power of 2");
static_assert(Power2SortSize <= kMaxBitonicSortSize,
"sort size <= 4096 only supported");
// If what we are sorting is too large, then lanes must loop more
// than once
constexpr int loopPerThread = (Power2SortSize / 2) / kWarpSize;
int laneId = getLaneId();
#pragma unroll
for (int size = 2; size < Power2SortSize; size *= 2) {
#pragma unroll
for (int stride = size / 2; stride > 0; stride /= 2) {
#pragma unroll
for (int loop = 0; loop < loopPerThread; ++loop) {
int threadId = loop * kWarpSize + laneId;
bool flag = ((threadId & (size / 2)) != 0);
int pos = 2 * threadId - (threadId & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos],
keys[pos + stride], values[pos + stride],
flag, comp);
__threadfence_block();
}
}
}
#pragma unroll
for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
#pragma unroll
for (int loop = 0; loop < loopPerThread; ++loop) {
int threadId = loop * kWarpSize + laneId;
int pos = 2 * threadId - (threadId & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos],
keys[pos + stride], values[pos + stride],
false, comp);
__threadfence_block();
}
}
}
} // namespace caffe2
#endif // CAFFE2_UTILS_GPU_BITONIC_SORT_H_
|