File: TestTensorCopyInsertion.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (78 lines) | stat: -rw-r--r-- 3,352 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
//===- TestTensorCopyInsertion.cpp - Bufferization Analysis -----*- c++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
/// This pass runs One-Shot Analysis and inserts copies for all OpOperands that
/// were decided to bufferize out-of-place. After running this pass, a
/// bufferization can write to buffers directly (without making copies) and no
/// longer has to care about potential read-after-write conflicts.
///
/// Note: By default, all newly inserted tensor copies/allocs (i.e., newly
/// created `bufferization.alloc_tensor` ops) that do not escape block are
/// annotated with `escape = false`. If `create-allocs` is unset, all newly
/// inserted tensor copies/allocs are annotated with `escape = true`. In that
/// case, they are not getting deallocated when bufferizing the IR.
struct TestTensorCopyInsertionPass
    : public PassWrapper<TestTensorCopyInsertionPass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorCopyInsertionPass)

  TestTensorCopyInsertionPass() = default;
  TestTensorCopyInsertionPass(const TestTensorCopyInsertionPass &pass)
      : PassWrapper(pass) {}

  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<bufferization::BufferizationDialect>();
  }
  StringRef getArgument() const final { return "test-tensor-copy-insertion"; }
  StringRef getDescription() const final {
    return "Module pass to test Tensor Copy Insertion";
  }

  void runOnOperation() override {
    bufferization::OneShotBufferizationOptions options;
    options.allowReturnAllocs = allowReturnAllocs;
    options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
    options.createDeallocs = createDeallocs;
    if (mustInferMemorySpace)
      options.defaultMemorySpace = std::nullopt;
    if (failed(bufferization::insertTensorCopies(getOperation(), options)))
      signalPassFailure();
  }

  Option<bool> allowReturnAllocs{
      *this, "allow-return-allocs",
      llvm::cl::desc("Allows returning/yielding new allocations from a block."),
      llvm::cl::init(false)};
  Option<bool> bufferizeFunctionBoundaries{
      *this, "bufferize-function-boundaries",
      llvm::cl::desc("Bufferize function boundaries."), llvm::cl::init(false)};
  Option<bool> createDeallocs{
      *this, "create-deallocs",
      llvm::cl::desc("Specify if new allocations should be deallocated."),
      llvm::cl::init(true)};
  Option<bool> mustInferMemorySpace{
      *this, "must-infer-memory-space",
      llvm::cl::desc(
          "The memory space of an memref types must always be inferred. If "
          "unset, a default memory space of 0 is used otherwise."),
      llvm::cl::init(false)};
};
} // namespace

namespace mlir::test {
void registerTestTensorCopyInsertionPass() {
  PassRegistration<TestTensorCopyInsertionPass>();
}
} // namespace mlir::test