File: CustomLoopOpt.hpp

package info (click to toggle)
intel-graphics-compiler 1.0.12504.6-1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 83,912 kB
  • sloc: cpp: 910,147; lisp: 202,655; ansic: 15,197; python: 4,025; yacc: 2,241; lex: 1,570; pascal: 244; sh: 104; makefile: 25
file content (164 lines) | stat: -rw-r--r-- 5,380 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/*========================== begin_copyright_notice ============================

Copyright (C) 2017-2021 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#pragma once

#include "common/LLVMWarningsPush.hpp"
#include <llvm/Pass.h>
#include <llvm/PassRegistry.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/Analysis/LoopPass.h>
#include <llvm/Transforms/Utils/ValueMapper.h>
#include <llvm/Transforms/Scalar.h>

#include <llvmWrapper/Transforms/Utils.h>

#include "common/LLVMWarningsPop.hpp"

#include "Compiler/CodeGenPublic.h"
#include "Compiler/CodeGenContextWrapper.hpp"

namespace IGC
{
    ///////////////////////////////////////////////////////////////////////////
    /// Enforce a single latch for every loop header. This needs to be ran before
    /// LLVM Loop canonicalization pass as LLVM loop simplification pass sometimes
    /// decides to spilt the loop. Spliting the loop may cause functional issues
    /// in case of barriers being used and it may cause extra SIMD divergence causing
    /// performance degradation
    llvm::FunctionPass* createLoopCanonicalization();
    /**
     * Custom loop versioning.
     * Break loop into segments to expose loop invirants.
     *
     * Input loop:
     *   float t = ....;
     *   float nextT = t * CB_Load;
     *   [loop] while (t < loop_range_y)
     *   {
     *        float val0 = max(t, loop_range_x);
     *        float val1 = min(nextT, loop_range_y);
     *        float val = some_alu_func(val1 / val0);
     *        ......
     *        t = nextT;
     *        nextT *= CB_Load;
     *    }
     *
     * Transformed loop:
     *   float t = ....;
     *   float nextT = t * CB_Load;
     *   [branch] if (CB_Load > 1.0 && loop_range_x * CB_Load < loop_range_y)
     *   {
     *       [loop] while (t < loop_range_x)        // loop seg 1
     *       {
     *           float val0 = loop_range_x;
     *           float val1 = nextT;
     *           float val = some_alu_func(val1 / val0);
     *           ......
     *           t = nextT;
     *           nextT *= CB_Load;
     *       }
     *       [loop] while (t < loop_range_y/CB_Load)  // loop seg 2
     *       {
     *           float val0 = t;
     *           float val1 = nextT;
     *           float val = some_alu_func(CB_Load);    // loop invirant
     *           ......
     *           t = nextT;
     *           nextT *= CB_Load;
     *       }
     *       {                                          // loop seg 3
     *           float val0 = t;
     *           float val1 = loop_range_y;
     *           float val = some_alu_func(val1 / val0);
     *           t = nextT;
     *           nextT *= CB_Load;
     *       }
     *   } else {
     *       [loop] while (t < loop_range_y)
     *       {
     *           float val0 = max(t, loop_range_x);
     *           float val1 = min(nextT, loop_range_y);
     *           float val = some_alu_func(val1 / val0);
     *           ......
     *           t = nextT;
     *           nextT *= CB_Load;
     *        }
     *   }
     */
    class CustomLoopVersioning : public llvm::FunctionPass
    {
    public:
        static char ID;

        CustomLoopVersioning();
        ~CustomLoopVersioning() { }

        void getAnalysisUsage(llvm::AnalysisUsage& AU) const
        {
            AU.addRequired<CodeGenContextWrapper>();
            AU.addRequired<MetaDataUtilsWrapper>();
            AU.addRequired<llvm::LoopInfoWrapperPass>();
            AU.addRequired<llvm::DominatorTreeWrapperPass>();
            AU.addRequiredID(llvm::LCSSAID);
        }

        bool runOnFunction(llvm::Function& F);
        bool processLoop(llvm::Loop* loop);

        llvm::StringRef getPassName() const
        {
            return "Custom Loop Versioning";
        }

    private:
        CodeGenContext* m_cgCtx;
        llvm::LoopInfo* m_LI;
        llvm::DominatorTree* m_DT;
        llvm::Function* m_function;

        // value map from orig loop to loop seg1/seg2/seg3
        llvm::ValueToValueMapTy m_vmapToSeg1;
        llvm::ValueToValueMapTy m_vmapToSeg2;
        llvm::ValueToValueMapTy m_vmapToSeg3;

        bool isCBLoad(llvm::Value* val, unsigned& bufId, unsigned& offset);

        // create phi nodes for after loop BB
        void addPhiNodes(
            const llvm::SmallVectorImpl<llvm::Instruction*>& liveOuts,
            llvm::Loop* loopSeg1, llvm::Loop* loopSeg2,
            llvm::BasicBlock* bbSeg3, llvm::Loop* origLoop);

        bool detectLoop(llvm::Loop* loop,
            llvm::Value*& var_range_x, llvm::Value*& var_range_y,
            llvm::LoadInst*& var_MediumFactor_preHdr,
            llvm::Value*& var_t_preHdr,
            llvm::Value*& var_nextT_preHdr);

        void linkLoops(llvm::Loop* loopSeg1, llvm::Loop* loopSeg2,
            llvm::BasicBlock* afterLoop);

        void rewriteLoopSeg1(llvm::Loop* loop,
            llvm::Value* range_x, llvm::Value* range_y);

        void hoistSeg2Invariant(llvm::Loop* loop,
            llvm::Instruction* fmul, llvm::Value* cbVal);

        void rewriteLoopSeg2(llvm::Loop* loop,
            llvm::Value* range_y, llvm::Value* cbVal);

        void rewriteLoopSeg3(llvm::BasicBlock* bb,
            llvm::Value* range_y);
    };


    llvm::LoopPass* createLoopHoistConstant();

} // namespace IGC