File: Decomposer.cpp

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,998,520 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (131 lines) | stat: -rw-r--r-- 4,590 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#include "Decomposer.h"

#include "Clauses.h"
#include "Utils.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Semantics/semantics.h"
#include "flang/Tools/CrossToolHelpers.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/ClauseT.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <utility>
#include <variant>

using namespace Fortran;

namespace {
using namespace Fortran::lower::omp;

struct ConstructDecomposition {
  ConstructDecomposition(mlir::ModuleOp modOp,
                         semantics::SemanticsContext &semaCtx,
                         lower::pft::Evaluation &ev,
                         llvm::omp::Directive compound,
                         const List<Clause> &clauses)
      : semaCtx(semaCtx), mod(modOp), eval(ev) {
    tomp::ConstructDecompositionT decompose(getOpenMPVersionAttribute(modOp),
                                            *this, compound,
                                            llvm::ArrayRef(clauses));
    output = std::move(decompose.output);
  }

  // Given an object, return its base object if one exists.
  std::optional<Object> getBaseObject(const Object &object) {
    return lower::omp::getBaseObject(object, semaCtx);
  }

  // Return the iteration variable of the associated loop if any.
  std::optional<Object> getLoopIterVar() {
    if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
      return Object{symbol, /*designator=*/{}};
    return std::nullopt;
  }

  semantics::SemanticsContext &semaCtx;
  mlir::ModuleOp mod;
  lower::pft::Evaluation &eval;
  List<UnitConstruct> output;
};
} // namespace

static UnitConstruct mergeConstructs(uint32_t version,
                                     llvm::ArrayRef<UnitConstruct> units) {
  tomp::ConstructCompositionT compose(version, units);
  return compose.merged;
}

namespace Fortran::lower::omp {
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                                               const UnitConstruct &uc) {
  os << llvm::omp::getOpenMPDirectiveName(uc.id);
  for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
    os << (index == 0 ? '\t' : ' ');
    os << llvm::omp::getOpenMPClauseName(clause.id);
  }
  return os;
}

ConstructQueue buildConstructQueue(
    mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
    Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
    llvm::omp::Directive compound, const List<Clause> &clauses) {

  List<UnitConstruct> constructs;

  ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
  assert(!decompose.output.empty() && "Construct decomposition failed");

  llvm::SmallVector<llvm::omp::Directive> loweringUnits;
  std::ignore =
      llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
  uint32_t version = getOpenMPVersionAttribute(modOp);

  int leafIndex = 0;
  for (llvm::omp::Directive dir_id : loweringUnits) {
    llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
        llvm::omp::getLeafConstructsOrSelf(dir_id);
    size_t numLeafs = leafsOrSelf.size();

    llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
                                          numLeafs};
    auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));

    if (!transferLocations(clauses, uc.clauses)) {
      // If some clauses are left without source information, use the
      // directive's source.
      for (auto &clause : uc.clauses) {
        if (clause.source.empty())
          clause.source = source;
      }
    }
    leafIndex += numLeafs;
  }

  return constructs;
}

bool isLastItemInQueue(ConstructQueue::iterator item,
                       const ConstructQueue &queue) {
  return std::next(item) == queue.end();
}
} // namespace Fortran::lower::omp