1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
|
//===- TosaInferShapes.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Propogate shapes forward along TOSA operations to resolve dynamic shape
// operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAINFERSHAPES
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir
using namespace mlir;
using namespace mlir::tosa;
namespace {
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a TosaOp, or an op with a
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
bool canBeRefined(Operation *user) {
if (!user->getDialect())
return false;
return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}
// During type propagation, the types of values in the operator graph are
// updated. For the tosa.while_loop operation, types are speculatively updated
// within the body region to determine the output type of the while_loop. This
// process is performed until a fixed point is reached, then the types are
// rolled back.
//
// This class encapsulates the state information needed to perform the roll back
// process or to commit to the final changes.
class TypeModificationState {
public:
TypeModificationState() = default;
~TypeModificationState() {
// Ensure the recorded modifications are either committed or rolled back.
assert(oldTypes.empty() && "unhandled type modifications");
}
// Update the state of the value and record the old type.
void setType(Value value, Type type) {
if (value.getType() != type) {
oldTypes.emplace_back(value, value.getType());
value.setType(type);
}
}
// Roll back changes made to the types in the IR by setting all the affected
// values to their old types.
void rollBack() {
for (auto [value, type] : oldTypes)
value.setType(type);
oldTypes.clear();
}
// Commit the changes to the types in the IR.
// This requires inserting tensor.cast operations to mediate the newly
// inferred result types with users that do not support type inference.
void commit() {
// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
tensor::CastOp castedValue;
for (auto &use : value.getUses()) {
if (canBeRefined(use.getOwner()))
continue;
// Cache the cast to avoid generating duplicates
if (!castedValue) {
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
castedValue = builder.create<tensor::CastOp>(oldType, value);
}
use.set(castedValue);
}
}
oldTypes.clear();
}
private:
// A record of each value whose type was updated along with that value's
// previous type.
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
};
void propagateShapesInRegion(Region ®ion, TypeModificationState &state);
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
for (auto ®ion : op.getRegions()) {
Block &frontBlock = region.front();
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
state.setType(blockArg, newType);
}
}
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
ifOp.getOperand(i + 1).getType());
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
frontBlock.getArgument(i).getType());
ValueKnowledge joinedKnowledge =
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
propagateShapesInRegion(bodyRegion, localState);
// Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
}
for (auto yieldOp : yieldOps) {
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
auto newKnowledge =
ValueKnowledge::getKnowledgeFromType(it.value().getType());
yieldTypeInfo[it.index()] =
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
}
}
// This should never happen.
if (yieldTypeInfo.size() != argTypes.size()) {
op.emitWarning("has a tosa.yield with the incorrect number of operands");
return;
}
// Determine the new block args and see if any changed.
hasNewTypes = false;
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
Type newType = yieldTypeInfo[i].getType();
hasNewTypes |= (newType != argTypes[i]);
argTypes[i] = newType;
}
// Roll back all changes made during the speculative part of the algorithm.
localState.rollBack();
}
// We now set the block arguments according to the most recent shape
// inference results. This gives us the block arg types for the next
// iteration.
for (auto ®ion : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
state.setType(region.front().getArgument(i), argTypes[i]);
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect() != tosaDialect)
continue;
propagateShapesToTosaIf(op, state);
propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), op.getOperands(),
op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
// Determine the knowledge based on the output type.
// TODO: should also query WIP type probably
Type resultTy = result.getType();
auto currentKnowledge =
ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
// Set new type
state.setType(result, newKnowledge.getType());
}
}
}
}
}
/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes
: public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
}
};
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
return std::make_unique<TosaInferShapes>();
}
|