File: GpuDefs.cuh

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (158 lines) | stat: -rw-r--r-- 3,833 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#ifndef CAFFE2_UTILS_GPU_DEFS_H_
#define CAFFE2_UTILS_GPU_DEFS_H_

#include <cuda_runtime.h>

namespace caffe2 {

// Static definition of GPU warp size for unrolling and code generation

#if defined(USE_ROCM)
constexpr int kWarpSize = warpSize;   // = 64 (Defined in hip_runtime.h)
#else
constexpr int kWarpSize = 32;
#endif // __CUDA_ARCH__

//
// Interfaces to PTX instructions for which there appears to be no
// intrinsic
//

template <typename T>
struct Bitfield {};

template <>
struct Bitfield<unsigned int> {
  static __device__ __forceinline__
  unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(USE_ROCM)
    pos &= 0xff;
    len &= 0xff;

    unsigned int m = (1u << len) - 1u;
    return (val >> pos) & m;
#else
    unsigned int ret;
    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
    return ret;
#endif // USE_ROCM
  }

  static __device__ __forceinline__
  unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(USE_ROCM)
    pos &= 0xff;
    len &= 0xff;

    unsigned int m = (1u << len) - 1u;
    toInsert &= m;
    toInsert <<= pos;
    m <<= pos;

    return (val & ~m) | toInsert;
#else
    unsigned int ret;
    asm("bfi.b32 %0, %1, %2, %3, %4;" :
        "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
    return ret;
#endif // USE_ROCM
  }
};

template <>
struct Bitfield<unsigned long long int> {
  static __device__ __forceinline__
  unsigned long long int getBitfield(unsigned long long int val, int pos, int len) {
#if defined(USE_ROCM)
    pos &= 0xff;
    len &= 0xff;

    unsigned long long int m = (1u << len) - 1u;
    return (val >> pos) & m;
#else
    unsigned long long int ret;
    asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
    return ret;
#endif // USE_ROCM
  }

  static __device__ __forceinline__
  unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
#if defined(USE_ROCM)
    pos &= 0xff;
    len &= 0xff;

    unsigned long long int m = (1u << len) - 1u;
    toInsert &= m;
    toInsert <<= pos;
    m <<= pos;

    return (val & ~m) | toInsert;
#else
    unsigned long long int ret;
    asm("bfi.b64 %0, %1, %2, %3, %4;" :
        "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
    return ret;
#endif // USE_ROCM
  }
};

__device__ __forceinline__ int getLaneId() {
#if defined(USE_ROCM)
  return __lane_id();
#else
  int laneId;
  asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
  return laneId;
#endif // USE_ROCM
}

#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
  unsigned long long int m = (1ull << getLaneId()) - 1ull;
  return m;
}

__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
  unsigned long long int m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
  return m;
}

__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
  unsigned long long int m = getLaneMaskLe();
  return m ? ~m : m;
}

__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
  unsigned long long int m = getLaneMaskLt();
  return ~m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLt() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
  return mask;
}

__device__ __forceinline__ unsigned getLaneMaskLe() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
  return mask;
}

__device__ __forceinline__ unsigned getLaneMaskGt() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
  return mask;
}

__device__ __forceinline__ unsigned getLaneMaskGe() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
  return mask;
}
#endif // USE_ROCM

}  // namespace caffe2

#endif  // CAFFE2_UTILS_GPU_DEFS_H_