From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com>
Origin: https://github.com/numba/numba/pull/8545
Date: Wed, 28 Sep 2022 17:18:36 -0500
Subject: Support with-lifting

test_withlifting tests report errors=31, skipped=2, expected failures=2 out of 68 tests.
A lot of the tests are failing due to cloudpickle incompatibility.
---
 numba/core/bytecode.py    | 21 +++++++++++++---
 numba/core/byteflow.py    | 58 ++++++++++++++++++++++++++++++++++----------
 numba/core/interpreter.py | 62 +++++++++++++++++++++++++++++++++++++++++------
 numba/core/transforms.py  |  4 +--
 4 files changed, 117 insertions(+), 28 deletions(-)

diff --git a/numba/core/bytecode.py b/numba/core/bytecode.py
index 1ccafce..f2882eb 100644
--- a/numba/core/bytecode.py
+++ b/numba/core/bytecode.py
@@ -226,7 +226,17 @@ class ByteCode(object):
         self.co_consts = code.co_consts
         self.co_cellvars = code.co_cellvars
         self.co_freevars = code.co_freevars
-        self.exception_entries = dis.Bytecode(code).exception_entries
+
+        def fixup_eh(ent):
+            from dis import _ExceptionTableEntry
+            out = _ExceptionTableEntry(
+                start=ent.start + 2, end=ent.end + 2, target=ent.target + 2,
+                depth=ent.depth, lasti=ent.lasti,
+            )
+            return out
+
+        entries = dis.Bytecode(code).exception_entries
+        self.exception_entries = tuple(map(fixup_eh, entries))
         self.table = table
         self.labels = sorted(labels)
 
@@ -306,13 +316,16 @@ class ByteCode(object):
         return self._compute_used_globals(self.func_id.func, self.table,
                                           self.co_consts, self.co_names)
 
-    def get_exception_entry(self, offset):
+    def find_exception_entry(self, offset):
         """
         Returns the exception entry for the given instruction offset
         """
+        candidates = []
         for ent in self.exception_entries:
-            if offset in range(ent.start, ent.end):
-                return ent
+            if ent.start <= offset <= ent.end:
+                candidates.append((ent.depth, ent))
+        ent = max(candidates)[1]
+        return ent
 
 class FunctionIdentity(serialize.ReduceMixin):
     """
diff --git a/numba/core/byteflow.py b/numba/core/byteflow.py
index f5092f5..8bd9dd1 100644
--- a/numba/core/byteflow.py
+++ b/numba/core/byteflow.py
@@ -10,11 +10,11 @@ from functools import total_ordering
 from numba.core.utils import UniqueDict, PYVERSION
 from numba.core.controlflow import NEW_BLOCKERS, CFGraph
 from numba.core.ir import Loc
-from numba.core.errors import UnsupportedError
+from numba.core.errors import UnsupportedError, CompilerError
 
 
 _logger = logging.getLogger(__name__)
-
+# logging.basicConfig(level=logging.DEBUG)
 
 _EXCEPT_STACK_OFFSET = 6
 _FINALLY_POP = _EXCEPT_STACK_OFFSET if PYVERSION >= (3, 8) else 1
@@ -254,7 +254,7 @@ class Flow(object):
         than a POP_TOP, if it is something else it'll be some sort of store
         which is not supported (this corresponds to `with CTXMGR as VAR(S)`)."""
         current_inst = state.get_inst()
-        if current_inst.opname == "SETUP_WITH":
+        if current_inst.opname in {"SETUP_WITH", "BEFORE_WITH"}:
             next_op = self._bytecode[current_inst.next].opname
             if next_op != "POP_TOP":
                 msg = ("The 'with (context manager) as "
@@ -263,6 +263,10 @@ class Flow(object):
                 raise UnsupportedError(msg)
 
 
+def _is_null_temp_reg(reg):
+    return reg.startswith("$null$")
+
+
 class TraceRunner(object):
     """Trace runner contains the states for the trace and the opcode dispatch.
     """
@@ -275,9 +279,18 @@ class TraceRunner(object):
         return Loc(self.debug_filename, lineno)
 
     def dispatch(self, state):
+        if PYVERSION == (3, 11) and state._blockstack:
+            state: State
+            while state._blockstack:
+                topblk = state._blockstack[-1]
+                if topblk['end'] <= state.pc_initial:
+                    state._blockstack.pop()
+                else:
+                    break
         inst = state.get_inst()
-        _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst)
-        _logger.debug("stack %s", state._stack)
+        if inst.opname != "CACHE":
+            _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst)
+            _logger.debug("stack %s", state._stack)
         fn = getattr(self, "op_{}".format(inst.opname), None)
         if fn is not None:
             fn(state, inst)
