File: OneToNFuncConversions.cpp

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (130 lines) | stat: -rw-r--r-- 4,583 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The patterns in this file are heavily inspired (and copied from)
// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
// type conversions.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/OneToNTypeConversion.h"

using namespace mlir;
using namespace mlir::func;

namespace {

class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
public:
  using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;

  LogicalResult
  matchAndRewrite(CallOp op, OpAdaptor adaptor,
                  OneToNPatternRewriter &rewriter) const override {
    Location loc = op->getLoc();
    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();

    // Nothing to do if the op doesn't have any non-identity conversions for its
    // operands or results.
    if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
        !resultMapping.hasNonIdentityConversion())
      return failure();

    // Create new CallOp.
    auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
                                         adaptor.getFlatOperands());
    newOp->setAttrs(op->getAttrs());

    rewriter.replaceOp(op, newOp->getResults(), resultMapping);
    return success();
  }
};

class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
public:
  using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;

  LogicalResult
  matchAndRewrite(FuncOp op, OpAdaptor adaptor,
                  OneToNPatternRewriter &rewriter) const override {
    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();

    // Construct mapping for function arguments.
    OneToNTypeMapping argumentMapping(op.getArgumentTypes());
    if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
                                                 argumentMapping)))
      return failure();

    // Construct mapping for function results.
    OneToNTypeMapping funcResultMapping(op.getResultTypes());
    if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
                                                 funcResultMapping)))
      return failure();

    // Nothing to do if the op doesn't have any non-identity conversions for its
    // operands or results.
    if (!argumentMapping.hasNonIdentityConversion() &&
        !funcResultMapping.hasNonIdentityConversion())
      return failure();

    // Update the function signature in-place.
    auto newType = FunctionType::get(rewriter.getContext(),
                                     argumentMapping.getConvertedTypes(),
                                     funcResultMapping.getConvertedTypes());
    rewriter.updateRootInPlace(op, [&] { op.setType(newType); });

    // Update block signatures.
    if (!op.isExternal()) {
      Region *region = &op.getBody();
      Block *block = &region->front();
      rewriter.applySignatureConversion(block, argumentMapping);
    }

    return success();
  }
};

class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
public:
  using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;

  LogicalResult
  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                  OneToNPatternRewriter &rewriter) const override {
    // Nothing to do if there is no non-identity conversion.
    if (!adaptor.getOperandMapping().hasNonIdentityConversion())
      return failure();

    // Convert operands.
    rewriter.updateRootInPlace(
        op, [&] { op->setOperands(adaptor.getFlatOperands()); });

    return success();
  }
};

} // namespace

namespace mlir {

void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
                                        RewritePatternSet &patterns) {
  patterns.add<
      // clang-format off
      ConvertTypesInFuncCallOp,
      ConvertTypesInFuncFuncOp,
      ConvertTypesInFuncReturnOp
      // clang-format on
      >(typeConverter, patterns.getContext());
}

} // namespace mlir