1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON
// This implements a 2D multisize tiling with target sizes [3, 10].
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%3:2 = transform.structured.tile %2#0 [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile %2#1 [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%5 = merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
transform.structured.tile %6#0 [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile %6#1 [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
// Without canonicalization, tile sizes are computed dynamically as affine maps.
// NOCANON-LABEL: @two_d
// NOCANON-COUNT-8: affine.apply
// NOCANON: scf.for
// CHECK-LABEL: @two_d
// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
func.func @two_d(%arg0: tensor<10x34xf32>,
%arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
%0 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]
}
ins(%arg0: tensor<10x34xf32>)
outs(%arg1: tensor<10x34xf32>) {
^bb0(%0: f32, %1: f32):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
linalg.yield %call_res : f32
} -> tensor<10x34xf32>
// 2D multi-size tiling should produce for quadrants with sizes
// (2, 8), (2, 9), (3, 8), (3, 9)
// respectively, and in this order.
// Check the full code for the first quadrant, the data flow for the second
// quadrant and only the overall code structure for the remaining quadrants.
// The canonicalizer is able to recover static shapes of for linalg.generic
// instances, use those to differentiate the quadrants.
// CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
// CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
// CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
// CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
// CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
// CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
// CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
// CHECK: scf.yield %[[RESPARTIAL]]
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
// CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
// CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
// CHECK: scf.yield %[[INSERTED_3]]
// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK-COUNT-2: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[RESULT:.+]] = tensor.insert_slice
// CHECK: return %[[RESULT]]
return %0 : tensor<10x34xf32>
}
// -----
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
%3:2 = transform.structured.tile %2#0 [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile %2#1 [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%5 = merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
transform.structured.tile %6#0 [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
transform.structured.tile %6#1 [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
}
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
// Even without canonicalization, tile sizes can be computed statically thanks
// to parameters.
// NOCANON-LABEL: @two_d
// NOCANON-NOT: affine.apply
// NOCANON: scf.for
// CHECK-LABEL: @two_d_param
// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
func.func @two_d_param(%arg0: tensor<10x34xf32>,
%arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
%0 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]
}
ins(%arg0: tensor<10x34xf32>)
outs(%arg1: tensor<10x34xf32>) {
^bb0(%0: f32, %1: f32):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
linalg.yield %call_res : f32
} -> tensor<10x34xf32>
// CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
// CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
// CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
// CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
// CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
// CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
// CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
// CHECK: scf.yield %[[RESPARTIAL]]
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
// CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
// CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
// CHECK: scf.yield %[[INSERTED_3]]
// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK-COUNT-2: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[RESULT:.+]] = tensor.insert_slice
// CHECK: return %[[RESULT]]
return %0 : tensor<10x34xf32>
}
|