File: TestBackwardDataFlowAnalysis.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (140 lines) | stat: -rw-r--r-- 5,170 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::dataflow;

namespace {

/// This lattice represents, for a given value, the set of memory resources that
/// this value, or anything derived from this value, is potentially written to.
struct WrittenTo : public AbstractSparseLattice {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
  using AbstractSparseLattice::AbstractSparseLattice;

  void print(raw_ostream &os) const override {
    os << "[";
    llvm::interleave(
        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
    os << "]";
  }
  ChangeResult addWrites(const SetVector<StringAttr> &writes) {
    int sizeBefore = this->writes.size();
    this->writes.insert(writes.begin(), writes.end());
    int sizeAfter = this->writes.size();
    return sizeBefore == sizeAfter ? ChangeResult::NoChange
                                   : ChangeResult::Change;
  }
  ChangeResult meet(const AbstractSparseLattice &other) override {
    const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
    return addWrites(rhs->writes);
  }

  SetVector<StringAttr> writes;
};

/// An analysis that, by going backwards along the dataflow graph, annotates
/// each value with all the memory resources it (or anything derived from it)
/// is eventually written to.
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
public:
  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

  void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
                      ArrayRef<const WrittenTo *> results) override;

  void visitBranchOperand(OpOperand &operand) override;

  void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
};

void WrittenToAnalysis::visitOperation(Operation *op,
                                       ArrayRef<WrittenTo *> operands,
                                       ArrayRef<const WrittenTo *> results) {
  if (auto store = dyn_cast<memref::StoreOp>(op)) {
    SetVector<StringAttr> newWrites;
    newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
    propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
    return;
  } // By default, every result of an op depends on every operand.
    for (const WrittenTo *r : results) {
      for (WrittenTo *operand : operands) {
        meet(operand, *r);
      }
      addDependency(const_cast<WrittenTo *>(r), op);
    }
}

void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
  // Mark branch operands as "brancharg%d", with %d the operand number.
  WrittenTo *lattice = getLatticeElement(operand.get());
  SetVector<StringAttr> newWrites;
  newWrites.insert(
      StringAttr::get(operand.getOwner()->getContext(),
                      "brancharg" + Twine(operand.getOperandNumber())));
  propagateIfChanged(lattice, lattice->addWrites(newWrites));
}

} // end anonymous namespace

namespace {
struct TestWrittenToPass
    : public PassWrapper<TestWrittenToPass, OperationPass<>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)

  StringRef getArgument() const override { return "test-written-to"; }

  void runOnOperation() override {
    Operation *op = getOperation();

    SymbolTableCollection symbolTable;

    DataFlowSolver solver;
    solver.load<DeadCodeAnalysis>();
    solver.load<SparseConstantPropagation>();
    solver.load<WrittenToAnalysis>(symbolTable);
    if (failed(solver.initializeAndRun(op)))
      return signalPassFailure();

    raw_ostream &os = llvm::outs();
    op->walk([&](Operation *op) {
      auto tag = op->getAttrOfType<StringAttr>("tag");
      if (!tag)
        return;
      os << "test_tag: " << tag.getValue() << ":\n";
      for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
        const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
        assert(writtenTo && "expected a sparse lattice");
        os << " operand #" << index << ": ";
        writtenTo->print(os);
        os << "\n";
      }
      for (auto [index, operand] : llvm::enumerate(op->getResults())) {
        const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
        assert(writtenTo && "expected a sparse lattice");
        os << " result #" << index << ": ";
        writtenTo->print(os);
        os << "\n";
      }
    });
  }
};
} // end anonymous namespace

namespace mlir {
namespace test {
void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
} // end namespace test
} // end namespace mlir