1# mypy: allow-untyped-defs 2import sympy 3from sympy.multipledispatch import dispatch 4 5 6__all__ = ["SingletonInt"] 7 8 9class SingletonInt(sympy.AtomicExpr): 10 # This is probably not super important unless we are in multiple dispatch 11 # situations with other more exotic Expr types. 12 _op_priority = 99999 13 14 def __new__(cls, *args, coeff=None, **kwargs): 15 instance = super().__new__(cls, *args, **kwargs) 16 return instance 17 18 # The semantics of this class should match that of NestedIntSymNodeImpl in 19 # c10/core/NestedIntSymNodeImpl.h 20 def __init__(self, val, *, coeff=1): 21 self._val = val 22 self._coeff = coeff 23 super().__init__() 24 25 # See NOTE [ Inequalities with nested int ] 26 def _eval_Eq(self, other): 27 if ( 28 isinstance(other, SingletonInt) 29 and other._val == self._val 30 and self._coeff == other._coeff 31 ): 32 return sympy.true 33 else: 34 return sympy.false 35 36 # This is necessary so that calling expr.free_symbols on exprs that contain 37 # this Singleton does not error 38 @property 39 def free_symbols(self): 40 return set() 41 42 def __mul__(self, other): 43 if isinstance(other, SingletonInt): 44 raise ValueError( 45 "SingletonInt cannot be multiplied by another SingletonInt" 46 ) 47 return SingletonInt(self._val, coeff=self._coeff * other) 48 49 def __rmul__(self, other): 50 if isinstance(other, SingletonInt): 51 raise ValueError( 52 "SingletonInt cannot be multiplied by another SingletonInt" 53 ) 54 return SingletonInt(self._val, coeff=self._coeff * other) 55 56 # Make sure we promptly raise an error instead of falling back to building 57 # an expression tree. There are probably more ops, how can we be exhaustive? 58 def __add__(self, other): 59 raise NotImplementedError("NYI") 60 61 def __sub__(self, other): 62 raise NotImplementedError("NYI") 63 64 def __truediv__(self, other): 65 raise NotImplementedError("NYI") 66 67 def __floordiv__(self, other): 68 raise NotImplementedError("NYI") 69 70 def __mod__(self, other): 71 raise NotImplementedError("NYI") 72 73 74# See NOTE [ Inequalities with nested int ] 75@dispatch(sympy.Integer, SingletonInt) 76def _eval_is_ge(a, b): 77 if a < 2: 78 return sympy.false 79 raise ValueError("Symbolic SingletonInt: Relation is indeterminate") 80 81 82@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef] 83def _eval_is_ge(a, b): # noqa: F811 84 if b <= 2: 85 return sympy.true 86 raise ValueError("Symbolic SingletonInt: Relation is indeterminate") 87 88 89@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef] 90def _eval_is_ge(a, b): # noqa: F811 91 if a._val == b._val: 92 if a._coeff >= b._coeff: 93 return sympy.true 94 else: 95 return sympy.false 96 raise ValueError("Symbolic SingletonInt: Relation is indeterminate") 97