File: Simplifications.cpp

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,998,520 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (122 lines) | stat: -rw-r--r-- 4,571 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <numeric>
#include <utility>

namespace mlir {
namespace mesh {

void populateSimplificationPatterns(
    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
      patterns, ReductionKind::Sum);
  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
      patterns, ReductionKind::Sum);

  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
      patterns, ReductionKind::Min);
  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
      patterns, ReductionKind::Min);
  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
      patterns, ReductionKind::Min);

  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
      patterns, ReductionKind::Max);
  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
      patterns, ReductionKind::Max);
  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
      patterns, ReductionKind::Max);

  // TODO: add simplifications for all-gather and other collectives.

  populateFoldingPatterns(patterns, symbolTableCollection);
}

namespace {

// This folding can not be done with an operation's fold method or
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced MeshOp ops.
struct MeshShapeFolder
    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
  using OpRewritePatternWithSymbolTableCollection::
      OpRewritePatternWithSymbolTableCollection;
  LogicalResult matchAndRewrite(MeshShapeOp op,
                                PatternRewriter &rewriter) const override {
    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
        op.getOperation(), op.getMeshAttr());
    if (!mesh) {
      return failure();
    }
    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
    SmallVector<MeshAxis> opAxesIota;
    if (opMeshAxes.empty()) {
      opAxesIota.resize(mesh.getRank());
      std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
      opMeshAxes = opAxesIota;
    }
    if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
          return ShapedType::isDynamic(mesh.getShape()[axis]);
        })) {
      // All mesh dimensions are dynamic. Nothing to fold.
      return failure();
    }

    SmallVector<Value> newResults(op->getResults().size());
    SmallVector<MeshAxis> newShapeOpMeshAxes;
    SmallVector<size_t> newToOldResultsIndexMap;

    for (size_t i = 0; i < opMeshAxes.size(); ++i) {
      auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
      if (ShapedType::isDynamic(meshAxisSize)) {
        newToOldResultsIndexMap.push_back(i);
        newShapeOpMeshAxes.push_back(opMeshAxes[i]);
      } else {
        // Fold static mesh axes.
        newResults[i] = builder.create<arith::ConstantOp>(
            builder.getIndexAttr(meshAxisSize));
      }
    }

    // Leave only the dynamic mesh axes to be queried.
    if (!newShapeOpMeshAxes.empty()) {
      MeshShapeOp newShapeOp =
          builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
      for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
        newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
      }
    }
    rewriter.replaceOp(op, newResults);

    return success();
  }
};

} // namespace

void populateFoldingPatterns(RewritePatternSet &patterns,
                             SymbolTableCollection &symbolTableCollection) {
  patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
}

} // namespace mesh
} // namespace mlir