File: IBiF_matrix.cl

package info (click to toggle)
intel-graphics-compiler 1.0.12504.6-1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 83,912 kB
  • sloc: cpp: 910,147; lisp: 202,655; ansic: 15,197; python: 4,025; yacc: 2,241; lex: 1,570; pascal: 244; sh: 104; makefile: 25
file content (260 lines) | stat: -rw-r--r-- 11,095 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
/*========================== begin_copyright_notice ============================

Copyright (C) 2021 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#define ARR_TO_VEC8(type, arr) \
    (type##8)(wi_contrib[0], wi_contrib[1], wi_contrib[2], wi_contrib[3], \
              wi_contrib[4], wi_contrib[5], wi_contrib[6], wi_contrib[7])

#define ARR_TO_VEC4(type, arr) \
    (type##4)(wi_contrib[0], wi_contrib[1], wi_contrib[2], wi_contrib[3])

#define ARR_TO_VEC2(type, arr) \
    (type##2)(wi_contrib[0], wi_contrib[1])

#define ARR_TO_VEC1(type, arr) \
    wi_contrib[0]

#define LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, element_type, contrib_type, M) \
    contrib_type *ptr = (contrib_type *)mem; \
    int slid = get_sub_group_local_id(); \
    int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
    stride = stride / pack_factor; \
    contrib_type wi_contrib[M]; \
    for (int i = 0; i < M; i++) \
        wi_contrib[i] = *(ptr + slid + (i * stride)); \

/* PackedA load i16 */
INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_8x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int4 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_4x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 4)
    return ARR_TO_VEC4(int, wi_contrib);
}

INLINE int2 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_2x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 2)
    return ARR_TO_VEC2(int, wi_contrib);
}

INLINE int __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_1x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 2)
    return ARR_TO_VEC1(int, wi_contrib);
}

/* PackedA load i8 */
INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_8x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int4 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_4x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 4)
    return ARR_TO_VEC4(int, wi_contrib);
}

INLINE int2 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_2x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 4)
    return ARR_TO_VEC2(int, wi_contrib);
}

INLINE int __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_1x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 4)
    return ARR_TO_VEC1(int, wi_contrib);
}

/* PackedA load i16 SG16 */
INLINE short8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_8x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, short, 8)
    return ARR_TO_VEC8(short, wi_contrib);
}

INLINE short4 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_4x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, short, 4)
    return ARR_TO_VEC4(short, wi_contrib);
}

INLINE short2 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_2x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, short, 2)
    return ARR_TO_VEC2(short, wi_contrib);
}

INLINE short __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_1x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, short, 2)
    return ARR_TO_VEC1(short, wi_contrib);
}

/* PackedA load i8 SG16 */
INLINE short8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_8x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, short, 8)
    return ARR_TO_VEC8(short, wi_contrib);
}

INLINE short4 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_4x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, short, 4)
    return ARR_TO_VEC4(short, wi_contrib);
}

INLINE short2 __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_2x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, short, 2)
    return ARR_TO_VEC2(short, wi_contrib);
}

INLINE short __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_1x32_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, short, 2)
    return ARR_TO_VEC1(short, wi_contrib);
}

