File: triton-tut-softmax-kernel.ll

package info (click to toggle)
llvm-toolchain-21 1%3A21.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,245,064 kB
  • sloc: cpp: 7,619,731; ansic: 1,434,018; asm: 1,058,748; python: 252,740; f90: 94,671; objc: 70,685; lisp: 42,813; pascal: 18,401; sh: 8,601; ml: 5,111; perl: 4,720; makefile: 3,676; awk: 3,523; javascript: 2,409; xml: 892; fortran: 770
file content (221 lines) | stat: -rw-r--r-- 9,069 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
; This is an excerpt from the tutorial of the Triton language converted into
; LLVM IR via the Triton XPU backend and cleaned of irrelevant details.
; The only pass criterion is that spirv-val considers output valid.

; Ths particular case is related to translation of <1 x Ty> vectors.

; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}

define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) {
  %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0)
  %9 = trunc i64 %8 to i32
  %10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0)
  %11 = trunc i64 %10 to i32
  %12 = tail call spir_func i64 @_Z12get_local_idj(i32 0)
  %13 = trunc i64 %12 to i32
  %14 = and i32 %13, 255
  %15 = or disjoint i32 %14, 256
  %16 = or disjoint i32 %14, 512
  %17 = or disjoint i32 %14, 768
  %18 = icmp slt i32 %14, %5
  %19 = icmp slt i32 %15, %5
  %20 = icmp slt i32 %16, %5
  %21 = icmp slt i32 %17, %5
  %22 = icmp sgt i32 %4, %9
  br i1 %22, label %.lr.ph, label %._crit_edge

.lr.ph:                                           ; preds = %7
  %23 = lshr i64 %12, 5
  %24 = and i32 %13, 31
  %25 = zext nneg i32 %15 to i64
  %26 = zext nneg i32 %16 to i64
  %27 = zext nneg i32 %17 to i64
  %28 = and i64 %12, 255
  %29 = and i64 %23, 7
  %30 = icmp eq i32 %24, 0
  %31 = getelementptr float, ptr addrspace(3) %6, i64 %29
  %32 = icmp slt i32 %13, 8
  %sext = shl i64 %12, 32
  %33 = ashr exact i64 %sext, 30
  %34 = getelementptr i8, ptr addrspace(3) %6, i64 %33
  %35 = and i32 %13, 7
  %36 = icmp eq i32 %35, 0
  %37 = and i1 %32, %36
  br label %38

38:                                               ; preds = %.lr.ph, %123
  %39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ]
  %40 = mul i32 %39, %2
  %41 = sext i32 %40 to i64
  %42 = getelementptr float, ptr addrspace(1) %1, i64 %41
  %43 = getelementptr float, ptr addrspace(1) %42, i64 %25
  %44 = getelementptr float, ptr addrspace(1) %42, i64 %26
  %45 = getelementptr float, ptr addrspace(1) %42, i64 %27
  br i1 %18, label %46, label %49

46:                                               ; preds = %38
  %47 = getelementptr float, ptr addrspace(1) %42, i64 %28
  %48 = load <1 x float>, ptr addrspace(1) %47, align 4
  br label %49

49:                                               ; preds = %46, %38
  %50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ]
  %51 = extractelement <1 x float> %50, i64 0
  br i1 %19, label %52, label %54

52:                                               ; preds = %49
  %53 = load <1 x float>, ptr addrspace(1) %43, align 4
  br label %54

54:                                               ; preds = %52, %49
  %55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ]
  %56 = extractelement <1 x float> %55, i64 0
  br i1 %20, label %57, label %59

57:                                               ; preds = %54
  %58 = load <1 x float>, ptr addrspace(1) %44, align 4
  br label %59

59:                                               ; preds = %57, %54
  %60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ]
  %61 = extractelement <1 x float> %60, i64 0
  br i1 %21, label %62, label %64

62:                                               ; preds = %59
  %63 = load <1 x float>, ptr addrspace(1) %45, align 4
  br label %64

64:                                               ; preds = %62, %59
  %65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ]
  %66 = extractelement <1 x float> %65, i64 0
  tail call spir_func void @_Z7barrierj(i32 1)
  %67 = tail call float @llvm.maxnum.f32(float %51, float %56)
  %68 = tail call float @llvm.maxnum.f32(float %67, float %61)
  %69 = tail call float @llvm.maxnum.f32(float %68, float %66)
  %70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69)
  br i1 %30, label %71, label %72

