File: vector-unroll-options.mlir

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (294 lines) | stat: -rw-r--r-- 18,521 bytes parent folder | download | duplicates (10)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1"  | FileCheck %s --check-prefix=ORDER
// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=0,3,1,2" | FileCheck %s --check-prefix=BATCHED

func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
                          %init : vector<8x8xf32>) -> vector<8x8xf32> {
  %0 = vector.contract
         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                           affine_map<(i, j, k) -> (j, k)>,
                           affine_map<(i, j, k) -> (i, j)>],
          iterator_types = ["parallel", "parallel", "reduction"]}
       %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
  return %0 : vector<8x8xf32>
}
// CHECK-LABEL: func @vector_contract_f32
// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [0, 0]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [0, 0]
//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  CHECK-SAME:   offsets = [0, 0]
//       CHECK:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [0, 2]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [0, 2]
//       CHECK:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [0, 0]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [4, 0]
//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  CHECK-SAME:   offsets = [0, 4]
//       CHECK:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [0, 2]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [4, 2]
//       CHECK:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [4, 0]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [0, 0]
//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  CHECK-SAME:   offsets = [4, 0]
//       CHECK:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [4, 2]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [0, 2]
//       CHECK:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [4, 0]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [4, 0]
//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  CHECK-SAME:   offsets = [4, 4]
//       CHECK:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  CHECK-SAME:   offsets = [4, 2]
//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  CHECK-SAME:   offsets = [4, 2]
//       CHECK:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       CHECK:   return

// ORDER-LABEL: func @vector_contract_f32
// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [0, 0]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [0, 0]
//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  ORDER-SAME:   offsets = [0, 0]
//       ORDER:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [0, 0]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [4, 0]
//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  ORDER-SAME:   offsets = [0, 4]
//       ORDER:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [4, 0]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [0, 0]
//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  ORDER-SAME:   offsets = [4, 0]
//       ORDER:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [4, 0]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [4, 0]
//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
//  ORDER-SAME:   offsets = [4, 4]
//       ORDER:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [0, 2]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [0, 2]
//       ORDER:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [0, 2]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [4, 2]
//       ORDER:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [4, 2]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [0, 2]
//       ORDER:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
//  ORDER-SAME:   offsets = [4, 2]
//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
//  ORDER-SAME:   offsets = [4, 2]
//       ORDER:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>

//       ORDER:   return



func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
                          %init : vector<8x8xf16>) -> vector<8x8xf16> {
  %0 = vector.contract
         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                           affine_map<(i, j, k) -> (j, k)>,
                           affine_map<(i, j, k) -> (i, j)>],
          iterator_types = ["parallel", "parallel", "reduction"]}
       %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
  return %0 : vector<8x8xf16>
}
// CHECK-LABEL: func @vector_contract_f16
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   vector.contract {
//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
//       CHECK:   return

func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> {
  %0 = vector.fma %a, %b, %c: vector<4x4xf32>
  return %0 : vector<4x4xf32>
}
//   CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>

func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
  %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
  return %0 : vector<4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
//       CHECK:   %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[R2:.*]] = vector.multi_reduction <add>, %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
//       CHECK:   %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
//       CHECK:   %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32>
//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
//       CHECK:   return %[[V2]] : vector<4xf32>


func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
  %0 = vector.reduction <add>, %v : vector<8xf32> into f32
  return %0 : f32
}
// CHECK-LABEL: func @vector_reduction(
//  CHECK-SAME:     %[[v:.*]]: vector<8xf32>
//       CHECK:   %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
//       CHECK:   %[[r0:.*]] = vector.reduction <add>, %[[s0]]
//       CHECK:   %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
//       CHECK:   %[[r1:.*]] = vector.reduction <add>, %[[s1]]
//       CHECK:   %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
//       CHECK:   %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
//       CHECK:   %[[r2:.*]] = vector.reduction <add>, %[[s2]]
//       CHECK:   %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
//       CHECK:   %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
//       CHECK:   %[[r3:.*]] = vector.reduction <add>, %[[s3]]
//       CHECK:   %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
//       CHECK:   return %[[add3]]

func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
  %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
  return %t : vector<2x3x8x4xf32>
}
// CHECK-LABEL: func @vector_tranpose
//       CHECK:   %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
//       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
//       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
//       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>

// -----

func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf32>, %init: vector<8x8x8xf32>) -> vector<8x8x8xf32> {
  %0 = vector.contract
         {indexing_maps = [affine_map<(d0,d1,d2,c0) -> (d0,d1,c0)>,
                           affine_map<(d0,d1,d2,c0) -> (d0,d2,c0)>,
                           affine_map<(d0,d1,d2,c0) -> (d0,d1,d2)>],
          iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
       %lhs, %rhs, %init : vector<8x8x4xf32>, vector<8x8x4xf32> into vector<8x8x8xf32>
  return %0 : vector<8x8x8xf32>
}


//    CHECK-LABEL: vector_contract_batched
// CHECK-COUNT-16: vector.contract
//      CHECK-NOT: vector.contract
//          CHECK: return

//    UNROLL-LABEL: vector_contract_batched
//  UNROLL-COUNT-1: vector.contract
//      UNROLL-NOT: vector.contract
//          UNROLL: return


//    BATCHED-LABEL: vector_contract_batched
// BATCHED-COUNT-16: vector.contract
//      BATCHED-NOT: vector.contract
//          BATCHED: return