File: inlining.rst

package info (click to toggle)
numba 0.61.2%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 17,316 kB
  • sloc: python: 211,580; ansic: 15,233; cpp: 6,544; javascript: 424; sh: 322; makefile: 173
file content (281 lines) | stat: -rw-r--r-- 13,102 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

=================
Notes on Inlining
=================

There are occasions where it is useful to be able to inline a function at its
call site, at the Numba IR level of representation. The decorators such as
:func:`numba.jit`, :func:`numba.extending.overload` and
:func:`register_jitable` support the keyword argument ``inline``, to facilitate
this behaviour.

When attempting to inline at this level, it is important to understand what
purpose this serves and what effect this will have. In contrast to the inlining
performed by LLVM, which is aimed at improving performance, the main reason to
inline at the Numba IR level is to allow type inference to cross function
boundaries.

As an example, consider the following snippet:

.. code:: python

    from numba import njit


    @njit
    def bar(a):
        a.append(10)


    @njit
    def foo():
        z = []
        bar(z)


    foo()

This will fail to compile and run, because the type of ``z`` can not be inferred
as it will only be refined within ``bar``. If we now add ``inline=True`` to the
decorator for ``bar`` the snippet will compile and run. This is because inlining
the call to ``a.append(10)`` will mean that ``z`` will be refined to hold integers
and so type inference will succeed.

So, to recap, inlining at the Numba IR level is unlikely to have a performance
benefit. Whereas inlining at the LLVM level stands a better chance.

The ``inline`` keyword argument can be one of three values:

* The string ``'never'``, this is the default and results in the function not
  being inlined under any circumstances.
* The string ``'always'``, this results in the function being inlined at all
  call sites.
* A python function that takes three arguments. The first argument is always the
  ``ir.Expr`` node that is the ``call`` requesting the inline, this is present
  to allow the function to make call contextually aware decisions. The second
  and third arguments are:

  * In the case of an untyped inline, i.e. that which occurs when using the
    :func:`numba.jit` family of decorators, both arguments are
    ``numba.ir.FunctionIR`` instances. The second argument corresponding to the
    IR of the caller, the third argument corresponding to the IR of the callee.

  * In the case of a typed inline, i.e. that which occurs when using
    :func:`numba.extending.overload`, both arguments are instances of a
    ``namedtuple`` with fields (corresponding to their standard use in the
    compiler internals):

    * ``func_ir`` - the function's Numba IR.
    * ``typemap`` - the function's type map.
    * ``calltypes`` - the call types of any calls in the function.
    * ``signature`` - the function's signature.

    The second argument holds the information from the caller, the third holds
    the information from the callee.

  In all cases the function should return True to inline and return False to not
  inline, this essentially permitting custom inlining rules (typical use might
  be cost models).
* Recursive functions with ``inline='always'`` will result in a non-terminating
  compilation. If you wish to avoid this, supply a function to limit the
  recursion depth (see below).

.. note:: No guarantee is made about the order in which functions are assessed
          for inlining or about the order in which they are inlined.


Example using :func:`numba.jit`
===============================

An example of using all three options to ``inline`` in the :func:`numba.njit`
decorator:

.. literalinclude:: inline_example.py

which produces the following when executed (with a print of the IR after the
legalization pass, enabled via the environment variable
``NUMBA_DEBUG_PRINT_AFTER="ir_legalization"``):

.. code-block:: none
    :emphasize-lines: 2, 3, 9, 16, 17, 21, 22, 26, 35

    label 0:
        $0.1 = global(never_inline: CPUDispatcher(<function never_inline at 0x7f890ccf9048>)) ['$0.1']
        $0.2 = call $0.1(func=$0.1, args=[], kws=(), vararg=None) ['$0.1', '$0.2']
        del $0.1                                 []
        a = $0.2                                 ['$0.2', 'a']
        del $0.2                                 []
        $0.3 = global(always_inline: CPUDispatcher(<function always_inline at 0x7f890ccf9598>)) ['$0.3']
        del $0.3                                 []
        $const0.1.0 = const(int, 200)            ['$const0.1.0']
        $0.2.1 = $const0.1.0                     ['$0.2.1', '$const0.1.0']
        del $const0.1.0                          []
        $0.4 = $0.2.1                            ['$0.2.1', '$0.4']
        del $0.2.1                               []
        b = $0.4                                 ['$0.4', 'b']
        del $0.4                                 []
        $0.5 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.5']
        $0.6 = call $0.5(func=$0.5, args=[], kws=(), vararg=None) ['$0.5', '$0.6']
        del $0.5                                 []
        d = $0.6                                 ['$0.6', 'd']
        del $0.6                                 []
        $const0.7 = const(int, 13)               ['$const0.7']
        magic_const = $const0.7                  ['$const0.7', 'magic_const']
        del $const0.7                            []
        $0.8 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.8']
        del $0.8                                 []
        $const0.1.2 = const(int, 300)            ['$const0.1.2']
        $0.2.3 = $const0.1.2                     ['$0.2.3', '$const0.1.2']
        del $const0.1.2                          []
        $0.9 = $0.2.3                            ['$0.2.3', '$0.9']
        del $0.2.3                               []
        e = $0.9                                 ['$0.9', 'e']
        del $0.9                                 []
        $0.10 = global(maybe_inline2: CPUDispatcher(<function maybe_inline2 at 0x7f890ccf9b70>)) ['$0.10']
        del $0.10                                []
        $const0.1.4 = const(int, 37)             ['$const0.1.4']
        $0.2.5 = $const0.1.4                     ['$0.2.5', '$const0.1.4']
        del $const0.1.4                          []
        $0.11 = $0.2.5                           ['$0.11', '$0.2.5']
        del $0.2.5                               []
        c = $0.11                                ['$0.11', 'c']
        del $0.11                                []
        $0.14 = a + b                            ['$0.14', 'a', 'b']
        del b                                    []
        del a                                    []
        $0.16 = $0.14 + c                        ['$0.14', '$0.16', 'c']
        del c                                    []
        del $0.14                                []
        $0.18 = $0.16 + d                        ['$0.16', '$0.18', 'd']
        del d                                    []
        del $0.16                                []
        $0.20 = $0.18 + e                        ['$0.18', '$0.20', 'e']
        del e                                    []
        del $0.18                                []
        $0.22 = $0.20 + magic_const              ['$0.20', '$0.22', 'magic_const']
        del magic_const                          []
        del $0.20                                []
        $0.23 = cast(value=$0.22)                ['$0.22', '$0.23']
        del $0.22                                []
        return $0.23                             ['$0.23']


