File: transform-op-matmul-to-outerproduct.mlir

package info (click to toggle)
llvm-toolchain-18 1%3A18.1.8-18
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,908,340 kB
  • sloc: cpp: 6,667,937; ansic: 1,440,452; asm: 883,619; python: 230,549; objc: 76,880; f90: 74,238; lisp: 35,989; pascal: 16,571; sh: 10,229; perl: 7,459; ml: 5,047; awk: 3,523; makefile: 2,987; javascript: 2,149; xml: 892; fortran: 649; cs: 573
file content (40 lines) | stat: -rw-r--r-- 3,119 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s

func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
  linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
            outs(%C: memref<3x3xf32>)
  return
}

// CHECK-LABEL:   func.func @outerproduct_matmul(
// CHECK-SAME:      %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
// CHECK:           %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
// CHECK:           %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
// CHECK:           %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
// CHECK:           vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
// CHECK:           return
// CHECK:         }

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
    transform.apply_patterns to %2 {
      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
    } : !transform.any_op
    transform.yield
  }
}