File: vector-contract-to-outerproduct-matvec-transforms.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (667 lines) | stat: -rw-r--r-- 35,592 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
/// Matvec operations:
///   b += A * x.
/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
/// are tested:
///   * plain (no mask, fixed-wdith vectors),
///   * masked (fixed-width vectors,
///   * scalable (mask + scalable vectors).
///
/// TODO: These tests were extracted from 2 different files. If you find the
/// formatting inconsistent, please update accordingly.

#matvec_accesses_1 = [
  affine_map<(m, k) -> (m, k)>,
  affine_map<(m, k) -> (k)>,
  affine_map<(m, k) -> (m)>
]
#matvec_trait_1 = {
  indexing_maps = #matvec_accesses_1,
  iterator_types = ["parallel", "reduction"]
}

#matvecmax_trait = {
  indexing_maps = #matvec_accesses_1,
  iterator_types = ["parallel", "reduction"],
  kind = #vector.kind<maxnumf>
}

#matvec_accesses_2 = [
  affine_map<(m, k) -> (k, m)>,
  affine_map<(m, k) -> (k)>,
  affine_map<(m, k) -> (m)>
]
#matvec_trait_2 = {
  indexing_maps = #matvec_accesses_2,
  iterator_types = ["parallel", "reduction"]
}

#matvec_accesses_3 = [
  affine_map<(m, k) -> (k)>,
  affine_map<(m, k) -> (m, k)>,
  affine_map<(m, k) -> (m)>
]
#matvec_trait_3 = {
  indexing_maps = #matvec_accesses_3,
  iterator_types = ["parallel", "reduction"]
}

#matvec_accesses_4 = [
  affine_map<(m, k) -> (k)>,
  affine_map<(m, k) -> (k, m)>,
  affine_map<(m, k) -> (m)>
]
#matvec_trait_4 = {
  indexing_maps = #matvec_accesses_4,
  iterator_types = ["parallel", "reduction"]
}

#matvec_accesses_5 = [
  affine_map<(k, m) -> (m, k)>,
  affine_map<(k, m) -> (k)>,
  affine_map<(k, m) -> (m)>
]
#matvec_trait_5 = {
  indexing_maps = #matvec_accesses_5,
  iterator_types = ["reduction", "parallel"]
}

#matvec_accesses_6 = [
  affine_map<(k, m) -> (k, m)>,
  affine_map<(k, m) -> (k)>,
  affine_map<(k, m) -> (m)>
]
#matvec_trait_6 = {
  indexing_maps = #matvec_accesses_6,
  iterator_types = ["reduction", "parallel"]
}

#matvec_accesses_7 = [
  affine_map<(k, m) -> (k)>,
  affine_map<(k, m) -> (m, k)>,
  affine_map<(k, m) -> (m)>
]
#matvec_trait_7 = {
  indexing_maps = #matvec_accesses_7,
  iterator_types = ["reduction", "parallel"]
}

#matvec_accesses_8 = [
  affine_map<(k, m) -> (k)>,
  affine_map<(k, m) -> (k, m)>,
  affine_map<(k, m) -> (m)>
]
#matvec_trait_8 = {
  indexing_maps = #matvec_accesses_8,
  iterator_types = ["reduction", "parallel"]
}

// ============================================================================
//  Matvec 1 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: func @matvec_mk_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
                         %x: vector<2xf32>,
                         %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
// CHECK-SAME:      %{{.*}}: vector<3xf32>,
// CHECK-SAME:      %{{.*}}: vector<2xf32>,
// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>

// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>

// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
                                %x: vector<3xf32>,
                                %b: vector<2xf32>,
                                %m: vector<2x3xi1>) -> vector<2xf32> {
  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
// CHECK-SAME:      %{{.*}}: vector<3xf32>,
// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>

// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>

// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
                                                      %x: vector<3xf32>,
                                                      %b: vector<[2]xf32>,
                                                      %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
  return %0 : vector<[2]xf32>
}

// ============================================================================
//  Matvec 1  - max (plain)
// ============================================================================
// CHECK-LABEL: func @matvec_mk_k_m_max
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32
func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
                             %x: vector<2xf32>,
                             %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max(
// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
// CHECK-SAME:      %{{.*}}: vector<3xf32>,
// CHECK-SAME:      %{{.*}}: vector<2xf32>,
// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>

// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>

// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>,
                                    %x: vector<3xf32>,
                                    %b: vector<2xf32>,
                                    %m: vector<2x3xi1>) -> vector<2xf32> {
  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
// CHECK-SAME:      %{{.*}}: vector<3xf32>,
// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>

// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>

// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
                                                          %x: vector<3xf32>,
                                                          %b: vector<[2]xf32>,
                                                          %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
  return %0 : vector<[2]xf32>
}

