1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
# mypy: allow-untyped-defs
import sympy
from sympy.multipledispatch import dispatch
__all__ = ["SingletonInt"]
class SingletonInt(sympy.AtomicExpr):
# This is probably not super important unless we are in multiple dispatch
# situations with other more exotic Expr types.
_op_priority = 99999
def __new__(cls, *args, coeff=None, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of NestedIntSymNodeImpl in
# c10/core/NestedIntSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with nested int ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
and other._val == self._val
and self._coeff == other._coeff
):
return sympy.true
else:
return sympy.false
# This is necessary so that calling expr.free_symbols on exprs that contain
# this Singleton does not error
@property
def free_symbols(self):
return set()
def __mul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
def __rmul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
# Make sure we promptly raise an error instead of falling back to building
# an expression tree. There are probably more ops, how can we be exhaustive?
def __add__(self, other):
raise NotImplementedError("NYI")
def __sub__(self, other):
raise NotImplementedError("NYI")
def __truediv__(self, other):
raise NotImplementedError("NYI")
def __floordiv__(self, other):
raise NotImplementedError("NYI")
def __mod__(self, other):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with nested int ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if b <= 2:
return sympy.true
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if a._val == b._val:
if a._coeff >= b._coeff:
return sympy.true
else:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
|