@@ -298,6 +311,7 @@ class TraceRunner(object):
         state.append(inst)
 
     def op_PUSH_NULL(self, state, inst):
+        state.push(state.make_null())
         state.append(inst)
 
     def op_RETURN_GENERATOR(self, state, inst):
@@ -355,9 +369,9 @@ class TraceRunner(object):
             res = state.make_temp()
             idx = inst.arg >> 1
             state.append(inst, idx=idx, res=res)
-            ## ignoring the NULL
-            # if inst.arg & 1:
-                # state.push(state.make_temp())
+            # ignoring the NULL
+            if inst.arg & 1:
+                state.push(state.make_null())
             state.push(res)
     else:
         def op_LOAD_GLOBAL(self, state, inst):
@@ -769,13 +783,20 @@ class TraceRunner(object):
 
         yielded = state.make_temp()
         exitfn = state.make_temp(prefix='setup_with_exitfn')
-        state.append(inst, contextmanager=cm, exitfn=exitfn)
 
         state.push(exitfn)
         state.push(yielded)
 
-        end = state._bytecode.get_exception_entry(inst.next).target
-
+        # Gather all exception entries for this WITH. There maybe multiple
+        # entries; esp. for nested WITHs.
+        bc = state._bytecode
+        ehhead = bc.find_exception_entry(inst.next)
+        ehrelated = [ehhead]
+        for eh in bc.exception_entries:
+            if eh.target == ehhead.target:
+                ehrelated.append(eh)
+        end = max(eh.end for eh in ehrelated)
+        state.append(inst, contextmanager=cm, exitfn=exitfn, end=end)
 
         state.push_block(
             state.make_block(
@@ -889,7 +910,10 @@ class TraceRunner(object):
         narg = inst.arg
         args = list(reversed([state.pop() for _ in range(narg)]))
         func = state.pop()
-
+        extra = state.pop()  # XXX need to see if it's NULL
+        if not _is_null_temp_reg(extra):
+            func = extra
+            args = [extra, *args]
         res = state.make_temp()
 
         kw_names = state.pop_kw_names()
@@ -1290,7 +1314,12 @@ class TraceRunner(object):
     # of LOAD_METHOD and CALL_METHOD.
 
     def op_LOAD_METHOD(self, state, inst):
-        self.op_LOAD_ATTR(state, inst)
+        item = state.pop()
+        extra = state.make_null()
+        state.push(extra)
+        res = state.make_temp()
+        state.append(inst, item=item, res=res)
+        state.push(res)
 
     def op_CALL_METHOD(self, state, inst):
         self.op_CALL_FUNCTION(state, inst)
@@ -1433,6 +1462,9 @@ class State(object):
         self._temp_registers.append(name)
         return name
 
+    def make_null(self):
+        return self.make_temp(prefix="null$")
+
     def append(self, inst, **kwargs):
         """Append new inst"""
         self._insts.append((inst.offset, kwargs))
diff --git a/numba/core/interpreter.py b/numba/core/interpreter.py
index 2db6c84..ec590b7 100644
--- a/numba/core/interpreter.py
+++ b/numba/core/interpreter.py
@@ -1224,6 +1224,44 @@ def peep_hole_fuse_dict_add_updates(func_ir):
     return func_ir
 
 
+def peep_hole_split_at_pop_block(func_ir):
+    """
+    Split blocks that contain ir.PopBlock
+    """
+    newblocks = {}
+    for label, blk in func_ir.blocks.items():
+        for i, inst in enumerate(blk.body):
+            if isinstance(inst, ir.PopBlock):
+                head = blk.body[:i]
+                mid = blk.body[i:i + 1]
+                tail = blk.body[i + 1:]
+                if head:
+                    blk.body.clear()
+                    blk.body.extend(head)
+
+                    midblk = ir.Block(blk.scope, loc=blk.loc)
+                    midblk.body.extend(mid)
+
+                    midlabel = label + i * 2
+                    newblocks[midlabel] = midblk
+
+                    blk.body.append(ir.Jump(midlabel, loc=blk.loc))
+                else:
+                    blk.body.clear()
+                    blk.body.extend(mid)
+                    midblk = blk
+
+                tailblk = ir.Block(blk.scope, loc=blk.loc)
+                tailblk.body.extend(tail)
+                taillabel = label + (i + 1) * 2
+                newblocks[taillabel] = tailblk
+
+                midblk.append(ir.Jump(taillabel, loc=blk.loc))
+
+    func_ir.blocks.update(newblocks)
+    return func_ir
+
+
 def _build_new_build_map(func_ir, name, old_body, old_lineno, new_items):
     """
     Create a new build_map with a new set of key/value items
@@ -1344,6 +1382,8 @@ class Interpreter(object):
         # post process the IR to rewrite opcodes/byte sequences that are too
         # involved to risk handling as part of direct interpretation
         peepholes = []
+        if PYVERSION == (3, 11):
+            peepholes.append(peep_hole_split_at_pop_block)
         if PYVERSION in [(3, 9), (3, 10)]:
             peepholes.append(peep_hole_list_to_tuple)
         peepholes.append(peep_hole_delete_with_exit)
@@ -1434,7 +1474,9 @@ class Interpreter(object):
         # Check out-of-scope syntactic-block
         while self.syntax_blocks:
             if offset >= self.syntax_blocks[-1].exit:
-                self.syntax_blocks.pop()
+                synblk = self.syntax_blocks.pop()
+                if isinstance(synblk, ir.With):
+                    self.current_block.append(ir.PopBlock(self.loc))
             else:
                 break
 
@@ -1637,6 +1679,13 @@ class Interpreter(object):
 
     def _dispatch(self, inst, kws):
         assert self.current_block is not None
+        if self.syntax_blocks:
+            top = self.syntax_blocks[-1]
+            if isinstance(top, ir.With) :
+                if inst.offset >= top.exit:
+                    self.current_block.append(ir.PopBlock(loc=self.loc))
+                    self.syntax_blocks.pop()
+
         fname = "op_%s" % inst.opname.replace('+', '_')
         try:
             fn = getattr(self, fname)
@@ -2136,23 +2185,20 @@ class Interpreter(object):
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
-    def op_BEFORE_WITH(self, inst, contextmanager, exitfn=None):
+    def op_BEFORE_WITH(self, inst, contextmanager, exitfn=None, end=None):
         assert self.blocks[inst.offset] is self.current_block
-        # use EH entry to determine the end of the with
-        exitpt = self.bytecode.get_exception_entry(inst.next).target
         # Handle with
-        wth = ir.With(inst.offset, exit=exitpt)
+        wth = ir.With(inst.offset, exit=end)
         self.syntax_blocks.append(wth)
         ctxmgr = self.get(contextmanager)
         self.current_block.append(ir.EnterWith(contextmanager=ctxmgr,
                                                begin=inst.offset,
-                                               end=exitpt, loc=self.loc,))
+                                               end=end, loc=self.loc,))
 
-        # Store exit fn
+        # Store exit f
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
-
     def op_SETUP_EXCEPT(self, inst):
         # Removed since python3.8
         self._insert_try_block_begin()
diff --git a/numba/core/transforms.py b/numba/core/transforms.py
index 1169c7d..e5398cd 100644
--- a/numba/core/transforms.py
+++ b/numba/core/transforms.py
@@ -588,7 +588,6 @@ def find_setupwiths(func_ir):
     # rewrite the CFG in case there are multiple POP_BLOCK statements for one
     # with
     func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
-
     # here we need to turn the withs back into a list of tuples so that the
     # rest of the code can cope
     with_ranges_tuple = [(s, list(p)[0])
@@ -619,7 +618,7 @@ def find_setupwiths(func_ir):
     # now we need to rewrite the tuple such that we have SETUP_WITH matching the
     # successor of the block that contains the POP_BLOCK.
     with_ranges_tuple = [(s, func_ir.blocks[p].terminator.get_targets()[0])
-             for (s, p) in with_ranges_tuple]
+                         for (s, p) in with_ranges_tuple]
 
     # finally we check for nested with statements and reject them
     with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
@@ -754,7 +753,6 @@ def _eliminate_nested_withs(with_ranges):
 def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
     """Modify the FunctionIR to merge the exit blocks of with constructs.
     """
-    out = []
     for k in withs:
         vs : set = withs[k]
         if len(vs) > 1:
