File: lower-to-loops-using-interface.mlir

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (268 lines) | stat: -rw-r--r-- 13,814 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s

func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
  %arg2 : memref<?x?xf32>) {
  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
      outs(%arg2 : memref<?x?xf32>)
  return
}
// CHECK-LABEL: func @gemm
//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]

// -----

func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
    %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
  linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>,
                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%arg0, %arg1, %arg2 : memref<200x300xi32>, memref<300xi16>, memref<200xi8>)
      outs(%arg3 : memref<300x200xi64>) {
    ^bb0(%b0 : i32, %b1 : i16, %b2 : i8, %b3 : i64):
      %0 = linalg.index 0 : index
      %1 = arith.index_cast %0 : index to i16
      %2 = arith.muli %b1, %1 : i16
      %3 = linalg.index 1 : index
      %4 = arith.index_cast %3 : index to i8
      %5 = arith.muli %b2, %4 : i8
      %6 = arith.extsi %2 : i16 to i32
      %7 = arith.extsi %5 : i8 to i32
      %8 = arith.addi %6, %7 : i32
      %9 = arith.addi %8, %b0 : i32
      %10 = arith.extsi %9 : i32 to i64
      linalg.yield %10 : i64
    }
  return
}
// CHECK-LABEL: func @indexed_generic
//  CHECK-SAME:     %[[ARG0:.+]]: memref<200x300xi32>
//  CHECK-SAME:     %[[ARG1:.+]]: memref<300xi16>
//  CHECK-SAME:     %[[ARG2:.+]]: memref<200xi8>
//  CHECK-SAME:     %[[ARG3:.+]]: memref<300x200xi64>
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[C200:.+]] = arith.constant 200 : index
//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C200]] step %[[C1]]
//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]]
//   CHECK-DAG:       %[[B0:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV1]]]
//   CHECK-DAG:       %[[B1:.+]] = memref.load %[[ARG1]][%[[IV1]]]
//   CHECK-DAG:       %[[B2:.+]] = memref.load %[[ARG2]][%[[IV0]]]
//       CHECK:       %[[T1:.+]] = arith.index_cast %[[IV0]]
//       CHECK:       %[[T2:.+]] = arith.muli %[[B1]], %[[T1]]
//       CHECK:       %[[T4:.+]] = arith.index_cast %[[IV1]]
//       CHECK:       %[[T5:.+]] = arith.muli %[[B2]], %[[T4]]
//       CHECK:       %[[T6:.+]] = arith.extsi %[[T2]]
//       CHECK:       %[[T7:.+]] = arith.extsi %[[T5]]
//       CHECK:       %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
//       CHECK:       %[[T9:.+]] = arith.addi %[[T8]], %[[B0]]
//       CHECK:       %[[T10:.+]] = arith.extsi %[[T9]]
//       CHECK:       memref.store %[[T10]], %[[ARG3]][%[[IV1]], %[[IV0]]]

// -----

func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
  %arg2 : memref<?x?x?x?xf32>) {
  linalg.conv_2d_nhwc_hwcf {
      strides = dense<[1, 2]> : tensor<2xi64>,
      dilations = dense<[3, 4]> : tensor<2xi64>}
      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
      outs(%arg2 : memref<?x?x?x?xf32>)
  return
}
//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
//       CHECK: func @conv_strides_and_dilation(
//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
//   CHECK-DAG:   %[[F:.+]] = memref.dim %[[ARG1]], %[[C3]]
//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[F]] step %[[C1]]
//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
//       CHECK:               scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
//   CHECK-DAG:                 %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
//   CHECK-DAG:                 %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
//   CHECK-DAG:                 %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
//       CHECK:                 %[[T12:.+]] = arith.mulf %[[T9]], %[[T10]]
//       CHECK:                 %[[T13:.+]] = arith.addf %[[T11]], %[[T12]]
//       CHECK:                 memref.store %[[T13]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]

// -----

func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?xf32>,
  %arg2 : memref<?x?x?x?xf32>) {
  linalg.pooling_nhwc_max {
      strides = dense<[1, 2]> : tensor<2xi64>,
      dilations = dense<[3, 4]> : tensor<2xi64>}
      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?xf32>)
      outs(%arg2 : memref<?x?x?x?xf32>)
  return
}
//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
//       CHECK: func @pool_strides_and_dilation
//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
//   CHECK-DAG:               %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
//   CHECK-DAG:               %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
//       CHECK:               %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]]
//       CHECK:               memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]

// -----

func.func @map(%lhs: memref<64xf32>,
    %rhs: memref<64xf32>, %out: memref<64xf32>) {
  linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
             outs(%out : memref<64xf32>)
    (%in: f32, %in_0: f32) {
      %0 = arith.addf %in, %in_0 : f32
      linalg.yield %0 : f32
    }
  return
}
// CHECK-LABEL: func.func @map(
// CHECK-SAME:    %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
// CHECK-SAME:    %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<64xf32>) {

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index

// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
// CHECK:       %[[LHS_ELEM:.*]] = memref.load %[[LHS]][%[[I]]]
// CHECK:       %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]]]
// CHECK:       %[[ADD:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]]
// CHECK:       memref.store %[[ADD]], %[[OUT]][%[[I]]]

// -----

func.func @transpose(%arg0: memref<16x32x64xf32>,
                               %arg1: memref<32x64x16xf32>) {
  linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
                   outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
  return
}
// CHECK-LABEL: func.func @transpose(
// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index

// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
// CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]], %[[K]]]
// CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[J]], %[[K]], %[[I]]]

// -----

func.func @reduce(%arg0: memref<16x32x64xf32>,
                  %arg1: memref<16x64xf32>) {
  linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
                outs(%arg1 : memref<16x64xf32>) dimensions = [1]
    (%in: f32, %init: f32) {
      %0 = arith.addf %in, %init : f32
      linalg.yield %0 : f32
    }
  return
}
// CHECK-LABEL: func.func @reduce(
// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index

// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
// CHECK:           %[[IN_ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]], %[[K]]]
// CHECK:           %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]]
// CHECK:           %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]]
// CHECK:           memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]]

// -----

func.func @broadcast(%input: memref<8x32xf32>,
                     %init: memref<8x16x32xf32>) {
  linalg.broadcast
      ins(%input:memref<8x32xf32>)
      outs(%init:memref<8x16x32xf32>)
      dimensions = [1]
  func.return
}
// CHECK-LABEL: func.func @broadcast(
// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index

// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
// CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
// CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]