File: 0025-Support-with-lifting.patch

package info (click to toggle)
numba 0.56.4%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 23,672 kB
  • sloc: python: 183,651; ansic: 15,370; cpp: 2,259; javascript: 424; sh: 308; makefile: 174
file content (337 lines) | stat: -rw-r--r-- 13,004 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com>
Origin: https://github.com/numba/numba/pull/8545
Date: Wed, 28 Sep 2022 17:18:36 -0500
Subject: Support with-lifting

test_withlifting tests report errors=31, skipped=2, expected failures=2 out of 68 tests.
A lot of the tests are failing due to cloudpickle incompatibility.
---
 numba/core/bytecode.py    | 21 +++++++++++++---
 numba/core/byteflow.py    | 58 ++++++++++++++++++++++++++++++++++----------
 numba/core/interpreter.py | 62 +++++++++++++++++++++++++++++++++++++++++------
 numba/core/transforms.py  |  4 +--
 4 files changed, 117 insertions(+), 28 deletions(-)

diff --git a/numba/core/bytecode.py b/numba/core/bytecode.py
index 1ccafce..f2882eb 100644
--- a/numba/core/bytecode.py
+++ b/numba/core/bytecode.py
@@ -226,7 +226,17 @@ class ByteCode(object):
         self.co_consts = code.co_consts
         self.co_cellvars = code.co_cellvars
         self.co_freevars = code.co_freevars
-        self.exception_entries = dis.Bytecode(code).exception_entries
+
+        def fixup_eh(ent):
+            from dis import _ExceptionTableEntry
+            out = _ExceptionTableEntry(
+                start=ent.start + 2, end=ent.end + 2, target=ent.target + 2,
+                depth=ent.depth, lasti=ent.lasti,
+            )
+            return out
+
+        entries = dis.Bytecode(code).exception_entries
+        self.exception_entries = tuple(map(fixup_eh, entries))
         self.table = table
         self.labels = sorted(labels)
 
@@ -306,13 +316,16 @@ class ByteCode(object):
         return self._compute_used_globals(self.func_id.func, self.table,
                                           self.co_consts, self.co_names)
 
-    def get_exception_entry(self, offset):
+    def find_exception_entry(self, offset):
         """
         Returns the exception entry for the given instruction offset
         """
+        candidates = []
         for ent in self.exception_entries:
-            if offset in range(ent.start, ent.end):
-                return ent
+            if ent.start <= offset <= ent.end:
+                candidates.append((ent.depth, ent))
+        ent = max(candidates)[1]
+        return ent
 
 class FunctionIdentity(serialize.ReduceMixin):
     """
diff --git a/numba/core/byteflow.py b/numba/core/byteflow.py
index f5092f5..8bd9dd1 100644
--- a/numba/core/byteflow.py
+++ b/numba/core/byteflow.py
@@ -10,11 +10,11 @@ from functools import total_ordering
 from numba.core.utils import UniqueDict, PYVERSION
 from numba.core.controlflow import NEW_BLOCKERS, CFGraph
 from numba.core.ir import Loc
-from numba.core.errors import UnsupportedError
+from numba.core.errors import UnsupportedError, CompilerError
 
 
 _logger = logging.getLogger(__name__)
-
+# logging.basicConfig(level=logging.DEBUG)
 
 _EXCEPT_STACK_OFFSET = 6
 _FINALLY_POP = _EXCEPT_STACK_OFFSET if PYVERSION >= (3, 8) else 1
@@ -254,7 +254,7 @@ class Flow(object):
         than a POP_TOP, if it is something else it'll be some sort of store
         which is not supported (this corresponds to `with CTXMGR as VAR(S)`)."""
         current_inst = state.get_inst()
-        if current_inst.opname == "SETUP_WITH":
+        if current_inst.opname in {"SETUP_WITH", "BEFORE_WITH"}:
             next_op = self._bytecode[current_inst.next].opname
             if next_op != "POP_TOP":
                 msg = ("The 'with (context manager) as "
@@ -263,6 +263,10 @@ class Flow(object):
                 raise UnsupportedError(msg)
 
 
+def _is_null_temp_reg(reg):
+    return reg.startswith("$null$")
+
+
 class TraceRunner(object):
     """Trace runner contains the states for the trace and the opcode dispatch.
     """
@@ -275,9 +279,18 @@ class TraceRunner(object):
         return Loc(self.debug_filename, lineno)
 
     def dispatch(self, state):
+        if PYVERSION == (3, 11) and state._blockstack:
+            state: State
+            while state._blockstack:
+                topblk = state._blockstack[-1]
+                if topblk['end'] <= state.pc_initial:
+                    state._blockstack.pop()
+                else:
+                    break
         inst = state.get_inst()
-        _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst)
-        _logger.debug("stack %s", state._stack)
+        if inst.opname != "CACHE":
+            _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst)
+            _logger.debug("stack %s", state._stack)
         fn = getattr(self, "op_{}".format(inst.opname), None)
         if fn is not None:
             fn(state, inst)