#define LOAD_PACKED_B_FROM_COL_MAJOR(mem, stride, element_type, N) \
    int *ptr = (int *)mem; \
    int slid = get_sub_group_local_id(); \
    int pack_factor = sizeof (int) / sizeof (element_type); \
    int wi_contrib[8]; \
    for (int i = 0; i < 8; i++) \
        wi_contrib[i] = *(ptr + i + (slid * stride)); \

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_ColumnMajor_16x8_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_B_FROM_COL_MAJOR(mem, stride, short, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_ColumnMajor_32x8_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_B_FROM_COL_MAJOR(mem, stride, char, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_16x8_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_32x8_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_16x16_i16_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, short, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_32x16_i8_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, char, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

/* Load accumulator is a special case of load packed A, both are row major: */
INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_8x8_i32_v8i8_pi32_i32(char *mem, int stride) {
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, int, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
}

/* Experimental new implementation: */
#define JOINT_MATRIX_USE_BLOCK_OPS 0

INLINE int8 __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_8x16_i32_v8i8_pi32_i32(char *mem, int stride) {
#if JOINT_MATRIX_USE_BLOCK_OPS
    LOAD_PACKED_A_FROM_ROW_MAJOR(mem, stride, int, int, 8)
    return ARR_TO_VEC8(int, wi_contrib);
#else
    __global uint *ptr = (__global uint *)mem;
    int slid = get_sub_group_local_id();
    int8 result;

    result.s0 = (int) intel_sub_group_block_read(ptr + slid + (0 * stride));
    result.s1 = (int) intel_sub_group_block_read(ptr + slid + (1 * stride));
    result.s2 = (int) intel_sub_group_block_read(ptr + slid + (2 * stride));
    result.s3 = (int) intel_sub_group_block_read(ptr + slid + (3 * stride));
    result.s4 = (int) intel_sub_group_block_read(ptr + slid + (4 * stride));
    result.s5 = (int) intel_sub_group_block_read(ptr + slid + (5 * stride));
    result.s6 = (int) intel_sub_group_block_read(ptr + slid + (6 * stride));
    result.s7 = (int) intel_sub_group_block_read(ptr + slid + (7 * stride));

    return result;
#endif
}

#define STORE_PACK_A_ROW_MAJOR(dst, stride, elem_t, contrib_t, M, N) \
    contrib_t *ptr = (contrib_t *)mem; \
    int slid = get_sub_group_local_id(); \
    int pack_factor = sizeof (contrib_t) / sizeof (elem_t); \
    stride = stride / pack_factor; \
    for (int i = 0; i < M; i++) \
        ptr[i * stride + slid] = row[i]; \

INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_8x32_i8_pi64_v8i8(char *mem, int8 row, int stride) {
    STORE_PACK_A_ROW_MAJOR(mem, stride, char, int, 8, 32)
}

INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_8x16_i16_pi64_v8i8(char *mem, int8 row, int stride) {
    STORE_PACK_A_ROW_MAJOR(mem, stride, short, int, 8, 16)
}

#define STORE_PACK_B_COL_MAJOR(dst, stride, elem_t, contrib_t, M, N) \
    contrib_t *ptr = (contrib_t *)mem; \
    int slid = get_sub_group_local_id(); \
    int pack_factor = sizeof (contrib_t) / sizeof (elem_t); \
    stride = stride / pack_factor; \
    for (int i = 0; i < M; i++) \
        ptr[i * stride + slid] = col[i]; \

INLINE int8 __builtin_spriv_OpJointMatrixStoreINTEL_PackedB_PackedB_16x8_i16_pi64_v8i8(char *mem, int8 col, int stride) {
    STORE_PACK_B_COL_MAJOR(mem, stride, short, int, 8, 16)
}

INLINE int8 __builtin_spriv_OpJointMatrixStoreINTEL_PackedB_PackedB_32x8_i8_pi64_v8i8(char *mem, int8 col, int stride) {
    STORE_PACK_B_COL_MAJOR(mem, stride, char, int, 8, 32)
}

INLINE int8 __builtin_spriv_OpJointMatrixStoreINTEL_PackedB_PackedB_16x16_i16_pi64_v8i8(char *mem, int8 col, int stride) {
    STORE_PACK_B_COL_MAJOR(mem, stride, short, int, 8, 16)
}

#define STORE_ACC_ROW_MAJOR(dst, stride, M) \
    int *ptr = (int *)dst; \
    int slid = get_sub_group_local_id(); \
    for (int i = 0; i < M; i++) \
        ptr[slid + i * stride] = row[i]; \


#define STORE_BLOCK_ACC_ROW_MAJOR(dst, stride, M) \
    __global uint *ptr = (__global uint *)dst; \
    for (int i = 0; i < M; i++) \
        intel_sub_group_block_write(ptr + i * stride, (uint) row[i]);

INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_8x8_i32_pi64_v8i8(char *mem, int8 row, int stride) {
    STORE_ACC_ROW_MAJOR(mem, stride, 8)
}

INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_8x16_i32_pi64_v8i8(char *mem, int8 row, int stride) {
#if JOINT_MATRIX_USE_BLOCK_OPS
    STORE_ACC_ROW_MAJOR(mem, stride, 8)
#else
    __global uint *ptr = (__global uint *)mem;
    intel_sub_group_block_write(ptr + 0 * stride, (uint) row[0]);
    intel_sub_group_block_write(ptr + 1 * stride, (uint) row[1]);
    intel_sub_group_block_write(ptr + 2 * stride, (uint) row[2]);
    intel_sub_group_block_write(ptr + 3 * stride, (uint) row[3]);
    intel_sub_group_block_write(ptr + 4 * stride, (uint) row[4]);
    intel_sub_group_block_write(ptr + 5 * stride, (uint) row[5]);
    intel_sub_group_block_write(ptr + 6 * stride, (uint) row[6]);
    intel_sub_group_block_write(ptr + 7 * stride, (uint) row[7]);
#endif
}

#define STORE_ACC_COL_MAJOR(mem, stride, M) \
    int *ptr = (int *)mem; \
    int slid = get_sub_group_local_id(); \
    for (int i = 0; i < M; i++) \
        ptr[slid * stride + i] = row[i]; \

INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_ColumnMajor_8x8_i32_pi64_v8i8(char *mem, int8 row, int stride) {
    STORE_ACC_COL_MAJOR(mem, stride, 8)
}