Things to note in the above:

1. The call to the function ``never_inline`` remains as a call.
2. The ``always_inline`` function has been inlined, note its
   ``const(int, 200)`` in the caller body.
3. There is a call to ``maybe_inline1`` before the ``const(int, 13)``
   declaration, the cost model prevented this from being inlined.
4. After the ``const(int, 13)`` the subsequent call to ``maybe_inline1`` has
   been inlined as shown by the ``const(int, 300)`` in the caller body.
5. The function ``maybe_inline2`` has been inlined as demonstrated by
   ``const(int, 37)`` in the caller body.
6. That dead code elimination has not been performed and as a result there are
   superfluous statements present in the IR.


Example using :func:`numba.extending.overload`
==============================================

An example of using inlining with the  :func:`numba.extending.overload`
decorator. It is most interesting to note that if a function is supplied as the
argument to ``inline`` a lot more information is available via the supplied
function arguments for use in decision making. Also that different
``@overload`` s can have different inlining behaviours, with multiple ways to
achieve this:

.. literalinclude:: inline_overload_example.py

which produces the following when executed (with a print of the IR after the
legalization pass, enabled via the environment variable
``NUMBA_DEBUG_PRINT_AFTER="ir_legalization"``):

.. code-block:: none
    :emphasize-lines: 2, 3, 4, 5, 6, 15, 16, 17, 18, 19, 20, 21, 22, 28, 29, 30

    label 0:
        $const0.2 = const(tuple, (1, 2, 3))      ['$const0.2']
        x.0 = $const0.2                          ['$const0.2', 'x.0']
        del $const0.2                            []
        $const0.2.2 = const(int, 0)              ['$const0.2.2']
        $0.3.3 = getitem(value=x.0, index=$const0.2.2) ['$0.3.3', '$const0.2.2', 'x.0']
        del x.0                                  []
        del $const0.2.2                          []
        $0.4.4 = $0.3.3                          ['$0.3.3', '$0.4.4']
        del $0.3.3                               []
        $0.3 = $0.4.4                            ['$0.3', '$0.4.4']
        del $0.4.4                               []
        a = $0.3                                 ['$0.3', 'a']
        del $0.3                                 []
        $const0.5 = const(int, 100)              ['$const0.5']
        x.5 = $const0.5                          ['$const0.5', 'x.5']
        del $const0.5                            []
        $const0.2.7 = const(int, 1)              ['$const0.2.7']
        $0.3.8 = x.5 + $const0.2.7               ['$0.3.8', '$const0.2.7', 'x.5']
        del x.5                                  []
        del $const0.2.7                          []
        $0.4.9 = $0.3.8                          ['$0.3.8', '$0.4.9']
        del $0.3.8                               []
        $0.6 = $0.4.9                            ['$0.4.9', '$0.6']
        del $0.4.9                               []
        b = $0.6                                 ['$0.6', 'b']
        del $0.6                                 []
        $0.7 = global(bar: <function bar at 0x7f6c3710d268>) ['$0.7']
        $const0.8 = const(complex, 300j)         ['$const0.8']
        $0.9 = call $0.7($const0.8, func=$0.7, args=[Var($const0.8, inline_overload_example.py (56))], kws=(), vararg=None) ['$0.7', '$0.9', '$const0.8']
        del $const0.8                            []
        del $0.7                                 []
        c = $0.9                                 ['$0.9', 'c']
        del $0.9                                 []
        $0.12 = a + b                            ['$0.12', 'a', 'b']
        del b                                    []
        del a                                    []
        $0.14 = $0.12 + c                        ['$0.12', '$0.14', 'c']
        del c                                    []
        del $0.12                                []
        $0.15 = cast(value=$0.14)                ['$0.14', '$0.15']
        del $0.14                                []
        return $0.15                             ['$0.15']

Things to note in the above:

1. The first highlighted section is the always inlined overload for the
   ``UniTuple`` argument type.
2. The second highlighted section is the overload for the ``Number`` argument
   type that has been inlined as the cost model function decided to do so as the
   argument was an ``Integer`` type instance.
3. The third highlighted section is the overload for the ``Number`` argument
   type that has not inlined as the cost model function decided to reject it as
   the argument was an ``Complex`` type instance.
4. That dead code elimination has not been performed and as a result there are
   superfluous statements present in the IR.

Using a function to limit the inlining depth of a recursive function
====================================================================

When using recursive inlines, you can terminate the compilation by using
a cost model.

.. code:: python

    from numba import njit
    import numpy as np

    class CostModel(object):
        def __init__(self, max_inlines):
            self._count = 0
            self._max_inlines = max_inlines

        def __call__(self, expr, caller, callee):
            ret = self._count < self._max_inlines
            self._count += 1
            return ret

    @njit(inline=CostModel(3))
    def factorial(n):
        if n <= 0:
            return 1
        return n * factorial(n - 1)

    factorial(5)