@@ -298,6 +311,7 @@ class TraceRunner(object):
         state.append(inst)
 
     def op_PUSH_NULL(self, state, inst):
+        state.push(state.make_null())
         state.append(inst)
 
     def op_RETURN_GENERATOR(self, state, inst):
@@ -355,9 +369,9 @@ class TraceRunner(object):
             res = state.make_temp()
             idx = inst.arg >> 1
             state.append(inst, idx=idx, res=res)
-            ## ignoring the NULL
-            # if inst.arg & 1:
-                # state.push(state.make_temp())
+            # ignoring the NULL
+            if inst.arg & 1:
+                state.push(state.make_null())
             state.push(res)
     else:
         def op_LOAD_GLOBAL(self, state, inst):
@@ -769,13 +783,20 @@ class TraceRunner(object):
 
         yielded = state.make_temp()
         exitfn = state.make_temp(prefix='setup_with_exitfn')
-        state.append(inst, contextmanager=cm, exitfn=exitfn)
 
         state.push(exitfn)
         state.push(yielded)
 
-        end = state._bytecode.get_exception_entry(inst.next).target
-
+        # Gather all exception entries for this WITH. There maybe multiple
+        # entries; esp. for nested WITHs.
+        bc = state._bytecode
+        ehhead = bc.find_exception_entry(inst.next)
+        ehrelated = [ehhead]
+        for eh in bc.exception_entries:
+            if eh.target == ehhead.target:
+                ehrelated.append(eh)
+        end = max(eh.end for eh in ehrelated)
+        state.append(inst, contextmanager=cm, exitfn=exitfn, end=end)
 
         state.push_block(
             state.make_block(
@@ -889,7 +910,10 @@ class TraceRunner(object):
         narg = inst.arg
         args = list(reversed([state.pop() for _ in range(narg)]))
         func = state.pop()
-
+        extra = state.pop()  # XXX need to see if it's NULL
+        if not _is_null_temp_reg(extra):
+            func = extra
+            args = [extra, *args]
         res = state.make_temp()
 
         kw_names = state.pop_kw_names()
@@ -1290,7 +1314,12 @@ class TraceRunner(object):
     # of LOAD_METHOD and CALL_METHOD.
 
     def op_LOAD_METHOD(self, state, inst):
-        self.op_LOAD_ATTR(state, inst)
+        item = state.pop()
+        extra = state.make_null()
+        state.push(extra)
+        res = state.make_temp()
+        state.append(inst, item=item, res=res)
+        state.push(res)
 
     def op_CALL_METHOD(self, state, inst):
         self.op_CALL_FUNCTION(state, inst)
@@ -1433,6 +1462,9 @@ class State(object):
         self._temp_registers.append(name)
         return name
 
+    def make_null(self):
+        return self.make_temp(prefix="null$")
+
     def append(self, inst, **kwargs):
         """Append new inst"""
         self._insts.append((inst.offset, kwargs))
diff --git a/numba/core/interpreter.py b/numba/core/interpreter.py
index 2db6c84..ec590b7 100644
--- a/numba/core/interpreter.py
+++ b/numba/core/interpreter.py
@@ -1224,6 +1224,44 @@ def peep_hole_fuse_dict_add_updates(func_ir):
     return func_ir
 
 
+def peep_hole_split_at_pop_block(func_ir):
+    """
+    Split blocks that contain ir.PopBlock
+    """
+    newblocks = {}
+    for label, blk in func_ir.blocks.items():
+        for i, inst in enumerate(blk.body):
+            if isinstance(inst, ir.PopBlock):
+                head = blk.body[:i]
+                mid = blk.body[i:i + 1]
+                tail = blk.body[i + 1:]
+                if head:
+                    blk.body.clear()
+                    blk.body.extend(head)
+
+                    midblk = ir.Block(blk.scope, loc=blk.loc)
+                    midblk.body.extend(mid)
+
+                    midlabel = label + i * 2
+                    newblocks[midlabel] = midblk
+
+                    blk.body.append(ir.Jump(midlabel, loc=blk.loc))
+                else:
+                    blk.body.clear()
+                    blk.body.extend(mid)
+                    midblk = blk
+
+                tailblk = ir.Block(blk.scope, loc=blk.loc)
+                tailblk.body.extend(tail)
+                taillabel = label + (i + 1) * 2
+                newblocks[taillabel] = tailblk
+
+                midblk.append(ir.Jump(taillabel, loc=blk.loc))
+
+    func_ir.blocks.update(newblocks)
+    return func_ir
+
+
 def _build_new_build_map(func_ir, name, old_body, old_lineno, new_items):
     """
     Create a new build_map with a new set of key/value items
@@ -1344,6 +1382,8 @@ class Interpreter(object):
         # post process the IR to rewrite opcodes/byte sequences that are too
         # involved to risk handling as part of direct interpretation
         peepholes = []
+        if PYVERSION == (3, 11):
+            peepholes.append(peep_hole_split_at_pop_block)
         if PYVERSION in [(3, 9), (3, 10)]:
             peepholes.append(peep_hole_list_to_tuple)
         peepholes.append(peep_hole_delete_with_exit)
@@ -1434,7 +1474,9 @@ class Interpreter(object):
         # Check out-of-scope syntactic-block
         while self.syntax_blocks:
             if offset >= self.syntax_blocks[-1].exit:
-                self.syntax_blocks.pop()
+                synblk = self.syntax_blocks.pop()
+                if isinstance(synblk, ir.With):
+                    self.current_block.append(ir.PopBlock(self.loc))
             else:
                 break
 
@@ -1637,6 +1679,13 @@ class Interpreter(object):
 
     def _dispatch(self, inst, kws):
         assert self.current_block is not None
+        if self.syntax_blocks:
+            top = self.syntax_blocks[-1]
+            if isinstance(top, ir.With) :
+                if inst.offset >= top.exit:
+                    self.current_block.append(ir.PopBlock(loc=self.loc))
+                    self.syntax_blocks.pop()
+
         fname = "op_%s" % inst.opname.replace('+', '_')
         try:
             fn = getattr(self, fname)
@@ -2136,23 +2185,20 @@ class Interpreter(object):
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
-    def op_BEFORE_WITH(self, inst, contextmanager, exitfn=None):
+    def op_BEFORE_WITH(self, inst, contextmanager, exitfn=None, end=None):
         assert self.blocks[inst.offset] is self.current_block
-        # use EH entry to determine the end of the with
-        exitpt = self.bytecode.get_exception_entry(inst.next).target
         # Handle with
-        wth = ir.With(inst.offset, exit=exitpt)
+        wth = ir.With(inst.offset, exit=end)
         self.syntax_blocks.append(wth)
         ctxmgr = self.get(contextmanager)
         self.current_block.append(ir.EnterWith(contextmanager=ctxmgr,
                                                begin=inst.offset,
-                                               end=exitpt, loc=self.loc,))
+                                               end=end, loc=self.loc,))
 
-        # Store exit fn
+        # Store exit f
         exit_fn_obj = ir.Const(None, loc=self.loc)
         self.store(value=exit_fn_obj, name=exitfn)
 
-
     def op_SETUP_EXCEPT(self, inst):
         # Removed since python3.8
         self._insert_try_block_begin()
diff --git a/numba/core/transforms.py b/numba/core/transforms.py
index 1169c7d..e5398cd 100644
--- a/numba/core/transforms.py
+++ b/numba/core/transforms.py
@@ -588,7 +588,6 @@ def find_setupwiths(func_ir):
     # rewrite the CFG in case there are multiple POP_BLOCK statements for one
     # with
     func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
-
     # here we need to turn the withs back into a list of tuples so that the
     # rest of the code can cope
     with_ranges_tuple = [(s, list(p)[0])
@@ -619,7 +618,7 @@ def find_setupwiths(func_ir):
     # now we need to rewrite the tuple such that we have SETUP_WITH matching the
     # successor of the block that contains the POP_BLOCK.
     with_ranges_tuple = [(s, func_ir.blocks[p].terminator.get_targets()[0])
-             for (s, p) in with_ranges_tuple]
+                         for (s, p) in with_ranges_tuple]
 
     # finally we check for nested with statements and reject them
     with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
@@ -754,7 +753,6 @@ def _eliminate_nested_withs(with_ranges):
 def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
     """Modify the FunctionIR to merge the exit blocks of with constructs.
     """
-    out = []
     for k in withs:
         vs : set = withs[k]
         if len(vs) > 1: