1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
namespace index_utils {
// Utility functions
// Total size of provided dimension
template <typename _dim3>
__device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
}
// Linearized indexing of idx based on dim, if bool==false that dimension does
// not participate
template <bool X, bool Y, bool Z, typename _dim3, typename _dim3_2>
__device__ nvfuser_index_t maskedOffset(const _dim3& idx, const _dim3_2& dim) {
nvfuser_index_t offset = 0;
if (Z)
offset += idx.z;
if (Y)
offset = offset * dim.y + idx.y;
if (X)
offset = offset * dim.x + idx.x;
return offset;
}
// Linearized indexing of idx based on dim. All dimensions participate.
template <typename _dim3, typename _dim3_2>
__device__ nvfuser_index_t offset(const _dim3& idx, const _dim3_2& dim) {
nvfuser_index_t offset = idx.z;
offset = offset * dim.y + idx.y;
offset = offset * dim.x + idx.x;
return offset;
}
// Masks the provided dim3, those == false get truncated to 1
template <bool X, bool Y, bool Z, typename _dim3>
__device__ dim3 maskedDims(const _dim3& dim) {
return dim3{
X ? (unsigned)dim.x : 1U,
Y ? (unsigned)dim.y : 1U,
Z ? (unsigned)dim.z : 1U};
}
// Provides total size of dim with masking, those dims == false do not
// participate in the size calculation
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
__device__ nvfuser_index_t maskedSize(const _dim3& dim) {
return size(maskedDims<X_BLOCK, Y_BLOCK, Z_BLOCK>(dim));
}
// Checks if provided idx is zero on those dims == true
template <bool X, bool Y, bool Z, typename _dim3>
__device__ bool maskedIsZero(const _dim3& idx) {
bool isZero = true;
if (X)
isZero = isZero && idx.x == 0;
if (Y)
isZero = isZero && idx.y == 0;
if (Z)
isZero = isZero && idx.z == 0;
return isZero;
}
// Checks if provided idx is zero on those dims == true
template <bool X, bool Y, bool Z, typename _dim3, typename _dim3_2>
__device__ bool maskedIsLast(const _dim3& idx, const _dim3_2& dim) {
bool isZero = true;
if (X)
isZero = isZero && idx.x == dim.x - 1;
if (Y)
isZero = isZero && idx.y == dim.y - 1;
if (Z)
isZero = isZero && idx.z == dim.z - 1;
return isZero;
}
} // namespace index_utils
|