File: reshape_control_fusion.mlir

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (62 lines) | stat: -rw-r--r-- 3,060 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s

func.func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
  %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %1 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
      outs(%init : tensor<?x?xf32>) {
      ^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
        %2 = arith.addf %arg2, %arg3 : f32
        linalg.yield %2 : f32
      } -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}
//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
//      CHECK: func @control_producer_reshape_fusion
// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME:       {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
//      CHECK:   %[[RESULT:.+]] = linalg.generic
// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
//      CHECK:   return %[[RESULT]]

// -----

func.func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
  %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      outs(%init : tensor<?x?xf32>) {
      ^bb0(%arg2: f32):
        linalg.yield %cst : f32
      } -> tensor<?x?xf32>
  %0 = tensor.expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
      outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
  return %1 : tensor<1x?x?xf32>
}
//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
//      CHECK: func @control_consumer_reshape_fusion
//      CHECK:   %[[FILL:.+]] = linalg.generic
// CHECK-SAME:       indexing_maps = [#[[MAP]]]
// CHECK-SAME:       outs(%{{.+}} : tensor<1x?x?xf32>)
//      CHECK:   linalg.batch_matmul
// CHECK-SAME:       outs(%[[FILL]] : tensor<1x?x?xf32>)