1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
|
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"]
}
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait_0 = {
indexing_maps = #matmat_accesses_0,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func.func @masked_extract_contract2(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
%arg2: vector<2xf32>,
%m: vector<2x3xi1>) -> vector<2xf32> {
%0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
: vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func.func @masked_extract_contract4(
// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1>
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1>
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1>
// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1>
// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1>
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
%arg2: vector<3x7xf32>,
%m : vector<3x7x5xi1>) -> vector<3x7xf32> {
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
return %0 : vector<3x7xf32>
}
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
//
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
//
// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32>
// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
//
// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32>
// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
//
// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32>
// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
//
// CHECK: return %[[c3]] : vector<2x3xf32>
func.func @matmul(%arg0: vector<2x4xf32>,
%arg1: vector<4x3xf32>,
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
// CHECK-LABEL: func @matmul_0
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
-> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
// CHECK-LABEL: func @matmul_0_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16>
// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
-> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait_1 = {
indexing_maps = #matmat_accesses_1,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_1
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
-> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait_2 = {
indexing_maps = #matmat_accesses_2,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_2
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
-> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait_3 = {
indexing_maps = #matmat_accesses_3,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_3
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
-> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
#matmat_accesses_4 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (n, m)>
]
#matmat_trait_4 = {
indexing_maps = #matmat_accesses_4,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_4
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
-> vector<3x2xf32>
{
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
#matmat_accesses_5 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (n, m)>
]
#matmat_trait_5 = {
indexing_maps = #matmat_accesses_5,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_5
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
-> vector<3x2xf32>
{
%0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
#matmat_accesses_6 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (n, m)>
]
#matmat_trait_6 = {
indexing_maps = #matmat_accesses_6,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_6
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
-> vector<3x2xf32>
{
%0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
#matmat_accesses_7 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (n, m)>
]
#matmat_trait_7 = {
indexing_maps = #matmat_accesses_7,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @matmul_7
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
-> vector<3x2xf32>
{
%0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
}
|