File: RuntimeOpVerification.cpp

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, trixie
  • size: 1,998,520 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (135 lines) | stat: -rw-r--r-- 5,974 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

namespace mlir {
namespace linalg {
namespace {
/// Verify that the runtime sizes of the operands to linalg structured ops are
/// compatible with the runtime sizes inferred by composing the loop ranges with
/// the linalg op's indexing maps. This is similar to the verifier except that
/// here we insert IR to perform the verification at runtime.
template <typename T>
struct StructuredOpInterface
    : public RuntimeVerifiableOpInterface::ExternalModel<
          StructuredOpInterface<T>, T> {
  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
                                   Location loc) const {
    auto linalgOp = llvm::cast<LinalgOp>(op);

    SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
    auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);

    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);

    // Subtract one from the loop ends before composing with the indexing map
    transform(ends, ends.begin(), [&](OpFoldResult end) {
      auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
      return builder.createOrFold<index::SubOp>(loc, endValue, one);
    });

    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
      auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
          builder, loc, indexingMap, starts);
      auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
          builder, loc, indexingMap, ends);

      for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
        auto startIndex =
            getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
        auto endIndex =
            getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);

        // Generate:
        //   minIndex = min(startIndex, endIndex)
        //   assert(minIndex >= 0)
        // To ensure we do not generate a negative index. We take the minimum of
        // the start and end indices in order to handle reverse loops such as
        // `affine_map<(i) -> (3 - i)>`
        auto min =
            builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
        auto cmpOp = builder.createOrFold<index::CmpOp>(
            loc, index::IndexCmpPredicate::SGE, min, zero);
        auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
            linalgOp, "unexpected negative result on dimension #" +
                          std::to_string(dim) + " of input/output operand #" +
                          std::to_string(opOperand.getOperandNumber()));
        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);

        // Generate:
        //   inferredDimSize = max(startIndex, endIndex) + 1
        //   actualDimSize = dim(operand)
        //   assert(inferredDimSize <= actualDimSize)
        // To ensure that we do not index past the bounds of the operands.
        auto max =
            builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);

        auto inferredDimSize =
            builder.createOrFold<index::AddOp>(loc, max, one);

        auto actualDimSize =
            createOrFoldDimOp(builder, loc, opOperand.get(), dim);

        // Similar to the verifier, when the affine expression in the indexing
        // map is complicated, we just check that the inferred dimension sizes
        // are in the boundary of the operands' size. Being more precise than
        // that is difficult.
        auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
                             ? index::IndexCmpPredicate::EQ
                             : index::IndexCmpPredicate::SLE;

        cmpOp = builder.createOrFold<index::CmpOp>(
            loc, predicate, inferredDimSize, actualDimSize);
        msg = RuntimeVerifiableOpInterface::generateErrorMessage(
            linalgOp, "dimension #" + std::to_string(dim) +
                          " of input/output operand #" +
                          std::to_string(opOperand.getOperandNumber()) +
                          " is incompatible with inferred dimension size");
        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
      }
    }
  }
};

template <typename... OpTs>
void attachInterface(MLIRContext *ctx) {
  (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
}
} // namespace
} // namespace linalg
} // namespace mlir

void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
    attachInterface<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
        >(ctx);

    // Load additional dialects of which ops may get created.
    ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
                     cf::ControlFlowDialect, index::IndexDialect,
                     tensor::TensorDialect>();
  });
}