1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
|
"""Helpers for implementing generic IR to IR transforms."""
from __future__ import annotations
from typing import Final, Optional
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
DecRef,
Extend,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
)
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
class IRTransform(OpVisitor[Optional[Value]]):
"""Identity transform.
Subclass and override to perform changes to IR.
Subclass IRTransform and override any OpVisitor visit_* methods
that perform any IR changes. The default implementations implement
an identity transform.
A visit method can return None to remove ops. In this case the
transform must ensure that no op uses the original removed op
as a source after the transform.
You can retain old BasicBlock and op references in ops. The transform
will automatically patch these for you as needed.
"""
def __init__(self, builder: LowLevelIRBuilder) -> None:
self.builder = builder
# Subclasses add additional op mappings here. A None value indicates
# that the op/register is deleted.
self.op_map: dict[Value, Value | None] = {}
def transform_blocks(self, blocks: list[BasicBlock]) -> None:
"""Transform basic blocks that represent a single function.
The result of the transform will be collected at self.builder.blocks.
"""
block_map: dict[BasicBlock, BasicBlock] = {}
op_map = self.op_map
empties = set()
for block in blocks:
new_block = BasicBlock()
block_map[block] = new_block
self.builder.activate_block(new_block)
new_block.error_handler = block.error_handler
for op in block.ops:
new_op = op.accept(self)
if new_op is not op:
op_map[op] = new_op
# A transform can produce empty blocks which can be removed.
if is_empty_block(new_block) and not is_empty_block(block):
empties.add(new_block)
self.builder.blocks = [block for block in self.builder.blocks if block not in empties]
# Update all op/block references to point to the transformed ones.
patcher = PatchVisitor(op_map, block_map)
for block in self.builder.blocks:
for op in block.ops:
op.accept(patcher)
if block.error_handler is not None:
block.error_handler = block_map.get(block.error_handler, block.error_handler)
def add(self, op: Op) -> Value:
return self.builder.add(op)
def visit_goto(self, op: Goto) -> None:
self.add(op)
def visit_branch(self, op: Branch) -> None:
self.add(op)
def visit_return(self, op: Return) -> None:
self.add(op)
def visit_unreachable(self, op: Unreachable) -> None:
self.add(op)
def visit_assign(self, op: Assign) -> Value | None:
if op.src in self.op_map and self.op_map[op.src] is None:
# Special case: allow removing register initialization assignments
return None
return self.add(op)
def visit_assign_multi(self, op: AssignMulti) -> Value | None:
return self.add(op)
def visit_load_error_value(self, op: LoadErrorValue) -> Value | None:
return self.add(op)
def visit_load_literal(self, op: LoadLiteral) -> Value | None:
return self.add(op)
def visit_get_attr(self, op: GetAttr) -> Value | None:
return self.add(op)
def visit_set_attr(self, op: SetAttr) -> Value | None:
return self.add(op)
def visit_load_static(self, op: LoadStatic) -> Value | None:
return self.add(op)
def visit_init_static(self, op: InitStatic) -> Value | None:
return self.add(op)
def visit_tuple_get(self, op: TupleGet) -> Value | None:
return self.add(op)
def visit_tuple_set(self, op: TupleSet) -> Value | None:
return self.add(op)
def visit_inc_ref(self, op: IncRef) -> Value | None:
return self.add(op)
def visit_dec_ref(self, op: DecRef) -> Value | None:
return self.add(op)
def visit_call(self, op: Call) -> Value | None:
return self.add(op)
def visit_method_call(self, op: MethodCall) -> Value | None:
return self.add(op)
def visit_cast(self, op: Cast) -> Value | None:
return self.add(op)
def visit_box(self, op: Box) -> Value | None:
return self.add(op)
def visit_unbox(self, op: Unbox) -> Value | None:
return self.add(op)
def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None:
return self.add(op)
def visit_call_c(self, op: CallC) -> Value | None:
return self.add(op)
def visit_primitive_op(self, op: PrimitiveOp) -> Value | None:
return self.add(op)
def visit_truncate(self, op: Truncate) -> Value | None:
return self.add(op)
def visit_extend(self, op: Extend) -> Value | None:
return self.add(op)
def visit_load_global(self, op: LoadGlobal) -> Value | None:
return self.add(op)
def visit_int_op(self, op: IntOp) -> Value | None:
return self.add(op)
def visit_comparison_op(self, op: ComparisonOp) -> Value | None:
return self.add(op)
def visit_float_op(self, op: FloatOp) -> Value | None:
return self.add(op)
def visit_float_neg(self, op: FloatNeg) -> Value | None:
return self.add(op)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None:
return self.add(op)
def visit_load_mem(self, op: LoadMem) -> Value | None:
return self.add(op)
def visit_set_mem(self, op: SetMem) -> Value | None:
return self.add(op)
def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None:
return self.add(op)
def visit_set_element(self, op: SetElement) -> Value | None:
return self.add(op)
def visit_load_address(self, op: LoadAddress) -> Value | None:
return self.add(op)
def visit_keep_alive(self, op: KeepAlive) -> Value | None:
return self.add(op)
def visit_unborrow(self, op: Unborrow) -> Value | None:
return self.add(op)
class PatchVisitor(OpVisitor[None]):
def __init__(
self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock]
) -> None:
self.op_map: Final = op_map
self.block_map: Final = block_map
def fix_op(self, op: Value) -> Value:
new = self.op_map.get(op, op)
assert new is not None, "use of removed op"
return new
def fix_block(self, block: BasicBlock) -> BasicBlock:
return self.block_map.get(block, block)
def visit_goto(self, op: Goto) -> None:
op.label = self.fix_block(op.label)
def visit_branch(self, op: Branch) -> None:
op.value = self.fix_op(op.value)
op.true = self.fix_block(op.true)
op.false = self.fix_block(op.false)
def visit_return(self, op: Return) -> None:
op.value = self.fix_op(op.value)
def visit_unreachable(self, op: Unreachable) -> None:
pass
def visit_assign(self, op: Assign) -> None:
op.src = self.fix_op(op.src)
def visit_assign_multi(self, op: AssignMulti) -> None:
op.src = [self.fix_op(s) for s in op.src]
def visit_load_error_value(self, op: LoadErrorValue) -> None:
pass
def visit_load_literal(self, op: LoadLiteral) -> None:
pass
def visit_get_attr(self, op: GetAttr) -> None:
op.obj = self.fix_op(op.obj)
def visit_set_attr(self, op: SetAttr) -> None:
op.obj = self.fix_op(op.obj)
op.src = self.fix_op(op.src)
def visit_load_static(self, op: LoadStatic) -> None:
pass
def visit_init_static(self, op: InitStatic) -> None:
op.value = self.fix_op(op.value)
def visit_tuple_get(self, op: TupleGet) -> None:
op.src = self.fix_op(op.src)
def visit_tuple_set(self, op: TupleSet) -> None:
op.items = [self.fix_op(item) for item in op.items]
def visit_inc_ref(self, op: IncRef) -> None:
op.src = self.fix_op(op.src)
def visit_dec_ref(self, op: DecRef) -> None:
op.src = self.fix_op(op.src)
def visit_call(self, op: Call) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_method_call(self, op: MethodCall) -> None:
op.obj = self.fix_op(op.obj)
op.args = [self.fix_op(arg) for arg in op.args]
def visit_cast(self, op: Cast) -> None:
op.src = self.fix_op(op.src)
def visit_box(self, op: Box) -> None:
op.src = self.fix_op(op.src)
def visit_unbox(self, op: Unbox) -> None:
op.src = self.fix_op(op.src)
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
if isinstance(op.value, Value):
op.value = self.fix_op(op.value)
def visit_call_c(self, op: CallC) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_primitive_op(self, op: PrimitiveOp) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_truncate(self, op: Truncate) -> None:
op.src = self.fix_op(op.src)
def visit_extend(self, op: Extend) -> None:
op.src = self.fix_op(op.src)
def visit_load_global(self, op: LoadGlobal) -> None:
pass
def visit_int_op(self, op: IntOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_comparison_op(self, op: ComparisonOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_float_op(self, op: FloatOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_float_neg(self, op: FloatNeg) -> None:
op.src = self.fix_op(op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_load_mem(self, op: LoadMem) -> None:
op.src = self.fix_op(op.src)
def visit_set_mem(self, op: SetMem) -> None:
op.dest = self.fix_op(op.dest)
op.src = self.fix_op(op.src)
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
op.src = self.fix_op(op.src)
def visit_set_element(self, op: SetElement) -> None:
op.src = self.fix_op(op.src)
def visit_load_address(self, op: LoadAddress) -> None:
if isinstance(op.src, LoadStatic):
new = self.fix_op(op.src)
assert isinstance(new, LoadStatic), new
op.src = new
def visit_keep_alive(self, op: KeepAlive) -> None:
op.src = [self.fix_op(s) for s in op.src]
def visit_unborrow(self, op: Unborrow) -> None:
op.src = self.fix_op(op.src)
def is_empty_block(block: BasicBlock) -> bool:
return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable)
|