1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
|
"""Utilities for checking that internal ir is valid and consistent."""
from __future__ import annotations
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR
from mypyc.ir.ops import (
Assign,
AssignMulti,
BaseAssign,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
SetAttr,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
)
from mypyc.ir.pprint import format_func
from mypyc.ir.rtypes import (
RArray,
RInstance,
RPrimitive,
RType,
RUnion,
bytes_rprimitive,
dict_rprimitive,
int_rprimitive,
is_float_rprimitive,
is_object_rprimitive,
list_rprimitive,
range_rprimitive,
set_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
class FnError:
def __init__(self, source: Op | BasicBlock, desc: str) -> None:
self.source = source
self.desc = desc
def __eq__(self, other: object) -> bool:
return (
isinstance(other, FnError) and self.source == other.source and self.desc == other.desc
)
def __repr__(self) -> str:
return f"FnError(source={self.source}, desc={self.desc})"
def check_func_ir(fn: FuncIR) -> list[FnError]:
"""Applies validations to a given function ir and returns a list of errors found."""
errors = []
op_set = set()
for block in fn.blocks:
if not block.terminated:
errors.append(
FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated")
)
for op in block.ops[:-1]:
if isinstance(op, ControlOp):
errors.append(FnError(source=op, desc="Block has operations after control op"))
if op in op_set:
errors.append(FnError(source=op, desc="Func has a duplicate op"))
op_set.add(op)
errors.extend(check_op_sources_valid(fn))
if errors:
return errors
op_checker = OpChecker(fn)
for block in fn.blocks:
for op in block.ops:
op.accept(op_checker)
return op_checker.errors
class IrCheckException(Exception):
pass
def assert_func_ir_valid(fn: FuncIR) -> None:
errors = check_func_ir(fn)
if errors:
raise IrCheckException(
"Internal error: Generated invalid IR: \n"
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors]))
)
def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
errors = []
valid_ops: set[Op] = set()
valid_registers: set[Register] = set()
for block in fn.blocks:
valid_ops.update(block.ops)
for op in block.ops:
if isinstance(op, BaseAssign):
valid_registers.add(op.dest)
elif isinstance(op, LoadAddress) and isinstance(op.src, Register):
valid_registers.add(op.src)
valid_registers.update(fn.arg_regs)
for block in fn.blocks:
for op in block.ops:
for source in op.sources():
if isinstance(source, Integer):
pass
elif isinstance(source, Op):
if source not in valid_ops:
errors.append(
FnError(
source=op,
desc=f"Invalid op reference to op of type {type(source).__name__}",
)
)
elif isinstance(source, Register):
if source not in valid_registers:
errors.append(
FnError(
source=op, desc=f"Invalid op reference to register {source.name!r}"
)
)
return errors
disjoint_types = {
int_rprimitive.name,
bytes_rprimitive.name,
str_rprimitive.name,
dict_rprimitive.name,
list_rprimitive.name,
set_rprimitive.name,
tuple_rprimitive.name,
range_rprimitive.name,
}
def can_coerce_to(src: RType, dest: RType) -> bool:
"""Check if src can be assigned to dest_rtype.
Currently okay to have false positives.
"""
if isinstance(dest, RUnion):
return any(can_coerce_to(src, d) for d in dest.items)
if isinstance(dest, RPrimitive):
if isinstance(src, RPrimitive):
# If either src or dest is a disjoint type, then they must both be.
if src.name in disjoint_types and dest.name in disjoint_types:
return src.name == dest.name
return src.size == dest.size
if isinstance(src, RInstance):
return is_object_rprimitive(dest)
if isinstance(src, RUnion):
# IR doesn't have the ability to narrow unions based on
# control flow, so cannot be a strict all() here.
return any(can_coerce_to(s, dest) for s in src.items)
return False
return True
class OpChecker(OpVisitor[None]):
def __init__(self, parent_fn: FuncIR) -> None:
self.parent_fn = parent_fn
self.errors: list[FnError] = []
def fail(self, source: Op, desc: str) -> None:
self.errors.append(FnError(source=source, desc=desc))
def check_control_op_targets(self, op: ControlOp) -> None:
for target in op.targets():
if target not in self.parent_fn.blocks:
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
if not can_coerce_to(src, dest):
self.fail(
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
)
def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
def expect_float(self, op: Op, v: Value) -> None:
if not is_float_rprimitive(v.type):
self.fail(op, f"Float expected (actual type is {v.type})")
def expect_non_float(self, op: Op, v: Value) -> None:
if is_float_rprimitive(v.type):
self.fail(op, "Float not expected")
def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)
def visit_branch(self, op: Branch) -> None:
self.check_control_op_targets(op)
def visit_return(self, op: Return) -> None:
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
def visit_unreachable(self, op: Unreachable) -> None:
# Unreachables are checked at a higher level since validation
# requires access to the entire basic block.
pass
def visit_assign(self, op: Assign) -> None:
self.check_type_coercion(op, op.src.type, op.dest.type)
def visit_assign_multi(self, op: AssignMulti) -> None:
for src in op.src:
assert isinstance(op.dest.type, RArray)
self.check_type_coercion(op, src.type, op.dest.type.item_type)
def visit_load_error_value(self, op: LoadErrorValue) -> None:
# Currently it is assumed that all types have an error value.
# Once this is fixed we can validate that the rtype here actually
# has an error value.
pass
def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None:
for x in t:
if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)):
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
if isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)
def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
for x in s:
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
pass
elif isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)
else:
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
def visit_load_literal(self, op: LoadLiteral) -> None:
expected_type = None
if op.value is None:
expected_type = "builtins.object"
elif isinstance(op.value, int):
expected_type = "builtins.int"
elif isinstance(op.value, str):
expected_type = "builtins.str"
elif isinstance(op.value, bytes):
expected_type = "builtins.bytes"
elif isinstance(op.value, bool):
expected_type = "builtins.object"
elif isinstance(op.value, float):
expected_type = "builtins.float"
elif isinstance(op.value, complex):
expected_type = "builtins.object"
elif isinstance(op.value, tuple):
expected_type = "builtins.tuple"
self.check_tuple_items_valid_literals(op, op.value)
elif isinstance(op.value, frozenset):
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
# it's a set (when it's really a frozenset).
expected_type = "builtins.set"
self.check_frozenset_items_valid_literals(op, op.value)
assert expected_type is not None, "Missed a case for LoadLiteral check"
if op.type.name not in [expected_type, "builtins.object"]:
self.fail(
op,
f"Invalid literal value for type: value has "
f"type {expected_type}, but op has type {op.type.name}",
)
def visit_get_attr(self, op: GetAttr) -> None:
# Nothing to do.
pass
def visit_set_attr(self, op: SetAttr) -> None:
# Nothing to do.
pass
# Static operations cannot be checked at the function level.
def visit_load_static(self, op: LoadStatic) -> None:
pass
def visit_init_static(self, op: InitStatic) -> None:
pass
def visit_tuple_get(self, op: TupleGet) -> None:
# Nothing to do.
pass
def visit_tuple_set(self, op: TupleSet) -> None:
# Nothing to do.
pass
def visit_inc_ref(self, op: IncRef) -> None:
# Nothing to do.
pass
def visit_dec_ref(self, op: DecRef) -> None:
# Nothing to do.
pass
def visit_call(self, op: Call) -> None:
# Length is checked in constructor, and return type is set
# in a way that can't be incorrect
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
def visit_method_call(self, op: MethodCall) -> None:
# Similar to above, but we must look up method first.
method_decl = op.receiver_type.class_ir.method_decl(op.method)
if method_decl.kind == FUNC_STATICMETHOD:
decl_index = 0
else:
decl_index = 1
if len(op.args) + decl_index != len(method_decl.sig.args):
self.fail(op, "Incorrect number of args for method call.")
# Skip the receiver argument (self)
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
def visit_cast(self, op: Cast) -> None:
pass
def visit_box(self, op: Box) -> None:
pass
def visit_unbox(self, op: Unbox) -> None:
pass
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
pass
def visit_call_c(self, op: CallC) -> None:
pass
def visit_primitive_op(self, op: PrimitiveOp) -> None:
pass
def visit_truncate(self, op: Truncate) -> None:
pass
def visit_extend(self, op: Extend) -> None:
pass
def visit_load_global(self, op: LoadGlobal) -> None:
pass
def visit_int_op(self, op: IntOp) -> None:
self.expect_non_float(op, op.lhs)
self.expect_non_float(op, op.rhs)
def visit_comparison_op(self, op: ComparisonOp) -> None:
self.check_compatibility(op, op.lhs.type, op.rhs.type)
self.expect_non_float(op, op.lhs)
self.expect_non_float(op, op.rhs)
def visit_float_op(self, op: FloatOp) -> None:
self.expect_float(op, op.lhs)
self.expect_float(op, op.rhs)
def visit_float_neg(self, op: FloatNeg) -> None:
self.expect_float(op, op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
self.expect_float(op, op.lhs)
self.expect_float(op, op.rhs)
def visit_load_mem(self, op: LoadMem) -> None:
pass
def visit_set_mem(self, op: SetMem) -> None:
pass
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
pass
def visit_load_address(self, op: LoadAddress) -> None:
pass
def visit_keep_alive(self, op: KeepAlive) -> None:
pass
def visit_unborrow(self, op: Unborrow) -> None:
pass
|