File: ExpandPatterns.cpp

package info (click to toggle)
llvm-toolchain-16 1%3A16.0.6-15~deb11u2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,634,820 kB
  • sloc: cpp: 6,179,261; ansic: 1,216,205; asm: 741,319; python: 196,614; objc: 75,325; f90: 49,640; lisp: 32,396; pascal: 12,286; sh: 9,394; perl: 7,442; ml: 5,494; awk: 3,523; makefile: 2,723; javascript: 1,206; xml: 886; fortran: 581; cs: 573
file content (112 lines) | stat: -rw-r--r-- 4,714 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of tanh op.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

/// Expands tanh op into
///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
  auto floatType = op.getOperand().getType();
  Location loc = op.getLoc();
  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
  auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
  Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
  Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);

  // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
  Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);

  // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
  exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
  dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
  divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
  Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);

  // tanh(x) = x >= 0 ? positiveRes : negativeRes
  auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
  Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
  Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
                                                op.getOperand(), zero);
  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
                                               negativeRes);
  return success();
}

static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
                                   PatternRewriter &rewriter) {
  auto operand = op.getOperand();
  auto elementTy = operand.getType();
  auto resultTy = op.getType();
  Location loc = op.getLoc();

  int bitWidth = elementTy.getIntOrFloatBitWidth();
  auto zero =
      rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
  auto leadingZeros = rewriter.create<arith::ConstantOp>(
      loc, IntegerAttr::get(elementTy, bitWidth));

  SmallVector<Value> operands = {operand, leadingZeros, zero};
  SmallVector<Type> types = {elementTy, elementTy, elementTy};
  SmallVector<Location> locations = {loc, loc, loc};

  auto whileOp = rewriter.create<scf::WhileOp>(
      loc, types, operands,
      [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
        // The conditional block of the while loop.
        Value input = args[0];
        Value zero = args[2];

        Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
            loc, arith::CmpIPredicate::ne, input, zero);
        beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
      },
      [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
        // The body of the while loop: shift right until reaching a value of 0.
        Value input = args[0];
        Value leadingZeros = args[1];

        auto one = afterBuilder.create<arith::ConstantOp>(
            loc, IntegerAttr::get(elementTy, 1));
        auto shifted =
            afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
        auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
            loc, resultTy, leadingZeros, one);

        afterBuilder.create<scf::YieldOp>(
            loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
      });

  rewriter.setInsertionPointAfter(whileOp);
  rewriter.replaceOp(op, whileOp->getResult(1));
  return success();
}

void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
  patterns.add(convertCtlzOp);
}

void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
  patterns.add(convertTanhOp);
}