File: rewrite-as-constant.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (158 lines) | stat: -rw-r--r-- 6,066 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
    transform.apply_patterns to %func_op {
      transform.apply_patterns.tensor.rewrite_as_constant
    } : !transform.op<"func.func">
    transform.yield
  }
}

// CHECK-LABEL: func @tensor_generate_constant(
//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32>
//       CHECK:   return %[[cst]]
func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
  %cst = arith.constant 5.0 : f32
  %0 = tensor.generate {
    ^bb0(%arg0: index, %arg1: index, %arg2: index):
    tensor.yield %cst : f32
  } : tensor<2x3x5xf32>
  return %0 : tensor<2x3x5xf32>
}

//         CHECK-LABEL: func @pad_of_ints(
//               CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
// CHECK-SAME{LITERAL}:     [0, 6, 7, 0],
// CHECK-SAME{LITERAL}:     [0, 8, 9, 0],
// CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
//               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
//               CHECK: return %[[cast]]
func.func @pad_of_ints() -> tensor<?x?xi32> {
  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %c1 = arith.constant 1 : index

  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<?x?xi32>

  return %0 : tensor<?x?xi32>
}

//         CHECK-LABEL: func @pad_of_floats(
//               CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}:     [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}:     [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xf32>
//               CHECK: return %[[cst]]

func.func @pad_of_floats() -> tensor<4x4xf32> {
  %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
  %pad_value = arith.constant 0.0 : f32

  %0 = tensor.pad %init low[1, 1] high[1, 1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : f32
  } : tensor<2x2xf32> to tensor<4x4xf32>

  return %0 : tensor<4x4xf32>
}

//         CHECK-LABEL: func @pad_of_ints_no_low_dims(
//               CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}:     [6, 7, 0],
// CHECK-SAME{LITERAL}:     [8, 9, 0],
// CHECK-SAME{LITERAL}:     [0, 0, 0]
// CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
//               CHECK: return %[[cst]]
func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %0 = tensor.pad %init low[0, 0] high[1, 1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<3x3xi32>

  return %0 : tensor<3x3xi32>
}

//         CHECK-LABEL: func @pad_of_ints_no_high_dims(
//               CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}:     [0, 0, 0],
// CHECK-SAME{LITERAL}:     [0, 6, 7],
// CHECK-SAME{LITERAL}:     [0, 8, 9]
// CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
//               CHECK: return %[[cst]]
func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %0 = tensor.pad %init low[1, 1] high[0, 0] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<3x3xi32>

  return %0 : tensor<3x3xi32>
}

//         CHECK-LABEL: func @pad_multi_use_do_not_fold(
//               CHECK: %[[pad:.+]] = tensor.pad
//               CHECK: return %[[pad]]
func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %c1 = arith.constant 1 : index

  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<?x?xi32>

  return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
}

// -----

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
    transform.apply_patterns to %func_op {
      transform.apply_patterns.tensor.rewrite_as_constant aggressive
    } : !transform.op<"func.func">
    transform.yield
  }
}

//         CHECK-LABEL: func @pad_aggressive_fold(
//               CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32>
//               CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
// CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
// CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
// CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
//               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
//               CHECK: return %[[cast]]
func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
  %init = arith.constant dense<7> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %c1 = arith.constant 1 : index

  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<?x?xi32>

  return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
}