From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com>
Origin: https://github.com/numba/numba/pull/8545
Date: Wed, 30 Nov 2022 12:37:48 -0600
Subject: Address reviews in interpreter.py

---
 numba/core/interpreter.py | 84 +++++++++++++++++++++++++++--------------------
 1 file changed, 48 insertions(+), 36 deletions(-)

diff --git a/numba/core/interpreter.py b/numba/core/interpreter.py
index b54da77..d4a8fe5 100644
--- a/numba/core/interpreter.py
+++ b/numba/core/interpreter.py
@@ -10,11 +10,10 @@ from numba.core.errors import NotDefinedError, UnsupportedError, error_extras
 from numba.core.ir_utils import get_definition, guard
 from numba.core.utils import (PYVERSION, BINOPS_TO_OPERATORS,
                               INPLACE_BINOPS_TO_OPERATORS,)
-from numba.core.byteflow import Flow, AdaptDFA, AdaptCFA
+from numba.core.byteflow import Flow, AdaptDFA, AdaptCFA, BlockKind
 from numba.core.unsafe import eh
 from numba.cpython.unsafe.tuple import unpack_single_tuple
 
-
 class _UNKNOWN_VALUE(object):
     """Represents an unknown value, this is for ease of debugging purposes only.
     """
@@ -1227,6 +1226,9 @@ def peep_hole_fuse_dict_add_updates(func_ir):
 def peep_hole_split_at_pop_block(func_ir):
     """
     Split blocks that contain ir.PopBlock.
+
+    This rewrite restores the IR structure to pre 3.11 so that withlifting
+    can work correctly.
     """
     new_block_map = {}
     sorted_blocks = sorted(func_ir.blocks.items())
@@ -1379,17 +1381,14 @@ class Interpreter(object):
 
         self.scopes.append(ir.Scope(parent=self.current_scope, loc=self.loc))
 
-        # Gather exception info block info
-
-        # for block in self.cfa.iterliveblocks():
-        #     dfainfo = self.dfa.infos[block.offset]
-        #     print('---', block.offset, '[[[[', dfainfo.blockstack)
-
         # Interpret loop
         for inst, kws in self._iter_inst():
             self._dispatch(inst, kws)
         if PYVERSION == (3, 11):
+            # Insert end of try markers
             self._end_try_blocks()
+        elif PYVERSION > (3, 11):
+            raise NotImplementedError(PYVERSION)
         self._legalize_exception_vars()
         # Prepare FunctionIR
         func_ir = ir.FunctionIR(self.blocks, self.is_generator, self.func_id,
@@ -1423,39 +1422,50 @@ class Interpreter(object):
         return func_ir
 
     def _end_try_blocks(self):
+        """Closing all try blocks by inserting the required marker at the
+        exception handler
+
+        This is only needed for py3.11 because of the changes in exception
+        handling. This merely maps the new py3.11 semantic back to the old way.
+
+        What the code does:
+
+        - For each block, compute the difference of blockstack to its incoming
+          blocks' blockstack.
+        - If the incoming blockstack has an extra TRY, the current block must
+          be the EXCEPT block and we need to insert a marker.
+
+        See also: _insert_try_block_end
+        """
         assert PYVERSION == (3, 11)
         graph = self.cfa.graph
         for offset, block in self.blocks.items():
-            cur_bs = inc_bs = self.dfa.infos[offset].blockstack
+            # Get current blockstack
+            cur_bs = self.dfa.infos[offset].blockstack
+            # Check blockstack of the incoming blocks
             for inc, _ in graph.predecessors(offset):
                 inc_bs = self.dfa.infos[inc].blockstack
 
-                # find first diff
+                # find first diff in the blockstack
                 for i, (x, y) in enumerate(zip(cur_bs, inc_bs)):
                     if x != y:
-                        # print(f"mismatch {x} != {y}")
                         break
                 else:
                     i = min(len(cur_bs), len(inc_bs))
 
-                remain = list(inc_bs[i:])
-
-                # print("==", offset, "|", remain)
-
                 def do_change(remain):
-                    if remain:
-                        while remain:
-                            ent = remain.pop()
-                            from .byteflow import BlockKind
-                            if ent['kind'] == BlockKind('TRY'):
-                                self.current_block = block
-                                oldbody = list(block.body)
-                                block.body.clear()
-                                self._insert_try_block_end()
-                                block.body.extend(oldbody)
-                                return True
-
-                if do_change(remain):
+                    while remain:
+                        ent = remain.pop()
+                        if ent['kind'] == BlockKind('TRY'):
+                            # Extend block with marker for end of try
+                            self.current_block = block
+                            oldbody = list(block.body)
+                            block.body.clear()
+                            self._insert_try_block_end()
+                            block.body.extend(oldbody)
+                            return True
+
+                if do_change(list(inc_bs[i:])):
                     break
 
     def _legalize_exception_vars(self):
@@ -1530,6 +1540,7 @@ class Interpreter(object):
         self.assigner = Assigner()
         # Check out-of-scope syntactic-block
         if PYVERSION == (3, 11):
+            # This is recreating pre-3.11 code structure
             while self.syntax_blocks:
                 if offset >= self.syntax_blocks[-1].exit:
                     synblk = self.syntax_blocks.pop()
@@ -1699,6 +1710,7 @@ class Interpreter(object):
                 val = self.get(varname)
             except ir.NotDefinedError:
                 # Hack to make sure exception variables are defined
+                assert PYVERSION == (3, 11), "unexpected missing definition"
                 val = ir.Const(value=None, loc=self.loc)
             stmt = ir.Assign(value=val, target=target,
                              loc=self.loc)
@@ -2268,7 +2280,7 @@ class Interpreter(object):
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
-    def op_BEFORE_WITH(self, inst, contextmanager, exitfn=None, end=None):
+    def op_BEFORE_WITH(self, inst, contextmanager, exitfn, end):
         assert self.blocks[inst.offset] is self.current_block
         # Handle with
         wth = ir.With(inst.offset, exit=end)
@@ -2278,7 +2290,7 @@ class Interpreter(object):
                                                begin=inst.offset,
                                                end=end, loc=self.loc,))
 
-        # Store exit f
+        # Store exit function
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
@@ -2880,16 +2892,16 @@ class Interpreter(object):
 
         op = BINOPS_TO_OPERATORS["is"]
 
-        lhs = self.store(value=ir.Const(None, loc=self.loc),
+        constnone = self.store(value=ir.Const(None, loc=self.loc),
                          name="${inst.offset}constnone")
-        rhs = self.get(pred)
-        isnone = ir.Expr.binop(op, lhs=lhs, rhs=rhs, loc=self.loc)
+        pred = self.get(pred)
+        isnone = ir.Expr.binop(op, lhs=pred, rhs=constnone, loc=self.loc)
 
         pname = "$%spred" % (inst.offset)
         predicate = self.store(value=isnone, name=pname)
-        bra = ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr,
-                        loc=self.loc)
-        self.current_block.append(bra)
+        branch = ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr,
+                           loc=self.loc)
+        self.current_block.append(branch)
 
     def op_POP_JUMP_FORWARD_IF_NONE(self, inst, pred):
         self._jump_if_none(inst, pred, True)
