1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
from typing import * # noqa: F403
import torch
from torch.fx.experimental._constant_symnode import ConstantIntNode
__all__ = ["NestedIntNode"]
# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
def _eq(lhs: Any, rhs: Any) -> bool:
return (
isinstance(lhs, NestedIntNode)
and isinstance(rhs, NestedIntNode)
and lhs.t_id == rhs.t_id
and lhs.coeff == rhs.coeff
)
def _ge(lhs: Any, rhs: Any) -> bool:
if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
if lhs.t_id == rhs.t_id:
return lhs.coeff >= rhs.coeff
raise ValueError("ge: relation is indeterminate")
elif isinstance(lhs, NestedIntNode):
if rhs.is_constant() and rhs.constant_int() <= 2:
return True
raise ValueError("ge: relation is indeterminate")
elif isinstance(rhs, NestedIntNode):
if lhs.is_constant() and lhs.constant_int() < 2:
return False
raise ValueError("ge: relation is indeterminate")
else:
raise ValueError("inputs unsupported")
class NestedIntNode:
def __init__(self, t_id: int, coeff: int):
self.t_id = t_id
self.coeff = coeff
def nested_int_coeff(self) -> int:
return self.coeff
def maybe_as_int(self) -> Optional[int]:
return None
def is_int(self) -> bool:
return True
def is_float(self) -> bool:
return False
def is_bool(self) -> bool:
return False
def is_nested_int(self) -> bool:
return True
def clone(self) -> "NestedIntNode":
return self
def _str(self) -> Any:
if self.coeff == 1:
return f"j{self.t_id}"
return f"{self.coeff}*j{self.t_id}"
def str(self) -> Any:
return self._str()
def __str__(self) -> Any:
return self._str()
def __repr__(self) -> Any:
return self._str()
def _graph_repr(self) -> Any:
return self._str()
def mul(self, other: Any) -> "NestedIntNode":
if other.is_constant():
other = other.constant_int()
else:
raise ValueError(f"unsupported: {type(other)}")
return NestedIntNode(self.t_id, self.coeff * other)
def eq(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(_eq(self, other))
def ne(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(not _eq(self, other))
def gt(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(not _ge(other, self))
def lt(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(not _ge(self, other))
def le(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(_ge(other, self))
def ge(self, other: Any) -> Any:
return torch._C._get_constant_bool_symnode(_ge(self, other))
def is_symbolic(self) -> bool:
return False
def nested_int(self) -> int:
return self.t_id
def is_constant(self) -> bool:
return False
def wrap_int(self, num: int) -> ConstantIntNode:
assert type(num) is int
return ConstantIntNode(num)
|