File: ReconcileUnrealizedCasts.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (134 lines) | stat: -rw-r--r-- 5,115 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
/// the same as the input ones.
/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
/// represent a noop within the IR, and thus the initial input values can be
/// propagated.
/// The same does not hold for 'open' chains chains of casts, such as
/// `A -> B -> C`. In this last case there is no cycle among the types and thus
/// the conversion is incomplete. The same hold for 'closed' chains like
/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
/// operations.
/// Bifurcations (that is when a chain starts in between of another one) are
/// also taken into considerations, and all the above considerations remain
/// valid.
/// Special corner cases such as dead casts or single casts with same input and
/// output types are also covered.
struct UnrealizedConversionCastPassthrough
    : public OpRewritePattern<UnrealizedConversionCastOp> {
  using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
                                PatternRewriter &rewriter) const override {
    // The nodes that either are not used by any operation or have at least
    // one user that is not an unrealized cast.
    DenseSet<UnrealizedConversionCastOp> exitNodes;

    // The nodes whose users are all unrealized casts
    DenseSet<UnrealizedConversionCastOp> intermediateNodes;

    // Stack used for the depth-first traversal of the use-def DAG.
    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
    visitStack.push_back(op);

    while (!visitStack.empty()) {
      UnrealizedConversionCastOp current = visitStack.pop_back_val();
      auto users = current->getUsers();
      bool isLive = false;

      for (Operation *user : users) {
        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
          if (other.getInputs() != current.getOutputs())
            return rewriter.notifyMatchFailure(
                op, "mismatching values propagation");
        } else {
          isLive = true;
        }

        // Continue traversing the DAG of unrealized casts
        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
          visitStack.push_back(other);
      }

      // If the cast is live, then we need to check if the results of the last
      // cast have the same type of the root inputs. It this is the case (e.g.
      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
      // no-op and the inputs can be forwarded. If it's not (e.g.
      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.

      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();

      if (isLive && !isCycle)
        return rewriter.notifyMatchFailure(op,
                                           "live unrealized conversion cast");

      bool isExitNode = users.empty() || isLive;

      if (isExitNode) {
        exitNodes.insert(current);
      } else {
        intermediateNodes.insert(current);
      }
    }

    // Replace the sink nodes with the root input values
    for (UnrealizedConversionCastOp exitNode : exitNodes)
      rewriter.replaceOp(exitNode, op.getInputs());

    // Erase all the other casts belonging to the DAG
    for (UnrealizedConversionCastOp castOp : intermediateNodes)
      rewriter.eraseOp(castOp);

    return success();
  }
};

/// Pass to simplify and eliminate unrealized conversion casts.
struct ReconcileUnrealizedCasts
    : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
  ReconcileUnrealizedCasts() = default;

  void runOnOperation() override {
    RewritePatternSet patterns(&getContext());
    populateReconcileUnrealizedCastsPatterns(patterns);
    ConversionTarget target(getContext());
    target.addIllegalOp<UnrealizedConversionCastOp>();
    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns))))
      signalPassFailure();
  }
};

} // namespace

void mlir::populateReconcileUnrealizedCastsPatterns(
    RewritePatternSet &patterns) {
  patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
}

std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
  return std::make_unique<ReconcileUnrealizedCasts>();
}