File: smtlib.py

package info (click to toggle)
firefox-esr 68.10.0esr-1~deb9u1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,143,932 kB
  • sloc: cpp: 5,227,879; javascript: 4,315,531; ansic: 2,467,042; python: 794,975; java: 349,993; asm: 232,034; xml: 228,320; sh: 82,008; lisp: 41,202; makefile: 22,347; perl: 15,555; objc: 5,277; cs: 4,725; yacc: 1,778; ada: 1,681; pascal: 1,673; lex: 1,417; exp: 527; php: 436; ruby: 225; awk: 162; sed: 53; csh: 44
file content (241 lines) | stat: -rw-r--r-- 8,522 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only
primitive instructions.
"""
from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\
    bvult, bvzeroext, bvsplit, bvconcat, bvsignext
from cdsl.ast import Var
from cdsl.types import BVType
from .elaborate import elaborate
from z3 import BitVec, ZeroExt, SignExt, And, Extract, Concat, Not, Solver,\
    unsat, BoolRef, BitVecVal, If
from z3.z3core import Z3_mk_eq

try:
    from typing import TYPE_CHECKING, Tuple, Dict, List # noqa
    from cdsl.xform import Rtl, XForm # noqa
    from cdsl.ast import VarAtomMap, Atom # noqa
    from cdsl.ti import VarTyping # noqa
    if TYPE_CHECKING:
        from z3 import ExprRef, BitVecRef # noqa
        Z3VarMap = Dict[Var, BitVecRef]
except ImportError:
    TYPE_CHECKING = False


# Use this for constructing a == b instead of == since MyPy doesn't
# accept overloading of __eq__ that doesn't return bool
def mk_eq(e1, e2):
    # type: (ExprRef, ExprRef) -> ExprRef
    """Return a z3 expression equivalent to e1 == e2"""
    return BoolRef(Z3_mk_eq(e1.ctx_ref(), e1.as_ast(), e2.as_ast()), e1.ctx)


def to_smt(r):
    # type: (Rtl) -> Tuple[List[ExprRef], Z3VarMap]
    """
    Encode a concrete primitive Rtl r sa z3 query.
    Returns a tuple (query, var_m) where:
        - query is a list of z3 expressions
        - var_m is a map from Vars v with non-BVType to their correspodning z3
          bitvector variable.
    """
    assert r.is_concrete()
    # Should contain only primitives
    primitives = set(PRIMITIVES.instructions)
    assert set(d.expr.inst for d in r.rtl).issubset(primitives)

    q = []  # type: List[ExprRef]
    m = {}  # type: Z3VarMap

    # Build declarations for any bitvector Vars
    var_to_bv = {}  # type: Z3VarMap
    for v in r.vars():
        typ = v.get_typevar().singleton_type()
        if not isinstance(typ, BVType):
            continue

        var_to_bv[v] = BitVec(v.name, typ.bits)

    # Encode each instruction as a equality assertion
    for d in r.rtl:
        inst = d.expr.inst

        exp = None  # type: ExprRef
        # For prim_to_bv/prim_from_bv just update var_m. No assertion needed
        if inst == prim_to_bv:
            assert isinstance(d.expr.args[0], Var)
            m[d.expr.args[0]] = var_to_bv[d.defs[0]]
            continue

        if inst == prim_from_bv:
            assert isinstance(d.expr.args[0], Var)
            m[d.defs[0]] = var_to_bv[d.expr.args[0]]
            continue

        if inst in [bvadd, bvult]:  # Binary instructions
            assert len(d.expr.args) == 2 and len(d.defs) == 1
            lhs = d.expr.args[0]
            rhs = d.expr.args[1]
            df = d.defs[0]
            assert isinstance(lhs, Var) and isinstance(rhs, Var)

            if inst == bvadd:  # Normal binary - output type same as args
                exp = (var_to_bv[lhs] + var_to_bv[rhs])
            else:
                assert inst == bvult
                exp = (var_to_bv[lhs] < var_to_bv[rhs])
                # Comparison binary - need to convert bool to BitVec 1
                exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1))

            exp = mk_eq(var_to_bv[df], exp)
        elif inst == bvzeroext:
            arg = d.expr.args[0]
            df = d.defs[0]
            assert isinstance(arg, Var)
            fromW = arg.get_typevar().singleton_type().width()
            toW = df.get_typevar().singleton_type().width()

            exp = mk_eq(var_to_bv[df], ZeroExt(toW-fromW, var_to_bv[arg]))
        elif inst == bvsignext:
            arg = d.expr.args[0]
            df = d.defs[0]
            assert isinstance(arg, Var)
            fromW = arg.get_typevar().singleton_type().width()
            toW = df.get_typevar().singleton_type().width()

            exp = mk_eq(var_to_bv[df], SignExt(toW-fromW, var_to_bv[arg]))
        elif inst == bvsplit:
            arg = d.expr.args[0]
            assert isinstance(arg, Var)
            arg_typ = arg.get_typevar().singleton_type()
            width = arg_typ.width()
            assert (width % 2 == 0)

            lo = d.defs[0]
            hi = d.defs[1]

            exp = And(mk_eq(var_to_bv[lo],
                      Extract(width//2-1, 0, var_to_bv[arg])),
                      mk_eq(var_to_bv[hi],
                      Extract(width-1, width//2, var_to_bv[arg])))
        elif inst == bvconcat:
            assert isinstance(d.expr.args[0], Var) and \
                isinstance(d.expr.args[1], Var)
            lo = d.expr.args[0]
            hi = d.expr.args[1]
            df = d.defs[0]

            # Z3 Concat expects hi bits first, then lo bits
            exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo]))
        else:
            assert False, "Unknown primitive instruction {}".format(inst)

        q.append(exp)

    return (q, m)


def equivalent(r1, r2, inp_m, out_m):
    # type: (Rtl, Rtl, VarAtomMap, VarAtomMap) -> List[ExprRef]
    """
    Given:
        - concrete source Rtl r1
        - concrete dest Rtl r2
        - VarAtomMap inp_m mapping r1's non-bitvector inputs to r2
        - VarAtomMap out_m mapping r1's non-bitvector outputs to r2

    Build a query checking whether r1 and r2 are semantically equivalent.
    If the returned query is unsatisfiable, then r1 and r2 are equivalent.
    Otherwise, the satisfying example for the query gives us values
    for which the two Rtls disagree.
    """
    # Sanity - inp_m is a bijection from the set of inputs of r1 to the set of
    # inputs of r2
    assert set(r1.free_vars()) == set(inp_m.keys())
    assert set(r2.free_vars()) == set(inp_m.values())

    # Note that the same rule is not expected to hold for out_m due to
    # temporaries/intermediates. out_m specified which values are enough for
    # equivalence.

    # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts
    src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()}  # type: VarAtomMap # noqa
    dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()}  # type: VarAtomMap # noqa
    r1 = r1.copy(src_m)
    r2 = r2.copy(dst_m)

    def _translate(m, k_m, v_m):
        # type: (VarAtomMap, VarAtomMap, VarAtomMap) -> VarAtomMap
        """Obtain a new map from m, by mapping m's keys with k_m and m's values
        with v_m"""
        res = {}  # type: VarAtomMap
        for (k, v) in m1.items():
            new_k = k_m[k]
            new_v = v_m[v]
            assert isinstance(new_k, Var)
            res[new_k] = new_v

        return res

    # Convert inp_m, out_m in terms of variables with the .a/.b suffixes
    inp_m = _translate(inp_m, src_m, dst_m)
    out_m = _translate(out_m, src_m, dst_m)

    # Encode r1 and r2 as SMT queries
    (q1, m1) = to_smt(r1)
    (q2, m2) = to_smt(r2)

    # Build an expression for the equality of real Cranelift inputs of
    # r1 and r2
    args_eq_exp = []  # type: List[ExprRef]

    for (v1, v2) in inp_m.items():
        assert isinstance(v2, Var)
        args_eq_exp.append(mk_eq(m1[v1], m2[v2]))

    # Build an expression for the equality of real Cranelift outputs of
    # r1 and r2
    results_eq_exp = []  # type: List[ExprRef]
    for (v1, v2) in out_m.items():
        assert isinstance(v2, Var)
        results_eq_exp.append(mk_eq(m1[v1], m2[v2]))

    # Put the whole query together
    return q1 + q2 + args_eq_exp + [Not(And(*results_eq_exp))]


def xform_correct(x, typing):
    # type: (XForm, VarTyping) -> bool
    """
    Given an XForm x and a concrete variable typing for x check whether x is
    semantically preserving for the concrete typing.
    """
    assert x.ti.permits(typing)

    # Create copies of the x.src and x.dst with their concrete types
    src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()}  # type: VarAtomMap # noqa
    src = x.src.copy(src_m)
    dst = x.apply(src)
    dst_m = x.dst.substitution(dst, {})

    # Build maps for the inputs/outputs for src->dst
    inp_m = {}  # type: VarAtomMap
    out_m = {}  # type: VarAtomMap

    for v in x.src.vars():
        src_v = src_m[v]
        assert isinstance(src_v, Var)
        if v.is_input():
            inp_m[src_v] = dst_m[v]
        elif v.is_output():
            out_m[src_v] = dst_m[v]

    # Get the primitive semantic Rtls for src and dst
    prim_src = elaborate(src)
    prim_dst = elaborate(dst)
    asserts = equivalent(prim_src, prim_dst, inp_m, out_m)

    s = Solver()
    s.add(*asserts)
    return s.check() == unsat