import py
from rpython.rlib.jit import JitDriver, hint, set_param
from rpython.rlib.jit import unroll_safe, dont_look_inside, promote
from rpython.rlib.objectmodel import we_are_translated
from rpython.rlib.debug import fatalerror
from rpython.jit.metainterp.test.support import LLJitMixin
from rpython.jit.codewriter.policy import StopAtXPolicy
from rpython.rtyper.annlowlevel import hlstr
from rpython.jit.metainterp.warmspot import get_stats
from rpython.jit.backend.llsupport import codemap

class RecursiveTests:

    def test_simple_recursion(self):
        myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
        def f(n):
            m = n - 2
            while True:
                myjitdriver.jit_merge_point(n=n, m=m)
                n -= 1
                if m == n:
                    return main(n) * 2
                myjitdriver.can_enter_jit(n=n, m=m)
        def main(n):
            if n > 0:
                return f(n+1)
            else:
                return 1
        res = self.meta_interp(main, [20], enable_opts='')
        assert res == main(20)
        self.check_history(call_i=0)

    def test_simple_recursion_with_exc(self):
        myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
        class Error(Exception):
            pass

        def f(n):
            m = n - 2
            while True:
                myjitdriver.jit_merge_point(n=n, m=m)
                n -= 1
                if n == 10:
                    raise Error
                if m == n:
                    try:
                        return main(n) * 2
                    except Error:
                        return 2
                myjitdriver.can_enter_jit(n=n, m=m)
        def main(n):
            if n > 0:
                return f(n+1)
            else:
                return 1
        res = self.meta_interp(main, [20], enable_opts='')
        assert res == main(20)

    def test_recursion_three_times(self):
        myjitdriver = JitDriver(greens=[], reds=['n', 'm', 'total'])
        def f(n):
            m = n - 3
            total = 0
            while True:
                myjitdriver.jit_merge_point(n=n, m=m, total=total)
                n -= 1
                total += main(n)
                if m == n:
                    return total + 5
                myjitdriver.can_enter_jit(n=n, m=m, total=total)
        def main(n):
            if n > 0:
                return f(n)
            else:
                return 1
        print
        for i in range(1, 11):
            print '%3d %9d' % (i, f(i))
        res = self.meta_interp(main, [10], enable_opts='')
        assert res == main(10)
        self.check_enter_count_at_most(11)

    def test_bug_1(self):
        myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack'])
        def opaque(n, i):
            if n == 1 and i == 19:
                for j in range(20):
                    res = f(0)      # recurse repeatedly, 20 times
                    assert res == 0
        def f(n):
            stack = [n]
            i = 0
            while i < 20:
                myjitdriver.can_enter_jit(n=n, i=i, stack=stack)
                myjitdriver.jit_merge_point(n=n, i=i, stack=stack)
                opaque(n, i)
                i += 1
            return stack.pop()
        res = self.meta_interp(f, [1], enable_opts='', repeat=2,
                               policy=StopAtXPolicy(opaque))
        assert res == 1

    def get_interpreter(self, codes):
        ADD = "0"
        JUMP_BACK = "1"
        CALL = "2"
        EXIT = "3"

        def getloc(i, code):
            return 'code="%s", i=%d' % (code, i)

        jitdriver = JitDriver(greens = ['i', 'code'], reds = ['n'],
                              get_printable_location = getloc)

        def interpret(codenum, n, i):
            code = codes[codenum]
            while i < len(code):
                jitdriver.jit_merge_point(n=n, i=i, code=code)
                op = code[i]
                if op == ADD:
                    n += 1
                    i += 1
                elif op == CALL:
                    n = interpret(1, n, 1)
                    i += 1
                elif op == JUMP_BACK:
                    if n > 20:
                        return 42
                    i -= 2
                    jitdriver.can_enter_jit(n=n, i=i, code=code)
                elif op == EXIT:
                    return n
                else:
                    raise NotImplementedError
            return n

        return interpret

    def test_inline(self):
        code = "021"
        subcode = "00"

        codes = [code, subcode]
        f = self.get_interpreter(codes)

        assert self.meta_interp(f, [0, 0, 0], enable_opts='') == 42
        self.check_resops(call_may_force_i=1, int_add=1, call=0)
        assert self.meta_interp(f, [0, 0, 0], enable_opts='',
                                inline=True) == 42
        self.check_resops(call=0, int_add=2, call_may_force_i=0,
                          guard_no_exception=0)

    def test_inline_jitdriver_check(self):
        code = "021"
        subcode = "100"
        codes = [code, subcode]

        f = self.get_interpreter(codes)

        assert self.meta_interp(f, [0, 0, 0], enable_opts='',
                                inline=True) == 42
        # the call is fully inlined, because we jump to subcode[1], thus
        # skipping completely the JUMP_BACK in subcode[0]
        self.check_resops(call=0, call_may_force=0, call_assembler=0)

    def test_guard_failure_in_inlined_function(self):
        def p(pc, code):
            code = hlstr(code)
            return "%s %d %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
                                get_printable_location=p)
        def f(code, n):
            pc = 0
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
                op = code[pc]
                if op == "-":
                    n -= 1
                elif op == "c":
                    n = f("---i---", n)
                elif op == "i":
                    if n % 5 == 1:
                        return n
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
                        pc = 0
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def main(n):
            return f("c-l", n)
        print main(100)
        res = self.meta_interp(main, [100], enable_opts='', inline=True)
        assert res == 0

    def test_guard_failure_and_then_exception_in_inlined_function(self):
        def p(pc, code):
            code = hlstr(code)
            return "%s %d %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n', 'flag'],
                                get_printable_location=p)
        def f(code, n):
            pc = 0
            flag = False
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag)
                op = code[pc]
                if op == "-":
                    n -= 1
                elif op == "c":
                    try:
                        n = f("---ir---", n)
                    except Exception:
                        return n
                elif op == "i":
                    if n < 200:
                        flag = True
                elif op == "r":
                    if flag:
                        raise Exception
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag)
                        pc = 0
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def main(n):
            return f("c-l", n)
        print main(1000)
        res = self.meta_interp(main, [1000], enable_opts='', inline=True)
        assert res == main(1000)

    def test_exception_in_inlined_function(self):
        def p(pc, code):
            code = hlstr(code)
            return "%s %d %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
                                get_printable_location=p)

        class Exc(Exception):
            pass

        def f(code, n):
            pc = 0
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
                op = code[pc]
                if op == "-":
                    n -= 1
                elif op == "c":
                    try:
                        n = f("---i---", n)
                    except Exc:
                        pass
                elif op == "i":
                    if n % 5 == 1:
                        raise Exc
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
                        pc = 0
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def main(n):
            return f("c-l", n)
        res = self.meta_interp(main, [100], enable_opts='', inline=True)
        assert res == main(100)

    def test_recurse_during_blackholing(self):
        # this passes, if the blackholing shortcut for calls is turned off
        # it fails, it is very delicate in terms of parameters,
        # bridge/loop creation order
        def p(pc, code):
            code = hlstr(code)
            return "%s %d %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
                                get_printable_location=p)

        def f(code, n):
            pc = 0
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
                op = code[pc]
                if op == "-":
                    n -= 1
                elif op == "c":
                    if n < 70 and n % 3 == 1:
                        n = f("--", n)
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
                        pc = 0
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def main(n):
            set_param(None, 'threshold', 3)
            set_param(None, 'trace_eagerness', 5)
            return f("c-l", n)
        expected = main(100)
        res = self.meta_interp(main, [100], enable_opts='', inline=True)
        assert res == expected

    def check_max_trace_length(self, length):
        for loop in get_stats().loops:
            assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode
            for op in loop.operations:
                if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'):
                    assert len(op.getdescr()._debug_suboperations) <= length + 5

    def test_inline_trace_limit(self):
        myjitdriver = JitDriver(greens=[], reds=['n'])
        def recursive(n):
            if n > 0:
                return recursive(n - 1) + 1
            return 0
        def loop(n):
            set_param(myjitdriver, "threshold", 10)
            pc = 0
            while n:
                myjitdriver.can_enter_jit(n=n)
                myjitdriver.jit_merge_point(n=n)
                n = recursive(n)
                n -= 1
            return n
        TRACE_LIMIT = 66
        res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
        assert res == 0
        self.check_max_trace_length(TRACE_LIMIT)
        self.check_enter_count_at_most(10) # maybe
        self.check_aborted_count(6)

    def test_trace_limit_bridge(self):
        def recursive(n):
            if n > 0:
                return recursive(n - 1) + 1
            return 0
        myjitdriver = JitDriver(greens=[], reds=['n'])
        def loop(n):
            set_param(None, "threshold", 4)
            set_param(None, "trace_eagerness", 2)
            while n:
                myjitdriver.can_enter_jit(n=n)
                myjitdriver.jit_merge_point(n=n)
                if n % 5 == 0:
                    n -= 1
                if n < 50:
                    n = recursive(n)
                n -= 1
            return n
        TRACE_LIMIT = 20
        res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
        self.check_max_trace_length(TRACE_LIMIT)
        self.check_aborted_count(8)
        self.check_enter_count_at_most(30)

    def test_trace_limit_with_exception_bug(self):
        myjitdriver = JitDriver(greens=[], reds=['n'])
        @unroll_safe
        def do_stuff(n):
            while n > 0:
                n -= 1
            raise ValueError
        def loop(n):
            pc = 0
            while n > 80:
                myjitdriver.can_enter_jit(n=n)
                myjitdriver.jit_merge_point(n=n)
                try:
                    do_stuff(n)
                except ValueError:
                    # the trace limit is checked when we arrive here, and we
                    # have the exception still in last_exc_value_box at this
                    # point -- so when we abort because of a trace too long,
                    # the exception is passed to the blackhole interp and
                    # incorrectly re-raised from here
                    pass
                n -= 1
            return n
        TRACE_LIMIT = 66
        res = self.meta_interp(loop, [100], trace_limit=TRACE_LIMIT)
        assert res == 80

    def test_max_failure_args(self):
        FAILARGS_LIMIT = 10
        jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])

        class A(object):
            def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
                self.i0 = i0
                self.i1 = i1
                self.i2 = i2
                self.i3 = i3
                self.i4 = i4
                self.i5 = i5
                self.i6 = i6
                self.i7 = i7
                self.i8 = i8
                self.i9 = i9

        def loop(n):
            i = 0
            o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
            while i < n:
                jitdriver.can_enter_jit(o=o, i=i, n=n)
                jitdriver.jit_merge_point(o=o, i=i, n=n)
                o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
                      i + 6, i + 7, i + 8, i + 9)
                i += 1
            return o

        res = self.meta_interp(loop, [20], failargs_limit=FAILARGS_LIMIT,
                               listops=True)
        self.check_aborted_count(4)

    def test_max_failure_args_exc(self):
        FAILARGS_LIMIT = 10
        jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])

        class A(object):
            def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
                self.i0 = i0
                self.i1 = i1
                self.i2 = i2
                self.i3 = i3
                self.i4 = i4
                self.i5 = i5
                self.i6 = i6
                self.i7 = i7
                self.i8 = i8
                self.i9 = i9

        def loop(n):
            i = 0
            o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
            while i < n:
                jitdriver.can_enter_jit(o=o, i=i, n=n)
                jitdriver.jit_merge_point(o=o, i=i, n=n)
                o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
                      i + 6, i + 7, i + 8, i + 9)
                i += 1
            raise ValueError

        def main(n):
            try:
                loop(n)
                return 1
            except ValueError:
                return 0

        res = self.meta_interp(main, [20], failargs_limit=FAILARGS_LIMIT,
                               listops=True)
        assert not res
        self.check_aborted_count(4)

    def test_set_param_inlining(self):
        myjitdriver = JitDriver(greens=[], reds=['n', 'recurse'])
        def loop(n, recurse=False):
            while n:
                myjitdriver.jit_merge_point(n=n, recurse=recurse)
                n -= 1
                if not recurse:
                    loop(10, True)
                    myjitdriver.can_enter_jit(n=n, recurse=recurse)
            return n
        TRACE_LIMIT = 66

        def main(inline):
            set_param(None, "threshold", 10)
            set_param(None, 'function_threshold', 60)
            if inline:
                set_param(None, 'inlining', True)
            else:
                set_param(None, 'inlining', False)
            return loop(100)

        res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
        self.check_resops(call=0, call_may_force_i=1)

        res = self.meta_interp(main, [1], enable_opts='', trace_limit=TRACE_LIMIT)
        self.check_resops(call=0, call_may_force=0)

    def test_trace_from_start(self):
        def p(pc, code):
            code = hlstr(code)
            return "'%s' at %d: %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
                                get_printable_location=p)

        def f(code, n):
            pc = 0
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
                op = code[pc]
                if op == "+":
                    n += 7
                elif op == "-":
                    n -= 1
                elif op == "c":
                    n = f('---', n)
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=1)
                        pc = 1
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def g(m):
            if m > 1000000:
                f('', 0)
            result = 0
            for i in range(m):
                result += f('+-cl--', i)
        res = self.meta_interp(g, [50], backendopt=True)
        assert res == g(50)
        py.test.skip("tracing from start is by now only longer enabled "
                     "if a trace gets too big")
        self.check_tree_loop_count(3)
        self.check_history(int_add=1)

    def test_dont_inline_huge_stuff(self):
        def p(pc, code):
            code = hlstr(code)
            return "%s %d %s" % (code, pc, code[pc])
        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
                                get_printable_location=p,
                                is_recursive=True)

        def f(code, n):
            pc = 0
            while pc < len(code):

                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
                op = code[pc]
                if op == "-":
                    n -= 1
                elif op == "c":
                    f('--------------------', n)
                elif op == "l":
                    if n > 0:
                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
                        pc = 0
                        continue
                else:
                    assert 0
                pc += 1
            return n
        def g(m):
            set_param(None, 'inlining', True)
            # carefully chosen threshold to make sure that the inner function
            # cannot be inlined, but the inner function on its own is small
            # enough
            set_param(None, 'trace_limit', 40)
            if m > 1000000:
                f('', 0)
            result = 0
            for i in range(m):
                result += f('-c-----------l-', i+100)
        self.meta_interp(g, [10], backendopt=True)
        self.check_aborted_count(1)
        self.check_resops(call=0, call_assembler_i=2)
        self.check_jitcell_token_count(2)

    def test_directly_call_assembler(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno):
            i = 0
            while i < 10:
                driver.can_enter_jit(codeno = codeno, i = i)
                driver.jit_merge_point(codeno = codeno, i = i)
                if codeno == 2:
                    portal(1)
                i += 1

        self.meta_interp(portal, [2], inline=True)
        self.check_history(call_assembler_n=1)

    def test_recursion_cant_call_assembler_directly(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno, j):
            i = 1
            while 1:
                driver.jit_merge_point(codeno=codeno, i=i, j=j)
                if (i >> 1) == 1:
                    if j == 0:
                        return
                    portal(2, j - 1)
                elif i == 5:
                    return
                i += 1
                driver.can_enter_jit(codeno=codeno, i=i, j=j)

        portal(2, 5)

        from rpython.jit.metainterp import compile, pyjitpl
        pyjitpl._warmrunnerdesc = None
        trace = []
        def my_ctc(*args):
            looptoken = original_ctc(*args)
            trace.append(looptoken)
            return looptoken
        original_ctc = compile.compile_tmp_callback
        try:
            compile.compile_tmp_callback = my_ctc
            self.meta_interp(portal, [2, 5], inline=True)
            self.check_resops(call_may_force=0, call_assembler_n=2)
        finally:
            compile.compile_tmp_callback = original_ctc
        # check that we made a temporary callback
        assert len(trace) == 1
        # and that we later redirected it to something else
        try:
            redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
        except AttributeError:
            pass    # not the llgraph backend
        else:
            print redirected
            assert redirected.keys() == trace

    def test_recursion_cant_call_assembler_directly_with_virtualizable(self):
        # exactly the same logic as the previous test, but with 'frame.j'
        # instead of just 'j'
        class Frame(object):
            _virtualizable_ = ['j']
            def __init__(self, j):
                self.j = j

        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
                           virtualizables = ['frame'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno, frame):
            i = 1
            while 1:
                driver.jit_merge_point(codeno=codeno, i=i, frame=frame)
                if (i >> 1) == 1:
                    if frame.j == 0:
                        return
                    portal(2, Frame(frame.j - 1))
                elif i == 5:
                    return
                i += 1
                driver.can_enter_jit(codeno=codeno, i=i, frame=frame)

        def main(codeno, j):
            portal(codeno, Frame(j))

        main(2, 5)

        from rpython.jit.metainterp import compile, pyjitpl
        pyjitpl._warmrunnerdesc = None
        trace = []
        def my_ctc(*args):
            looptoken = original_ctc(*args)
            trace.append(looptoken)
            return looptoken
        original_ctc = compile.compile_tmp_callback
        try:
            compile.compile_tmp_callback = my_ctc
            self.meta_interp(main, [2, 5], inline=True)
            self.check_resops(call_may_force=0, call_assembler_n=2)
        finally:
            compile.compile_tmp_callback = original_ctc
        # check that we made a temporary callback
        assert len(trace) == 1
        # and that we later redirected it to something else
        try:
            redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
        except AttributeError:
            pass    # not the llgraph backend
        else:
            print redirected
            assert redirected.keys() == trace

    def test_directly_call_assembler_return(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno):
            i = 0
            k = codeno
            while i < 10:
                driver.can_enter_jit(codeno = codeno, i = i, k = k)
                driver.jit_merge_point(codeno = codeno, i = i, k = k)
                if codeno == 2:
                    k = portal(1)
                i += 1
            return k

        self.meta_interp(portal, [2], inline=True)
        self.check_history(call_assembler_i=1)

    def test_directly_call_assembler_raise(self):

        class MyException(Exception):
            def __init__(self, x):
                self.x = x

        driver = JitDriver(greens = ['codeno'], reds = ['i'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno):
            i = 0
            while i < 10:
                driver.can_enter_jit(codeno = codeno, i = i)
                driver.jit_merge_point(codeno = codeno, i = i)
                if codeno == 2:
                    try:
                        portal(1)
                    except MyException as me:
                        i += me.x
                i += 1
            if codeno == 1:
                raise MyException(1)

        self.meta_interp(portal, [2], inline=True)
        self.check_history(call_assembler_n=1)

    def test_directly_call_assembler_fail_guard(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno, k):
            i = 0
            while i < 10:
                driver.can_enter_jit(codeno=codeno, i=i, k=k)
                driver.jit_merge_point(codeno=codeno, i=i, k=k)
                if codeno == 2:
                    k += portal(1, k)
                elif k > 40:
                    if i % 2:
                        k += 1
                    else:
                        k += 2
                k += 1
                i += 1
            return k

        res = self.meta_interp(portal, [2, 0], inline=True)
        assert res == 13542

    def test_directly_call_assembler_virtualizable(self):
        class Thing(object):
            def __init__(self, val):
                self.val = val

        class Frame(object):
            _virtualizable_ = ['thing']

        driver = JitDriver(greens = ['codeno'], reds = ['i', 's', 'frame'],
                           virtualizables = ['frame'],
                           get_printable_location = lambda codeno : str(codeno))

        def main(codeno):
            frame = Frame()
            frame.thing = Thing(0)
            result = portal(codeno, frame)
            return result

        def portal(codeno, frame):
            i = 0
            s = 0
            while i < 10:
                driver.can_enter_jit(frame=frame, codeno=codeno, i=i, s=s)
                driver.jit_merge_point(frame=frame, codeno=codeno, i=i, s=s)
                nextval = frame.thing.val
                if codeno == 0:
                    subframe = Frame()
                    subframe.thing = Thing(nextval)
                    nextval = portal(1, subframe)
                    s += subframe.thing.val
                frame.thing = Thing(nextval + 1)
                i += 1
            return frame.thing.val + s

        res = self.meta_interp(main, [0], inline=True)
        self.check_resops(call=0, cond_call=2)
        assert res == main(0)

    def test_directly_call_assembler_virtualizable_reset_token(self):
        py.test.skip("not applicable any more, I think")
        from rpython.rtyper.lltypesystem import lltype
        from rpython.rlib.debug import llinterpcall

        class Thing(object):
            def __init__(self, val):
                self.val = val

        class Frame(object):
            _virtualizable_ = ['thing']

        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
                           virtualizables = ['frame'],
                           get_printable_location = lambda codeno : str(codeno))

        @dont_look_inside
        def check_frame(subframe):
            if we_are_translated():
                llinterpcall(lltype.Void, check_ll_frame, subframe)
        def check_ll_frame(ll_subframe):
            # This is called with the low-level Struct that is the frame.
            # Check that the vable_token was correctly reset to zero.
            # Note that in order for that test to catch failures, it needs
            # three levels of recursion: the vable_token of the subframe
            # at the level 2 is set to a non-zero value when doing the
            # call to the level 3 only.  This used to fail when the test
            # is run via rpython.jit.backend.x86.test.test_recursive.
            from rpython.jit.metainterp.virtualizable import TOKEN_NONE
            assert ll_subframe.vable_token == TOKEN_NONE

        def main(codeno):
            frame = Frame()
            frame.thing = Thing(0)
            portal(codeno, frame)
            return frame.thing.val

        def portal(codeno, frame):
            i = 0
            while i < 5:
                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
                nextval = frame.thing.val
                if codeno < 2:
                    subframe = Frame()
                    subframe.thing = Thing(nextval)
                    nextval = portal(codeno + 1, subframe)
                    check_frame(subframe)
                frame.thing = Thing(nextval + 1)
                i += 1
            return frame.thing.val

        res = self.meta_interp(main, [0], inline=True)
        assert res == main(0)

    def test_directly_call_assembler_virtualizable_force1(self):
        class Thing(object):
            def __init__(self, val):
                self.val = val

        class Frame(object):
            _virtualizable_ = ['thing']

        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
                           virtualizables = ['frame'],
                           get_printable_location = lambda codeno : str(codeno))
        class SomewhereElse(object):
            pass

        somewhere_else = SomewhereElse()

        def change(newthing):
            somewhere_else.frame.thing = newthing

        def main(codeno):
            frame = Frame()
            somewhere_else.frame = frame
            frame.thing = Thing(0)
            portal(codeno, frame)
            return frame.thing.val

        def portal(codeno, frame):
            print 'ENTER:', codeno, frame.thing.val
            i = 0
            while i < 10:
                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
                nextval = frame.thing.val
                if codeno == 0:
                    subframe = Frame()
                    subframe.thing = Thing(nextval)
                    nextval = portal(1, subframe)
                elif codeno == 1:
                    if frame.thing.val > 40:
                        change(Thing(13))
                        nextval = 13
                else:
                    fatalerror("bad codeno = " + str(codeno))
                frame.thing = Thing(nextval + 1)
                i += 1
            print 'LEAVE:', codeno, frame.thing.val
            return frame.thing.val

        res = self.meta_interp(main, [0], inline=True,
                               policy=StopAtXPolicy(change))
        assert res == main(0)

    def test_directly_call_assembler_virtualizable_with_array(self):
        myjitdriver = JitDriver(greens = ['codeno'], reds = ['n', 'x', 'frame'],
                                virtualizables = ['frame'])

        class Frame(object):
            _virtualizable_ = ['l[*]', 's']

            def __init__(self, l, s):
                self = hint(self, access_directly=True,
                            fresh_virtualizable=True)
                self.l = l
                self.s = s

        def main(codeno, n, a):
            frame = Frame([a, a+1, a+2, a+3], 0)
            return f(codeno, n, a, frame)

        def f(codeno, n, a, frame):
            x = 0
            while n > 0:
                myjitdriver.can_enter_jit(codeno=codeno, frame=frame, n=n, x=x)
                myjitdriver.jit_merge_point(codeno=codeno, frame=frame, n=n,
                                            x=x)
                frame.s = promote(frame.s)
                n -= 1
                s = frame.s
                assert s >= 0
                x += frame.l[s]
                frame.s += 1
                if codeno == 0:
                    subframe = Frame([n, n+1, n+2, n+3], 0)
                    x += f(1, 10, 1, subframe)
                s = frame.s
                assert s >= 0
                x += frame.l[s]
                x += len(frame.l)
                frame.s -= 1
            return x

        res = self.meta_interp(main, [0, 10, 1], listops=True, inline=True)
        assert res == main(0, 10, 1)

    def test_directly_call_assembler_virtualizable_force_blackhole(self):
        class Thing(object):
            def __init__(self, val):
                self.val = val

        class Frame(object):
            _virtualizable_ = ['thing']

        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
                           virtualizables = ['frame'],
                           get_printable_location = lambda codeno : str(codeno))
        class SomewhereElse(object):
            pass

        somewhere_else = SomewhereElse()

        def change(newthing, arg):
            print arg
            if arg > 30:
                somewhere_else.frame.thing = newthing
                arg = 13
            return arg

        def main(codeno):
            frame = Frame()
            somewhere_else.frame = frame
            frame.thing = Thing(0)
            portal(codeno, frame)
            return frame.thing.val

        def portal(codeno, frame):
            i = 0
            while i < 10:
                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
                nextval = frame.thing.val
                if codeno == 0:
                    subframe = Frame()
                    subframe.thing = Thing(nextval)
                    nextval = portal(1, subframe)
                else:
                    nextval = change(Thing(13), frame.thing.val)
                frame.thing = Thing(nextval + 1)
                i += 1
            return frame.thing.val

        res = self.meta_interp(main, [0], inline=True,
                               policy=StopAtXPolicy(change))
        assert res == main(0)

    def test_assembler_call_red_args(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
                           get_printable_location = lambda codeno : str(codeno))

        def residual(k):
            if k > 150:
                return 0
            return 1

        def portal(codeno, k):
            i = 0
            while i < 15:
                driver.can_enter_jit(codeno=codeno, i=i, k=k)
                driver.jit_merge_point(codeno=codeno, i=i, k=k)
                if codeno == 2:
                    k += portal(residual(k), k)
                if codeno == 0:
                    k += 2
                elif codeno == 1:
                    k += 1
                i += 1
            return k

        res = self.meta_interp(portal, [2, 0], inline=True,
                               policy=StopAtXPolicy(residual))
        assert res == portal(2, 0)
        self.check_resops(call_assembler_i=4)

    def test_inline_without_hitting_the_loop(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno):
            i = 0
            while True:
                driver.jit_merge_point(codeno=codeno, i=i)
                if codeno < 10:
                    i += portal(20)
                    codeno += 1
                elif codeno == 10:
                    if i > 63:
                        return i
                    codeno = 0
                    driver.can_enter_jit(codeno=codeno, i=i)
                else:
                    return 1

        assert portal(0) == 70
        res = self.meta_interp(portal, [0], inline=True)
        assert res == 70
        self.check_resops(call_assembler=0)

    def test_inline_with_hitting_the_loop_sometimes(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
                           get_printable_location = lambda codeno : str(codeno))

        def portal(codeno, k):
            if k > 2:
                return 1
            i = 0
            while True:
                driver.jit_merge_point(codeno=codeno, i=i, k=k)
                if codeno < 10:
                    i += portal(codeno + 5, k+1)
                    codeno += 1
                elif codeno == 10:
                    if i > [-1, 2000, 63][k]:
                        return i
                    codeno = 0
                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
                else:
                    return 1

        assert portal(0, 1) == 2095
        res = self.meta_interp(portal, [0, 1], inline=True)
        assert res == 2095
        self.check_resops(call_assembler_i=12)

    def test_inline_with_hitting_the_loop_sometimes_exc(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
                           get_printable_location = lambda codeno : str(codeno))
        class GotValue(Exception):
            def __init__(self, result):
                self.result = result

        def portal(codeno, k):
            if k > 2:
                raise GotValue(1)
            i = 0
            while True:
                driver.jit_merge_point(codeno=codeno, i=i, k=k)
                if codeno < 10:
                    try:
                        portal(codeno + 5, k+1)
                    except GotValue as e:
                        i += e.result
                    codeno += 1
                elif codeno == 10:
                    if i > [-1, 2000, 63][k]:
                        raise GotValue(i)
                    codeno = 0
                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
                else:
                    raise GotValue(1)

        def main(codeno, k):
            try:
                portal(codeno, k)
            except GotValue as e:
                return e.result

        assert main(0, 1) == 2095
        res = self.meta_interp(main, [0, 1], inline=True)
        assert res == 2095
        self.check_resops(call_assembler_n=12)

    def test_inline_recursion_limit(self):
        driver = JitDriver(greens = ["threshold", "loop"], reds=["i"])
        @dont_look_inside
        def f():
            set_param(driver, "max_unroll_recursion", 10)
        def portal(threshold, loop, i):
            f()
            if i > threshold:
                return i
            while True:
                driver.jit_merge_point(threshold=threshold, loop=loop, i=i)
                if loop:
                    portal(threshold, False, 0)
                else:
                    portal(threshold, False, i + 1)
                    return i
                if i > 10:
                    return 1
                i += 1
                driver.can_enter_jit(threshold=threshold, loop=loop, i=i)

        res1 = portal(10, True, 0)
        res2 = self.meta_interp(portal, [10, True, 0], inline=True)
        assert res1 == res2
        self.check_resops(call_assembler_i=2)

        res1 = portal(9, True, 0)
        res2 = self.meta_interp(portal, [9, True, 0], inline=True)
        assert res1 == res2
        self.check_resops(call_assembler=0)

    def test_handle_jitexception_in_portal(self):
        # a test for _handle_jitexception_in_portal in blackhole.py
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
                           get_printable_location = lambda codeno: str(codeno))
        def do_can_enter_jit(codeno, i, str):
            i = (i+1)-1    # some operations
            driver.can_enter_jit(codeno=codeno, i=i, str=str)
        def intermediate(codeno, i, str):
            if i == 9:
                do_can_enter_jit(codeno, i, str)
        def portal(codeno, str):
            i = value.initial
            while i < 10:
                intermediate(codeno, i, str)
                driver.jit_merge_point(codeno=codeno, i=i, str=str)
                i += 1
                if codeno == 64 and i == 10:
                    str = portal(96, str)
                str += chr(codeno+i)
            return str
        class Value:
            initial = -1
        value = Value()
        def main():
            value.initial = 0
            return (portal(64, '') +
                    portal(64, '') +
                    portal(64, '') +
                    portal(64, '') +
                    portal(64, ''))
        assert main() == 'ABCDEFGHIabcdefghijJ' * 5
        for tlimit in [95, 90, 102]:
            print 'tlimit =', tlimit
            res = self.meta_interp(main, [], inline=True, trace_limit=tlimit)
            assert ''.join(res.chars) == 'ABCDEFGHIabcdefghijJ' * 5

    def test_handle_jitexception_in_portal_returns_void(self):
        # a test for _handle_jitexception_in_portal in blackhole.py
        driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
                           get_printable_location = lambda codeno: str(codeno))
        def do_can_enter_jit(codeno, i, str):
            i = (i+1)-1    # some operations
            driver.can_enter_jit(codeno=codeno, i=i, str=str)
        def intermediate(codeno, i, str):
            if i == 9:
                do_can_enter_jit(codeno, i, str)
        def portal(codeno, str):
            i = value.initial
            while i < 10:
                intermediate(codeno, i, str)
                driver.jit_merge_point(codeno=codeno, i=i, str=str)
                i += 1
                if codeno == 64 and i == 10:
                    portal(96, str)
                str += chr(codeno+i)
        class Value:
            initial = -1
        value = Value()
        def main():
            value.initial = 0
            portal(64, '')
            portal(64, '')
            portal(64, '')
            portal(64, '')
            portal(64, '')
        main()
        for tlimit in [95, 90, 102]:
            print 'tlimit =', tlimit
            self.meta_interp(main, [], inline=True, trace_limit=tlimit)

    def test_no_duplicates_bug(self):
        driver = JitDriver(greens = ['codeno'], reds = ['i'],
                           get_printable_location = lambda codeno: str(codeno))
        def portal(codeno, i):
            while i > 0:
                driver.can_enter_jit(codeno=codeno, i=i)
                driver.jit_merge_point(codeno=codeno, i=i)
                if codeno > 0:
                    break
                portal(i, i)
                i -= 1
        self.meta_interp(portal, [0, 10], inline=True)

    def test_trace_from_start_always(self):
        from rpython.rlib.nonconst import NonConstant

        driver = JitDriver(greens = ['c'], reds = ['i', 'v'])

        def portal(c, i, v):
            while i > 0:
                driver.jit_merge_point(c=c, i=i, v=v)
                portal(c, i - 1, v)
                if v:
                    driver.can_enter_jit(c=c, i=i, v=v)
                break

        def main(c, i, _set_param, v):
            if _set_param:
                set_param(driver, 'function_threshold', 0)
            portal(c, i, v)

        self.meta_interp(main, [10, 10, False, False], inline=True)
        self.check_jitcell_token_count(1)
        self.check_trace_count(1)
        self.meta_interp(main, [3, 10, True, False], inline=True)
        self.check_jitcell_token_count(0)
        self.check_trace_count(0)

    def test_trace_from_start_does_not_prevent_inlining(self):
        driver = JitDriver(greens = ['c', 'bc'], reds = ['i'])

        def portal(bc, c, i):
            while True:
                driver.jit_merge_point(c=c, bc=bc, i=i)
                if bc == 0:
                    portal(1, 8, 0)
                    c += 1
                else:
                    return
                if c == 10: # bc == 0
                    c = 0
                    if i >= 100:
                        return
                    driver.can_enter_jit(c=c, bc=bc, i=i)
                i += 1

        self.meta_interp(portal, [0, 0, 0], inline=True)
        self.check_resops(call_may_force=0, call=0)

    def test_dont_repeatedly_trace_from_the_same_guard(self):
        driver = JitDriver(greens = [], reds = ['level', 'i'])

        def portal(level):
            if level == 0:
                i = -10
            else:
                i = 0
            #
            while True:
                driver.jit_merge_point(level=level, i=i)
                if level == 25:
                    return 42
                i += 1
                if i <= 0:      # <- guard
                    continue    # first make a loop
                else:
                    # then we fail the guard above, doing a recursive call,
                    # which will itself fail the same guard above, and so on
                    return portal(level + 1)

        self.meta_interp(portal, [0])
        self.check_trace_count_at_most(2)   # and not, e.g., 24

    def test_get_unique_id(self):
        lst = []
        
        def reg_codemap(self, (start, size, l)):
            lst.append((start, size))
            old_reg_codemap(self, (start, size, l))
        
        old_reg_codemap = codemap.CodemapStorage.register_codemap
        try:
            codemap.CodemapStorage.register_codemap = reg_codemap
            def get_unique_id(pc, code):
                return (code + 1) * 2

            driver = JitDriver(greens=["pc", "code"], reds='auto',
                               get_unique_id=get_unique_id, is_recursive=True)

            def f(pc, code):
                i = 0
                while i < 10:
                    driver.jit_merge_point(pc=pc, code=code)
                    pc += 1
                    if pc == 3:
                        if code == 1:
                            f(0, 0)
                        pc = 0
                    i += 1

            self.meta_interp(f, [0, 1], inline=True)
            self.check_get_unique_id(lst) # overloaded on assembler backends
        finally:
            codemap.CodemapStorage.register_codemap = old_reg_codemap

    def check_get_unique_id(self, lst):
        pass

class TestLLtype(RecursiveTests, LLJitMixin):
    pass
