1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
|
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(scf-parallel-loop-fusion))' -split-input-file | FileCheck %s
func.func @fuse_empty_loops() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
return
}
// CHECK-LABEL: func @fuse_empty_loops
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: scf.yield
// CHECK: }
// CHECK-NOT: scf.parallel
// -----
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.yield
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_two
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
// -----
func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
%result: memref<100x10xf32>) {
%c100 = arith.constant 100 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%broadcast_rhs = memref.alloc() : memref<100x10xf32>
%diff = memref.alloc() : memref<100x10xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%rhs_elem = memref.load %rhs[%i] : memref<100xf32>
memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
%broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
%diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32>
%exp_elem = math.exp %diff_elem : f32
memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
scf.yield
}
memref.dealloc %broadcast_rhs : memref<100x10xf32>
memref.dealloc %diff : memref<100x10xf32>
return
}
// CHECK-LABEL: func @fuse_three
// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
// CHECK: [[C100:%.*]] = arith.constant 100 : index
// CHECK: [[C10:%.*]] = arith.constant 10 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc()
// CHECK: [[DIFF:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// CHECK: memref.dealloc [[BROADCAST_RHS]]
// CHECK: memref.dealloc [[DIFF]]
// -----
func.func @do_not_fuse_nested_ploop1() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_nested_ploop1
// CHECK: scf.parallel
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_nested_ploop2() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_nested_ploop2
// CHECK: scf.parallel
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_loops_unmatching_num_loops() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
%buffer = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_loops_unmatching_iteration_space() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_unmatching_write_read_patterns(
%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%common_buf = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%k = arith.addi %i, %c1 : index
%sum_elem = memref.load %common_buf[%k, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.yield
}
memref.dealloc %common_buf : memref<2x2xf32>
return
}
// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_unmatching_read_write_patterns(
%A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %common_buf[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%k = arith.addi %i, %c1 : index
%sum_elem = memref.load %sum[%k, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
scf.yield
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%buffer = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
: memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
scf.yield
}
return
}
// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
// CHECK: scf.parallel
// CHECK: scf.parallel
// -----
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.yield
}
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @nested_fuse
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
|