1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
from numba import njit
import numba
from numba.core import ir
@njit(inline='never')
def never_inline():
return 100
@njit(inline='always')
def always_inline():
return 200
def sentinel_cost_model(expr, caller_info, callee_info):
# this cost model will return True (i.e. do inlining) if either:
# a) the callee IR contains an `ir.Const(37)`
# b) the caller IR contains an `ir.Const(13)` logically prior to the call
# site
# check the callee
for blk in callee_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 37:
return True
# check the caller
before_expr = True
for blk in caller_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Expr):
if stmt.value == expr:
before_expr = False
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 13:
return True & before_expr
return False
@njit(inline=sentinel_cost_model)
def maybe_inline1():
# Will not inline based on the callee IR with the declared cost model
# The following is ir.Const(300).
return 300
@njit(inline=sentinel_cost_model)
def maybe_inline2():
# Will inline based on the callee IR with the declared cost model
# The following is ir.Const(37).
return 37
@njit
def foo():
a = never_inline() # will never inline
b = always_inline() # will always inline
# will not inline as the function does not contain a magic constant known to
# the cost model, and the IR up to the call site does not contain a magic
# constant either
d = maybe_inline1()
# declare this magic constant to trigger inlining of maybe_inline1 in a
# subsequent call
magic_const = 13
# will inline due to above constant declaration
e = maybe_inline1()
# will inline as the maybe_inline2 function contains a magic constant known
# to the cost model
c = maybe_inline2()
return a + b + c + d + e + magic_const
foo()
|