// ============================================================================
//  Matvec 2 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: func @matvec_km_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
                         %x: vector<2xf32>,
                         %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_matvec_km_k_m
// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
                                %x: vector<2xf32>,
                                %b: vector<4xf32>, 
                                %mask: vector<4x2xi1>) -> vector<4xf32> {
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_2 %A, %x, %b
      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
  } : vector<4x2xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
                                                      %x: vector<2xf32>,
                                                      %b: vector<[4]xf32>,
                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_2 %A, %x, %b
      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<[4]x2xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// ============================================================================
//  Matvec 3 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: func @matvec_k_mk_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
                         %x: vector<2xf32>,
                         %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_matvec_k_mk_m
// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
                                %x: vector<2xf32>,
                                %b: vector<4xf32>,
                                %mask: vector<4x2xi1>) -> vector<4xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
      vector.contract #matvec_trait_3 %x, %A, %b
        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
  } : vector<4x2xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                      %x: vector<2xf32>,
                                                      %b: vector<[4]xf32>,
                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
      vector.contract #matvec_trait_3 %x, %A, %b
        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<[4]x2xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// ============================================================================
//  Matvec 4 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: func @matvec_k_km_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_k_km_m(%A: vector<2x2xf32>,
                         %x: vector<2xf32>,
                         %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
                                                      %x: vector<2xf32>,
                                                      %b: vector<[4]xf32>,
                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_4 %x, %A, %b
      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<[4]x2xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// CHECK-LABEL: @masked_matvec_k_km_m
// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
                                %x: vector<2xf32>,
                                %b: vector<4xf32>,
                                %mask: vector<4x2xi1>) -> vector<4xf32> {
  // CHECK:         vector.transpose %[[MASK]]
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_4 %x, %A, %b
      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
  } : vector<4x2xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// ============================================================================
//  Matvec 5 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL:   func.func @tmatvec_mk_k_m(
// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @tmatvec_mk_k_m(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
                          %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_5 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_tmatvec_mk_k_m
// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
func.func @masked_tmatvec_mk_k_m(%A: vector<4x2xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_5 %A, %x, %b
      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
  } : vector<2x4xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_5 %A, %x, %b
      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<2x[4]xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// ============================================================================
//  Matvec 6 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL:   func.func @tmatvec_km_k_m(
// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
// CHECK:           %[[VAL_3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK:           %[[VAL_5:.*]] = vector.outerproduct %[[VAL_3]], %[[VAL_4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK:           %[[VAL_6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK:           %[[VAL_8:.*]] = vector.outerproduct %[[VAL_6]], %[[VAL_7]], %[[VAL_5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @tmatvec_km_k_m(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
                          %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_6 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_tmatvec_km_k_m
// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
func.func @masked_tmatvec_km_k_m(%A: vector<2x4xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_6 %A, %x, %b
      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
  } : vector<2x4xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_6 %A, %x, %b
      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<2x[4]xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// ============================================================================
//  Matvec 7 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL:   func.func @tmatvec_k_mk_m(
// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @tmatvec_k_mk_m(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
                          %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_tmatvec_k_mk_m
// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
func.func @masked_tmatvec_k_mk_m(%A: vector<4x2xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_7 %x, %A, %b
      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
  } : vector<2x4xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
  // CHECK:         vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_7 %x, %A, %b
      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<2x[4]xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// ============================================================================
//  Matvec 8 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: func @tmatvec_m_mk_k
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @tmatvec_m_mk_k(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
                          %b: vector<2xf32>) -> vector<2xf32> {
  %0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
  return %0 : vector<2xf32>
}

// CHECK-LABEL: @masked_tmatvec_k_km_m
// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
func.func @masked_tmatvec_k_km_m(%A: vector<2x4xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_8 %x, %A, %b
      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
  } : vector<2x4xi1> -> vector<4xf32>
  return %res : vector<4xf32>
}

// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
  // CHECK-NOT:     vector.transpose %[[A]]
  // CHECK-NOT:     vector.transpose %[[MASK]]
  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
  %res = vector.mask %mask {
    vector.contract #matvec_trait_8 %x, %A, %b
      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
  } : vector<2x[4]xi1> -> vector<[4]xf32>
  return %res : vector<[4]xf32>
}

// Unrolling scalable reduction dim is not supported - bail out
// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
// CHECK:         vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
                                    %arg1: vector<[3]xf32>,
                                    %arg2: vector<[2]xf32>,
                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
  return %0 : vector<[2]xf32>
}

// ============================================================================
//  TD sequence
// ============================================================================
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
    transform.apply_patterns to %func_op {
      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
    } : !transform.op<"func.func">
    transform.yield
  }
}