1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894
|
"""
Type Inference
"""
from .typevar import TypeVar
from .ast import Def, Var
from copy import copy
from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa
from typing import cast
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
from .typevar import TypeSet # noqa
if TYPE_CHECKING:
T = MTypeVar('T')
TypeMap = Dict[TypeVar, TypeVar]
VarTyping = Dict[Var, TypeVar]
except ImportError:
TYPE_CHECKING = False
pass
class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
def __init__(self, tv, tc):
# type: (TypeVar, Union[TypeVar, TypeSet]) -> None
"""
Abstract "constructor" for linters
"""
assert False, "Abstract"
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
"""
Translate any TypeVars in the constraint according to the map or
TypeEnv m
"""
def translate_one(a):
# type: (Any) -> Any
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None # type: TypeConstraint
res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) # help MyPy figure out other
return self._args() == other._args()
def is_concrete(self):
# type: () -> bool
"""
Return true iff all typevars in the constraint are singletons.
"""
return [] == list(filter(lambda x: x.singleton_type() is None,
self.tvs()))
def __hash__(self):
# type: () -> int
return hash(self._args())
def _args(self):
# type: () -> Tuple[Any,...]
"""
Return a tuple with the exact arguments passed to __init__ to create
this object.
"""
assert False, "Abstract"
def tvs(self):
# type: () -> Iterable[TypeVar]
"""
Return the typevars contained in this constraint.
"""
return list(filter(lambda x: isinstance(x, TypeVar), self._args()))
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
assert False, "Abstract"
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert False, "Abstract"
def __repr__(self):
# type: () -> str
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class TypesEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
return self.tv1 == self.tv2 or self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv1.singleton_type() == self.tv2.singleton_type()
class InTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
def __init__(self, tv, ts):
# type: (TypeVar, TypeSet) -> None
assert not tv.is_derived and tv.name.startswith("typeof_")
self.tv = tv
self.ts = ts
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv, self.ts)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
tv_ts = self.tv.get_typeset().copy()
# Trivially True
if (tv_ts.issubset(self.ts)):
return True
# Trivially false
tv_ts &= self.ts
if (tv_ts.size() == 0):
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
"""
Constraint specifying that a type var tv1 must be wider than or equal to
type var tv2 at runtime. This requires that:
1) They have the same number of lanes
2) In a lane tv1 has at least as many bits as tv2.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
# Trivially True
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
# Trivially False
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
# Trivially False
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class SameWidth(TypeConstraint):
"""
Constraint specifying that two types have the same width. E.g. i32x2 has
the same width as i64x1, i16x4, f32x2, f64, b1x64 etc.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
# Trivially False
if len(ts1.widths().intersection(ts2.widths())) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return (typ1.width() == typ2.width())
class TypeEnv(object):
"""
Class encapsulating the necessary book keeping for type inference.
:attribute type_map: dict holding the equivalence relations between tvs
:attribute constraints: a list of accumulated constraints - tuples
(tv1, tv2)) where tv1 and tv2 are equal
:attribute ranks: dictionary recording the (optional) ranks for tvs.
'rank' is a partial ordering on TVs based on their
origin. See comments in rank() and register().
:attribute vars: a set containing all known Vars
:attribute idx: counter used to get fresh ids
"""
RANK_SINGLETON = 5
RANK_INPUT = 4
RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2
RANK_TEMP = 1
RANK_INTERNAL = 0
def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
self.ranks = {} # type: Dict[TypeVar, int]
self.vars = set() # type: Set[Var]
if arg is None:
self.type_map = {} # type: TypeMap
self.constraints = [] # type: List[TypeConstraint]
else:
self.type_map, self.constraints = arg
self.idx = 0
def __getitem__(self, arg):
# type: (Union[TypeVar, Var]) -> TypeVar
"""
Lookup the canonical representative for a Var/TypeVar.
"""
if (isinstance(arg, Var)):
assert arg in self.vars
tv = arg.get_typevar()
else:
assert (isinstance(arg, TypeVar))
tv = arg
while tv in self.type_map:
tv = self.type_map[tv]
if tv.is_derived:
tv = TypeVar.derived(self[tv.base], tv.derived_func)
return tv
def equivalent(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
"""
Record a that the free tv1 is part of the same equivalence class as
tv2. The canonical representative of the merged class is tv2's
canonical representative.
"""
assert not tv1.is_derived
assert self[tv1] == tv1
# Make sure we don't create cycles
if tv2.is_derived:
assert self[tv2.base] != tv1
self.type_map[tv1] = tv2
def add_constraint(self, constr):
# type: (TypeConstraint) -> None
"""
Add a new constraint
"""
if (constr in self.constraints):
return
# InTypeset constraints can be expressed by constraining the typeset of
# a variable. No need to add them to self.constraints
if (isinstance(constr, InTypeset)):
self[constr.tv].constrain_types_by_ts(constr.ts)
return
self.constraints.append(constr)
def get_uid(self):
# type: () -> str
r = str(self.idx)
self.idx += 1
return r
def __repr__(self):
# type: () -> str
return self.dot()
def rank(self, tv):
# type: (TypeVar) -> int
"""
Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()). Internally generated
non-derived TVs implicitly get the lowest rank (0). Derived variables
get their rank from their free typevar. Singletons have the highest
rank. TVs associated with vars in a source pattern have a higher rank
than TVs associated with temporary vars.
"""
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
else TypeEnv.RANK_SINGLETON
if tv.is_derived:
tv = tv.free_typevar()
return self.ranks.get(tv, default_rank)
def register(self, v):
# type: (Var) -> None
"""
Register a new Var v. This computes a rank for the associated TypeVar
for v, which is used to impose a partial order on type variables.
"""
self.vars.add(v)
if v.is_input():
r = TypeEnv.RANK_INPUT
elif v.is_intermediate():
r = TypeEnv.RANK_INTERMEDIATE
elif v.is_output():
r = TypeEnv.RANK_OUTPUT
else:
assert(v.is_temp())
r = TypeEnv.RANK_TEMP
self.ranks[v.get_typevar()] = r
def free_typevars(self):
# type: () -> List[TypeVar]
"""
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
# Filter out None here due to singleton type vars
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
def normalize(self):
# type: () -> None
"""
Normalize by:
- collapsing any roots that don't correspond to a concrete TV AND
have a single TV derived from them or equivalent to them
E.g. if we have a root of the tree that looks like:
typeof_a typeof_b
\\ /
typeof_x
|
half_width(1)
|
1
we want to collapse the linear path between 1 and typeof_x. The
resulting graph is:
typeof_a typeof_b
\\ /
typeof_x
"""
source_tvs = set([v.get_typevar() for v in self.vars])
children = {} # type: Dict[TypeVar, Set[TypeVar]]
for v in self.type_map.values():
if not v.is_derived:
continue
t = v.free_typevar()
s = children.get(t, set())
s.add(v)
children[t] = s
for (a, b) in self.type_map.items():
s = children.get(b, set())
s.add(a)
children[b] = s
for r in self.free_typevars():
while (r not in source_tvs and r in children and
len(children[r]) == 1):
child = list(children[r])[0]
if child in self.type_map:
assert self.type_map[child] == r
del self.type_map[child]
r = child
def extract(self):
# type: () -> TypeEnv
"""
Extract a clean type environment from self, that only mentions
TVs associated with real variables
"""
vars_tvs = set([v.get_typevar() for v in self.vars])
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints:
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
for arg in constr._args():
if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr)
# Sanity: translated typemap should refer to only real vars
for (k, v) in new_type_map.items():
assert k in vars_tvs
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
t = TypeEnv()
t.type_map = new_type_map
t.constraints = new_constraints
# ranks and vars contain only TVs associated with real vars
t.ranks = copy(self.ranks)
t.vars = copy(self.vars)
return t
def concrete_typings(self):
# type: () -> Iterable[VarTyping]
"""
Return an iterable over all possible concrete typings permitted by this
TypeEnv.
"""
free_tvs = self.free_typevars()
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
for concrete_types in product(*free_tv_iters):
# Build type substitutions for all free vars
m = {tv: TypeVar.singleton(typ)
for (tv, typ) in zip(free_tvs, concrete_types)}
concrete_var_map = {v: subst(self[v.get_typevar()], m)
for v in self.vars}
# Check if constraints are satisfied for this typing
failed = None
for constr in self.constraints:
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
break
if (failed is not None):
continue
yield concrete_var_map
def permits(self, concrete_typing):
# type: (VarTyping) -> bool
"""
Return true iff this TypeEnv permits the (possibly partial) concrete
variable type mapping concrete_typing.
"""
# Each variable has a concrete type, that is a subset of its inferred
# typeset.
for (v, typ) in concrete_typing.items():
assert typ.singleton_type() is not None
if not typ.get_typeset().issubset(self[v].get_typeset()):
return False
m = {self[v]: typ for (v, typ) in concrete_typing.items()}
# Constraints involving vars in concrete_typing are satisfied
for constr in self.constraints:
try:
# If the constraint includes only vars in concrete_typing, we
# can translate it using m. Otherwise we encounter a KeyError
# and ignore it
constr = constr.translate(m)
if not constr.eval():
return False
except KeyError:
pass
return True
def dot(self):
# type: () -> str
"""
Return a representation of self as a graph in dot format.
Nodes correspond to TypeVariables.
Dotted edges correspond to equivalences between TVS
Solid edges correspond to derivation relations between TVs.
Dashed edges correspond to equivalence constraints.
"""
def label(s):
# type: (TypeVar) -> str
return "\"" + str(s) + "\""
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set() # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
def add_nodes(*args):
# type: (*TypeVar) -> None
for tv in args:
nodes.add(tv)
while (tv.is_derived):
nodes.add(tv.base)
edges.add((tv, tv.base, "solid", "forward",
tv.derived_func))
tv = tv.base
for v in self.vars:
add_nodes(v.get_typevar())
for (tv1, tv2) in self.type_map.items():
# Add all intermediate TVs appearing in edges
add_nodes(tv1, tv2)
edges.add((tv1, tv2, "dotted", "forward", None))
for constr in self.constraints:
if isinstance(constr, TypesEqual):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
elif isinstance(constr, SameWidth):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none",
"same_width"))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
r = "digraph {\n"
for n in nodes:
r += label(n)
if n in root_nodes:
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, direction, elabel) in edges:
e = label(n1) + "->" + label(n2)
e += "[style={},dir={}".format(style, direction)
if elabel is not None:
e += ",label=\"{}\"".format(elabel)
e += "];\n"
r += e
r += "}"
return r
if TYPE_CHECKING:
TypingError = str
TypingOrError = Union[TypeEnv, TypingError]
def get_error(typing_or_err):
# type: (TypingOrError) -> Optional[TypingError]
"""
Helper function to appease mypy when checking the result of typing.
"""
if isinstance(typing_or_err, str):
if (TYPE_CHECKING):
return cast(TypingError, typing_or_err)
else:
return typing_or_err
else:
return None
def get_type_env(typing_or_err):
# type: (TypingOrError) -> TypeEnv
"""
Helper function to appease mypy when checking the result of typing.
"""
assert isinstance(typing_or_err, TypeEnv), \
"Unexpected error: {}".format(typing_or_err)
if (TYPE_CHECKING):
return cast(TypeEnv, typing_or_err)
else:
return typing_or_err
def subst(tv, tv_map):
# type: (TypeVar, TypeMap) -> TypeVar
"""
Perform substition on the input tv using the TypeMap tv_map.
"""
if tv in tv_map:
return tv_map[tv]
if tv.is_derived:
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
return tv
def normalize_tv(tv):
# type: (TypeVar) -> TypeVar
"""
Normalize a (potentially derived) TV using the following rules:
- vector and width derived functions commute
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
- half/double pairs collapse
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
"""
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
if not tv.is_derived:
return tv
df = tv.derived_func
if (tv.base.is_derived):
base_df = tv.base.derived_func
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
if df in vector_derives and base_df in width_derives:
return normalize_tv(
TypeVar.derived(
TypeVar.derived(tv.base.base, df), base_df))
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
# cancel each other. Note: This doesn't hide any over/underflows,
# since we 1) assert the safety of each TV in the chain upon its
# creation, and 2) the base typeset is only allowed to shrink.
if (df, base_df) in \
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
return normalize_tv(tv.base.base)
return TypeVar.derived(normalize_tv(tv.base), df)
def constrain_fixpoint(tv1, tv2):
# type: (TypeVar, TypeVar) -> None
"""
Given typevars tv1 and tv2 (which could be derived from one another)
constrain their typesets to be the same. When one is derived from the
other, repeat the constrain process until fixpoint.
"""
# Constrain tv2's typeset as long as tv1's typeset is changing.
while True:
old_tv1_ts = tv1.get_typeset().copy()
tv2.constrain_types(tv1)
if tv1.get_typeset() == old_tv1_ts:
break
old_tv2_ts = tv2.get_typeset().copy()
tv1.constrain_types(tv2)
assert old_tv2_ts == tv2.get_typeset()
def unify(tv1, tv2, typ):
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
"""
Unify tv1 and tv2 in the current type environment typ, and return an
updated type environment or error.
"""
tv1 = normalize_tv(typ[tv1])
tv2 = normalize_tv(typ[tv2])
# Already unified
if tv1 == tv2:
return typ
if typ.rank(tv2) < typ.rank(tv1):
return unify(tv2, tv1, typ)
constrain_fixpoint(tv1, tv2)
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
return "Error: empty type created when unifying {} and {}"\
.format(tv1, tv2)
# Free -> Derived(Free)
if not tv1.is_derived:
typ.equivalent(tv1, tv2)
return typ
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(TypesEqual(tv1, tv2))
return typ
def move_first(l, i):
# type: (List[T], int) -> List[T]
return [l[i]] + l[:i] + l[i+1:]
def ti_def(definition, typ):
# type: (Def, TypeEnv) -> TypingOrError
"""
Perform type inference on one Def in the current type environment typ and
return an updated type environment or error.
At a high level this works by creating fresh copies of each formal type var
in the Def's instruction's signature, and unifying the formal tv with the
corresponding actual tv.
"""
expr = definition.expr
inst = expr.inst
# Create a dict m mapping each free typevar in the signature of definition
# to a fresh copy of itself.
free_formal_tvs = inst.all_typevars()
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
# Update m with any explicitly bound type vars
for (idx, bound_typ) in enumerate(expr.typevars):
m[free_formal_tvs[idx]] = TypeVar.singleton(bound_typ)
# Get fresh copies for each typevar in the signature (both free and
# derived)
fresh_formal_tvs = \
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
# Get the list of actual Vars
actual_vars = [] # type: List[Expr]
actual_vars += [definition.defs[i] for i in inst.value_results]
actual_vars += [expr.args[i] for i in inst.value_opnums]
# Get the list of the actual TypeVars
actual_tvs = []
for v in actual_vars:
assert(isinstance(v, Var))
# Register with TypeEnv that this typevar corresponds ot variable v,
# and thus has a given rank
typ.register(v)
actual_tvs.append(v.get_typevar())
# Make sure we unify the control typevar first.
if inst.is_polymorphic:
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
actual_tvs = move_first(actual_tvs, idx)
# Unify each actual typevar with the corresponding fresh formal tv
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
typ_or_err = unify(actual_tv, formal_tv, typ)
err = get_error(typ_or_err)
if (err):
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
typ = get_type_env(typ_or_err)
# Add any instruction specific constraints
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ
def ti_rtl(rtl, typ):
# type: (Rtl, TypeEnv) -> TypingOrError
"""
Perform type inference on an Rtl in a starting type env typ. Return an
updated type environment or error.
"""
for (i, d) in enumerate(rtl.rtl):
assert (isinstance(d, Def))
typ_or_err = ti_def(d, typ)
err = get_error(typ_or_err) # type: Optional[TypingError]
if (err):
return "On line {}: ".format(i) + err
typ = get_type_env(typ_or_err)
return typ
def ti_xform(xform, typ):
# type: (XForm, TypeEnv) -> TypingOrError
"""
Perform type inference on an Rtl in a starting type env typ. Return an
updated type environment or error.
"""
typ_or_err = ti_rtl(xform.src, typ)
err = get_error(typ_or_err) # type: Optional[TypingError]
if (err):
return "In src pattern: " + err
typ = get_type_env(typ_or_err)
typ_or_err = ti_rtl(xform.dst, typ)
err = get_error(typ_or_err)
if (err):
return "In dst pattern: " + err
typ = get_type_env(typ_or_err)
return get_type_env(typ_or_err)
|