1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
@run_apply_patterns
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
vector.ApplyRankReducingSubviewPatternsOp()
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
vector.ApplyTransferPermutationPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_broadcast
vector.ApplyLowerBroadcastPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masks
vector.ApplyLowerMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
vector.ApplyLowerMaskedTransfersPatternsOp()
# CHECK: transform.apply_patterns.vector.materialize_masks
vector.ApplyMaterializeMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_outerproduct
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
vector.ApplyLowerShapeCastPatternsOp()
@run_apply_patterns
def configurable_patterns():
# CHECK-LABEL: TEST: configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.lower_transfer
# CHECK-SAME: max_transfer_rank = 4
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
# CHECK: transform.apply_patterns.vector.transfer_to_scf
# CHECK-SAME: max_transfer_rank = 3
# CHECK-SAME: full_unroll = true
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
@run_apply_patterns
def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.lower_contraction
vector.ApplyLowerContractionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = matmulintrinsics
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.Matmul
)
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = parallelarith
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.ParallelArith
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
vector.ApplyLowerMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# This is the default mode, not printed.
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
# This is the default strategy, not printed.
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.EltWise
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Flat
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_1d
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_16x16
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
# CHECK-SAME: avx2_lowering_strategy = true
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Flat,
avx2_lowering_strategy=True,
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
vector.ApplySplitTransferFullPartialPatternsOp()
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = none
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.None_
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# This is the default mode, not printed.
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
)
|