71:                                               ; preds = %64
  store float %70, ptr addrspace(3) %31, align 4
  br label %72

72:                                               ; preds = %71, %64
  tail call spir_func void @_Z7barrierj(i32 1)
  br i1 %32, label %74, label %.thread1

.thread1:                                         ; preds = %72
  %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float poison, i32 8)
  br label %78

74:                                               ; preds = %72
  %75 = load float, ptr addrspace(3) %34, align 4
  %76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8)
  br i1 %37, label %77, label %78

77:                                               ; preds = %74
  store float %76, ptr addrspace(3) %34, align 4
  br label %78

78:                                               ; preds = %.thread1, %77, %74
  tail call spir_func void @_Z7barrierj(i32 1)
  %79 = load float, ptr addrspace(3) %6, align 4
  %80 = fsub float %51, %79
  %81 = fsub float %56, %79
  %82 = fsub float %61, %79
  %83 = fsub float %66, %79
  %84 = fmul float %80, 0x3FF7154760000000
  %85 = tail call float @llvm.exp2.f32(float %84)
  %86 = fmul float %81, 0x3FF7154760000000
  %87 = tail call float @llvm.exp2.f32(float %86)
  %88 = fmul float %82, 0x3FF7154760000000
  %89 = tail call float @llvm.exp2.f32(float %88)
  %90 = fmul float %83, 0x3FF7154760000000
  %91 = tail call float @llvm.exp2.f32(float %90)
  tail call spir_func void @_Z7barrierj(i32 1)
  %92 = fadd float %85, %87
  %93 = fadd float %89, %92
  %94 = fadd float %91, %93
  %95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94)
  br i1 %30, label %96, label %97

96:                                               ; preds = %78
  store float %95, ptr addrspace(3) %31, align 4
  br label %97

97:                                               ; preds = %96, %78
  tail call spir_func void @_Z7barrierj(i32 1)
  br i1 %32, label %99, label %.thread

.thread:                                          ; preds = %97
  %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float poison, i32 8)
  br label %103

99:                                               ; preds = %97
  %100 = load float, ptr addrspace(3) %34, align 4
  %101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8)
  br i1 %37, label %102, label %103

102:                                              ; preds = %99
  store float %101, ptr addrspace(3) %34, align 4
  br label %103

103:                                              ; preds = %.thread, %102, %99
  tail call spir_func void @_Z7barrierj(i32 1)
  %104 = load float, ptr addrspace(3) %6, align 4
  %105 = fdiv float %87, %104
  %106 = fdiv float %89, %104
  %107 = fdiv float %91, %104
  %108 = mul i32 %39, %3
  %109 = sext i32 %108 to i64
  %110 = getelementptr float, ptr addrspace(1) %0, i64 %109
  %111 = getelementptr float, ptr addrspace(1) %110, i64 %25
  %112 = getelementptr float, ptr addrspace(1) %110, i64 %26
  %113 = getelementptr float, ptr addrspace(1) %110, i64 %27
  br i1 %18, label %114, label %117

114:                                              ; preds = %103
  %115 = fdiv float %85, %104
  %116 = getelementptr float, ptr addrspace(1) %110, i64 %28
  store float %115, ptr addrspace(1) %116, align 4
  br label %117

117:                                              ; preds = %114, %103
  br i1 %19, label %118, label %119

118:                                              ; preds = %117
  store float %105, ptr addrspace(1) %111, align 4
  br label %119

119:                                              ; preds = %118, %117
  br i1 %20, label %120, label %121

120:                                              ; preds = %119
  store float %106, ptr addrspace(1) %112, align 4
  br label %121

121:                                              ; preds = %120, %119
  br i1 %21, label %122, label %123

122:                                              ; preds = %121
  store float %107, ptr addrspace(1) %113, align 4
  br label %123

123:                                              ; preds = %122, %121
  %124 = add i32 %39, %11
  %125 = icmp slt i32 %124, %4
  br i1 %125, label %38, label %._crit_edge

._crit_edge:                                      ; preds = %123, %7
  ret void
}

declare float @llvm.maxnum.f32(float, float)
declare spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, float, i32)
declare spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32, i32, float)
declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32, i32, float, i32)
declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, float)
declare spir_func void @_Z7barrierj(i32)
declare spir_func i64 @_Z12get_local_idj(i32)
declare spir_func i64 @_Z14get_num_groupsj(i32)
declare spir_func i64 @_Z12get_group_idj(i32)
declare float @llvm.exp2.f32(float)