File: inferenceresult.jl

package info (click to toggle)
julia 1.5.3%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 91,132 kB
  • sloc: lisp: 278,486; ansic: 60,186; cpp: 29,801; sh: 2,403; makefile: 1,998; pascal: 1,313; objc: 647; javascript: 516; asm: 226; python: 161; xml: 34
file content (173 lines) | stat: -rw-r--r-- 7,457 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# This file is a part of Julia. License is MIT: https://julialang.org/license

const EMPTY_VECTOR = Vector{Any}()

mutable struct InferenceResult
    linfo::MethodInstance
    argtypes::Vector{Any}
    overridden_by_const::BitVector
    result # ::Type, or InferenceState if WIP
    src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
    function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
        argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
        return new(linfo, argtypes, overridden_by_const, Any, nothing)
    end
end

function is_argtype_match(@nospecialize(given_argtype),
                          @nospecialize(cache_argtype),
                          overridden_by_const::Bool)
    if isa(given_argtype, Const) || isa(given_argtype, PartialStruct)
        return is_lattice_equal(given_argtype, cache_argtype)
    end
    return !overridden_by_const
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
# so that we can construct cache-correct `InferenceResult`s in the first place.
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector)
    @assert isa(linfo.def, Method) # ensure the next line works
    nargs::Int = linfo.def.nargs
    @assert length(given_argtypes) >= (nargs - 1)
    given_argtypes = anymap(widenconditional, given_argtypes)
    if linfo.def.isva
        isva_given_argtypes = Vector{Any}(undef, nargs)
        for i = 1:(nargs - 1)
            isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
        end
        if length(given_argtypes) >= nargs || !isvarargtype(given_argtypes[end])
            isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[nargs:end])
        else
            isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[end:end])
        end
        given_argtypes = isva_given_argtypes
    end
    cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing)
    if nargs === length(given_argtypes)
        for i in 1:nargs
            given_argtype = given_argtypes[i]
            cache_argtype = cache_argtypes[i]
            if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
                # prefer the argtype we were given over the one computed from `linfo`
                cache_argtypes[i] = given_argtype
                overridden_by_const[i] = true
            end
        end
    end
    return cache_argtypes, overridden_by_const
end

function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
    toplevel = !isa(linfo.def, Method)
    linfo_argtypes = Any[unwrap_unionall(linfo.specTypes).parameters...]
    nargs::Int = toplevel ? 0 : linfo.def.nargs
    cache_argtypes = Vector{Any}(undef, nargs)
    # First, if we're dealing with a varargs method, then we set the last element of `args`
    # to the appropriate `Tuple` type or `PartialStruct` instance.
    if !toplevel && linfo.def.isva
        if linfo.specTypes == Tuple
            if nargs > 1
                linfo_argtypes = svec(Any[Any for i = 1:(nargs - 1)]..., Tuple.parameters[1])
            end
            vargtype = Tuple
        else
            linfo_argtypes_length = length(linfo_argtypes)
            if nargs > linfo_argtypes_length
                va = linfo_argtypes[linfo_argtypes_length]
                if isvarargtype(va)
                    new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes)
                    vargtype_elements = Any[new_va]
                    vargtype = Tuple{new_va}
                else
                    vargtype_elements = Any[]
                    vargtype = Tuple{}
                end
            else
                vargtype_elements = Any[]
                for p in linfo_argtypes[nargs:linfo_argtypes_length]
                    p = isvarargtype(p) ? unconstrain_vararg_length(p) : p
                    push!(vargtype_elements, rewrap(p, linfo.specTypes))
                end
                for i in 1:length(vargtype_elements)
                    atyp = vargtype_elements[i]
                    if isa(atyp, DataType) && isdefined(atyp, :instance)
                        # replace singleton types with their equivalent Const object
                        vargtype_elements[i] = Const(atyp.instance)
                    elseif isconstType(atyp)
                        vargtype_elements[i] = Const(atyp.parameters[1])
                    end
                end
                vargtype = tuple_tfunc(vargtype_elements)
            end
        end
        cache_argtypes[nargs] = vargtype
        nargs -= 1
    end
    # Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
    # type info as we go (where possible). Note that if we're dealing with a varargs method,
    # we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
    # we don't overwrite the result of that work here).
    linfo_argtypes_length = length(linfo_argtypes)
    if linfo_argtypes_length > 0
        n = linfo_argtypes_length > nargs ? nargs : linfo_argtypes_length
        tail_index = n
        local lastatype
        for i = 1:n
            atyp = linfo_argtypes[i]
            if i == n && isvarargtype(atyp)
                atyp = unwrapva(atyp)
                tail_index -= 1
            end
            while isa(atyp, TypeVar)
                atyp = atyp.ub
            end
            if isa(atyp, DataType) && isdefined(atyp, :instance)
                # replace singleton types with their equivalent Const object
                atyp = Const(atyp.instance)
            elseif isconstType(atyp)
                atyp = Const(atyp.parameters[1])
            else
                atyp = rewrap(atyp, linfo.specTypes)
            end
            i == n && (lastatype = atyp)
            cache_argtypes[i] = atyp
        end
        for i = (tail_index + 1):nargs
            cache_argtypes[i] = lastatype
        end
    else
        @assert nargs == 0 "invalid specialization of method" # wrong number of arguments
    end
    return cache_argtypes, falses(length(cache_argtypes))
end

function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
    method = linfo.def::Method
    nargs::Int = method.nargs
    method.isva && (nargs -= 1)
    length(given_argtypes) >= nargs || return nothing
    for cached_result in cache
        cached_result.linfo === linfo || continue
        cache_match = true
        cache_argtypes = cached_result.argtypes
        cache_overridden_by_const = cached_result.overridden_by_const
        for i in 1:nargs
            if !is_argtype_match(given_argtypes[i],
                                 cache_argtypes[i],
                                 cache_overridden_by_const[i])
                cache_match = false
                break
            end
        end
        if method.isva && cache_match
            cache_match = is_argtype_match(tuple_tfunc(given_argtypes[(nargs + 1):end]),
                                           cache_argtypes[end],
                                           cache_overridden_by_const[end])
        end
        cache_match || continue
        return cached_result
    end
    return nothing
end