1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
|
//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
// The byte-code is constructed from the PDL Interpreter dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_BYTECODE_H_
#define MLIR_REWRITE_BYTECODE_H_
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace pdl_interp {
class RecordMatchOp;
} // namespace pdl_interp
namespace detail {
class PDLByteCode;
/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
using OwningOpRange = llvm::OwningArrayRef<Operation *>;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
//===----------------------------------------------------------------------===//
/// All of the data pertaining to a specific pattern within the bytecode.
class PDLByteCodePattern : public Pattern {
public:
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr);
/// Return the bytecode address of the rewriter for this pattern.
ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
/// Return the configuration set for this pattern, or null if there is none.
PDLPatternConfigSet *getConfigSet() const { return configSet; }
private:
template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
Args &&...patternArgs)
: Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
configSet(configSet) {}
/// The address of the rewriter for this pattern.
ByteCodeAddr rewriterAddr;
/// The optional config set for this pattern.
PDLPatternConfigSet *configSet;
};
//===----------------------------------------------------------------------===//
// PDLByteCodeMutableState
//===----------------------------------------------------------------------===//
/// This class contains the mutable state of a bytecode instance. This allows
/// for a bytecode instance to be cached and reused across various different
/// threads/drivers.
class PDLByteCodeMutableState {
public:
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
/// Cleanup any allocated state after a match/rewrite has been completed. This
/// method should be called irregardless of whether the match+rewrite was a
/// success or not.
void cleanupAfterMatchAndRewrite();
private:
/// Allow access to data fields.
friend class PDLByteCode;
/// The mutable block of memory used during the matching and rewriting phases
/// of the bytecode.
std::vector<const void *> memory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of operations. These are always stored by
/// owning references, because at no point in the execution of the byte code
/// we get an indexed range (view) of operations.
std::vector<OwningOpRange> opRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
/// A set of type ranges that have been allocated by the byte code interpreter
/// to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of values.
std::vector<ValueRange> valueRangeMemory;
/// A set of value ranges that have been allocated by the byte code
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
/// The current index of ranges being iterated over for each level of nesting.
/// These are always maintained at 0 for the loops that are not active, so we
/// do not need to have a separate initialization phase for each loop.
std::vector<unsigned> loopIndex;
/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
};
//===----------------------------------------------------------------------===//
// PDLByteCode
//===----------------------------------------------------------------------===//
/// The bytecode class is also the interpreter. Contains the bytecode itself,
/// the static info, addresses of the rewriter functions, the interpreter
/// memory buffer, and the execution context.
class PDLByteCode {
public:
/// Each successful match returns a MatchResult, which contains information
/// necessary to execute the rewriter and indicates the originating pattern.
struct MatchResult {
MatchResult(Location loc, const PDLByteCodePattern &pattern,
PatternBenefit benefit)
: location(loc), pattern(&pattern), benefit(benefit) {}
MatchResult(const MatchResult &) = delete;
MatchResult &operator=(const MatchResult &) = delete;
MatchResult(MatchResult &&other) = default;
MatchResult &operator=(MatchResult &&) = default;
/// The location of operations to be replaced.
Location location;
/// Memory values defined in the matcher that are passed to the rewriter.
SmallVector<const void *> values;
/// Memory used for the range input values.
SmallVector<TypeRange, 0> typeRangeValues;
SmallVector<ValueRange, 0> valueRangeValues;
/// The originating pattern that was matched. This is always non-null, but
/// represented with a pointer to allow for assignment.
const PDLByteCodePattern *pattern;
/// The current benefit of the pattern that was matched.
PatternBenefit benefit;
};
/// Create a ByteCode instance from the given module containing operations in
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);
/// Return the patterns held by the bytecode.
ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
/// Initialize the given state such that it can be used to execute the current
/// bytecode.
void initializeMutableState(PDLByteCodeMutableState &state) const;
/// Run the pattern matcher on the given root operation, collecting the
/// matched patterns in `matches`.
void match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const;
/// Run the rewriter of the given pattern that was previously matched in
/// `match`. Returns if a failure was encountered during the rewrite.
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const;
private:
/// Execute the given byte code starting at the provided instruction `inst`.
/// `matches` is an optional field provided when this function is executed in
/// a matching context.
void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
PDLByteCodeMutableState &state,
SmallVectorImpl<MatchResult> *matches) const;
/// The set of pattern configs referenced within the bytecode.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
/// A vector containing pointers to uniqued data. The storage is intentionally
/// opaque such that we can store a wide range of data types. The types of
/// data stored here include:
/// * Attribute, OperationName, Type
std::vector<const void *> uniquedData;
/// A vector containing the generated bytecode for the matcher.
SmallVector<ByteCodeField, 64> matcherByteCode;
/// A vector containing the generated bytecode for all of the rewriters.
SmallVector<ByteCodeField, 64> rewriterByteCode;
/// The set of patterns contained within the bytecode.
SmallVector<PDLByteCodePattern, 32> patterns;
/// A set of user defined functions invoked via PDL.
std::vector<PDLConstraintFunction> constraintFunctions;
std::vector<PDLRewriteFunction> rewriteFunctions;
/// The maximum memory index used by a value.
ByteCodeField maxValueMemoryIndex = 0;
/// The maximum number of different types of ranges.
ByteCodeField maxOpRangeCount = 0;
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;
/// The maximum number of nested loops.
ByteCodeField maxLoopLevel = 0;
};
} // namespace detail
} // namespace mlir
#endif // MLIR_REWRITE_BYTECODE_H_
|