File: TosaTestPasses.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (208 lines) | stat: -rw-r--r-- 7,599 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
//===- TosaTestPasses.cpp -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test passes to exercise TOSA helper functions.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define PASS_NAME "tosa-test-quant-utils"

using namespace mlir;
using namespace mlir::tosa;

// This transformation converts quantized uint8 to quantized int8. The
// construction of the new type invokes buildQTypeFromMinMax. Extracted from
// TOSA legalization infrastructure.
struct ConvertTosaNegateOp : public RewritePattern {
  explicit ConvertTosaNegateOp(MLIRContext *context)
      : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
  LogicalResult matchAndRewrite(Operation *op,
                                PatternRewriter &rewriter) const override;
};

LogicalResult
ConvertTosaNegateOp::matchAndRewrite(Operation *op,
                                     PatternRewriter &rewriter) const {

  auto tosaNegateOp = cast<tosa::NegateOp>(op);

  auto inputType =
      dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType());
  // skip if input is not ranked tensor type
  if (!inputType)
    return failure();

  // skip if it's not ranked tensor type.
  auto outputType =
      dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType());
  if (!outputType)
    return failure();

  // skip if output is not per-tensor quantized type.
  auto outputElementType =
      dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
  if (!outputElementType)
    return failure();

  // skip if output is not uint8.
  if (outputElementType.isSigned() ||
      outputElementType.getStorageTypeIntegralWidth() != 8)
    return failure();

  double typeRangeMin = double(outputElementType.getStorageTypeMin() -
                               outputElementType.getZeroPoint()) *
                        outputElementType.getScale();
  double typeRangeMax = double(outputElementType.getStorageTypeMax() -
                               outputElementType.getZeroPoint()) *
                        outputElementType.getScale();
  bool narrowRange = outputElementType.getStorageTypeMin() == 1;

  auto dstQConstType = RankedTensorType::get(
      outputType.getShape(),
      buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
                           rewriter.getF64FloatAttr(typeRangeMin),
                           rewriter.getF64FloatAttr(typeRangeMax),
                           rewriter.getI32IntegerAttr(
                               outputElementType.getStorageTypeIntegralWidth()),
                           0, true /* signed */,
                           rewriter.getBoolAttr(narrowRange)));

  ElementsAttr inputElems;
  if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems)))
    return failure();

  auto newConstOp =
      rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
  auto newNegateOp = rewriter.create<tosa::NegateOp>(
      op->getLoc(), dstQConstType, newConstOp.getResult());

  rewriter.replaceOp(op, {newNegateOp.getResult()});
  return success();
}

// This transformation modifies the quantized output of a test conv2d input and
// appends a TOSA rescale after it. The rescale op requires the invocation of
// computeMultiplierAndShift. From TOSA legalization infrastructure.
struct ConvertTosaConv2DOp : public RewritePattern {
  explicit ConvertTosaConv2DOp(MLIRContext *context)
      : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
  LogicalResult matchAndRewrite(Operation *op,
                                PatternRewriter &rewriter) const override;
};

LogicalResult
ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
                                     PatternRewriter &rewriter) const {

  auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);

  auto inputType =
      dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType());

  // skip if input is not ranked tensor type
  if (!inputType)
    return failure();

  auto weightType =
      dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType());

  // skip if wt is not ranked tensor type
  if (!weightType)
    return failure();

  // skip if it's not ranked tensor type.
  auto outputType =
      dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType());
  if (!outputType)
    return failure();

  auto inputQType =
      dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType());
  auto weightQType =
      dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType());
  auto outputQType =
      dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());

  // Works on quantized type only.
  if (!(inputQType && weightQType && outputQType))
    return failure();

  auto newTosaConv2DOpType =
      RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));

  auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
      op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
      tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
      tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
      tosaConv2DOp.getDilationAttr());

  // Create rescale to quantized type
  double inputScale = inputQType.getScale();
  double weightScale = weightQType.getScale();
  double outputScale = outputQType.getScale();
  int64_t outputZp = outputQType.getZeroPoint();

  double opTensorScale = (inputScale * weightScale) / outputScale;

  int32_t multiplier;
  int32_t shift;

  // Obtain the quantized scale = multiplier and shift.
  computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);

  auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
      op->getLoc(), outputType, newTosaConv2DOp.getResult(),
      rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
      rewriter.getDenseI32ArrayAttr({multiplier}),
      rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(true),
      rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));

  rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
  return success();
}

namespace {

struct TosaTestQuantUtilAPI
    : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI)

  StringRef getArgument() const final { return PASS_NAME; }
  StringRef getDescription() const final {
    return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
  }
  void runOnOperation() override;
};

void TosaTestQuantUtilAPI::runOnOperation() {
  auto *ctx = &getContext();
  RewritePatternSet patterns(ctx);
  auto func = getOperation();

  patterns.add<ConvertTosaNegateOp>(ctx);
  patterns.add<ConvertTosaConv2DOp>(ctx);
  (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

namespace mlir {
void registerTosaTestQuantUtilAPIPass() {
  PassRegistration<TosaTestQuantUtilAPI>();
}
} // namespace mlir