File: FuseContinuations.cpp

package info (click to toggle)
intel-graphics-compiler 1.0.12504.6-1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 83,912 kB
  • sloc: cpp: 910,147; lisp: 202,655; ansic: 15,197; python: 4,025; yacc: 2,241; lex: 1,570; pascal: 244; sh: 104; makefile: 25
file content (124 lines) | stat: -rw-r--r-- 4,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*========================== begin_copyright_notice ============================

Copyright (C) 2021 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

//===----------------------------------------------------------------------===//
//
// We will sometimes see shaders where multiple continuations have exactly the
// same code in them.  For example, a common TraceRay() use is to make the call
// at the end of a shader.  This may result in a continuation that just releases
// the stack ID (for a raygen shader) or just does a merge (for a closest-hit,
// say).  Here, we try to merge continuations that are identical to help
// STS/BTD out to pack more lanes into an invocation by having less
// continuations to deal with.
//
//===----------------------------------------------------------------------===//

#include "FuseContinuations.h"
#include <set>
#include "common/LLVMWarningsPush.hpp"
#include "llvm/Transforms/Utils/FunctionComparator.h"
#include "llvm/IR/InstIterator.h"
#include "common/LLVMWarningsPop.hpp"
#include "GenISAIntrinsics/GenIntrinsicInst.h"

using namespace llvm;
using namespace IGC;

namespace ContinuationFusing {

void fuseContinuations(Module& M, MapVector<Function*, FuncInfo>& ContMap)
{
    // Based off of data structures from MergeFunctions.cpp
    class FunctionNode
    {
      FunctionComparator::FunctionHash Hash;
      Function* F = nullptr;
      Function* RootFn = nullptr;
    public:
      // Note the hash is recalculated potentially multiple times, but it is cheap.
      FunctionNode(Function *F, Function *RootFn)
        : F(F), RootFn(RootFn), Hash(FunctionComparator::functionHash(*F))  {}

      Function *getFunc() const { return F; }
      Function *getRootFunc() const { return RootFn; }
      FunctionComparator::FunctionHash getHash() const { return Hash; }
    };

    // The function comparison operator is provided here so that FunctionNodes do
    // not need to become larger with another pointer.
    class FunctionNodeCmp
    {
        GlobalNumberState* GlobalNumbers;
    public:
        FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}

        bool operator()(const FunctionNode& LHS, const FunctionNode& RHS) const
        {
            if (LHS.getRootFunc() != RHS.getRootFunc())
            {
                uint64_t L = GlobalNumbers->getNumber(LHS.getFunc());
                uint64_t R = GlobalNumbers->getNumber(RHS.getFunc());
                return L < R;
            }
            // Order first by hashes, then full function comparison.
            if (LHS.getHash() != RHS.getHash())
                return LHS.getHash() < RHS.getHash();
            FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
            return FCmp.compare() == -1;
        }
    };

    DenseMap<uint32_t, SmallVector<ContinuationHLIntrinsic*, 4>> CIs;
    auto fill = [&]()
    {
        if (!CIs.empty())
            return;

        for (auto& F : M)
        {
            for (auto& I : instructions(F))
            {
                if (auto * CI = dyn_cast<ContinuationHLIntrinsic>(&I))
                    CIs[CI->getContinuationID()].push_back(CI);
            }
        }
    };

    using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>;

    GlobalNumberState GlobalNumbers;
    FnTreeType FnTree{ FunctionNodeCmp(&GlobalNumbers) };

    for (auto& [Fn, FnInfo] : ContMap)
    {
        auto [I, Ok] = FnTree.insert(FunctionNode(Fn, FnInfo.RootFn));
        if (!Ok)
        {
            fill();
            auto* DupFn = I->getFunc();
            auto Entry = ContMap.find(DupFn);
            IGC_ASSERT(Entry != ContMap.end());
            uint32_t NewID = Entry->second.Idx;
            for (auto *CI : CIs[FnInfo.Idx])
            {
                CI->setContinuationID(NewID);
                CI->setContinuationFn(DupFn);
            }

            Fn->removeDeadConstantUsers();
            IGC_ASSERT_MESSAGE(Fn->use_empty(), "other uses?");

            Fn->eraseFromParent();
            Fn = nullptr;
        }
    }

    ContMap.remove_if([](auto& P) { return P.first == nullptr; });
}

} // namespace ContinuationFusing