• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import sympy
3from sympy.multipledispatch import dispatch
4
5
6__all__ = ["SingletonInt"]
7
8
9class SingletonInt(sympy.AtomicExpr):
10    # This is probably not super important unless we are in multiple dispatch
11    # situations with other more exotic Expr types.
12    _op_priority = 99999
13
14    def __new__(cls, *args, coeff=None, **kwargs):
15        instance = super().__new__(cls, *args, **kwargs)
16        return instance
17
18    # The semantics of this class should match that of NestedIntSymNodeImpl in
19    # c10/core/NestedIntSymNodeImpl.h
20    def __init__(self, val, *, coeff=1):
21        self._val = val
22        self._coeff = coeff
23        super().__init__()
24
25    # See NOTE [ Inequalities with nested int ]
26    def _eval_Eq(self, other):
27        if (
28            isinstance(other, SingletonInt)
29            and other._val == self._val
30            and self._coeff == other._coeff
31        ):
32            return sympy.true
33        else:
34            return sympy.false
35
36    # This is necessary so that calling expr.free_symbols on exprs that contain
37    # this Singleton does not error
38    @property
39    def free_symbols(self):
40        return set()
41
42    def __mul__(self, other):
43        if isinstance(other, SingletonInt):
44            raise ValueError(
45                "SingletonInt cannot be multiplied by another SingletonInt"
46            )
47        return SingletonInt(self._val, coeff=self._coeff * other)
48
49    def __rmul__(self, other):
50        if isinstance(other, SingletonInt):
51            raise ValueError(
52                "SingletonInt cannot be multiplied by another SingletonInt"
53            )
54        return SingletonInt(self._val, coeff=self._coeff * other)
55
56    # Make sure we promptly raise an error instead of falling back to building
57    # an expression tree. There are probably more ops, how can we be exhaustive?
58    def __add__(self, other):
59        raise NotImplementedError("NYI")
60
61    def __sub__(self, other):
62        raise NotImplementedError("NYI")
63
64    def __truediv__(self, other):
65        raise NotImplementedError("NYI")
66
67    def __floordiv__(self, other):
68        raise NotImplementedError("NYI")
69
70    def __mod__(self, other):
71        raise NotImplementedError("NYI")
72
73
74# See NOTE [ Inequalities with nested int ]
75@dispatch(sympy.Integer, SingletonInt)
76def _eval_is_ge(a, b):
77    if a < 2:
78        return sympy.false
79    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
80
81
82@dispatch(SingletonInt, sympy.Integer)  # type: ignore[no-redef]
83def _eval_is_ge(a, b):  # noqa: F811
84    if b <= 2:
85        return sympy.true
86    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
87
88
89@dispatch(SingletonInt, SingletonInt)  # type: ignore[no-redef]
90def _eval_is_ge(a, b):  # noqa: F811
91    if a._val == b._val:
92        if a._coeff >= b._coeff:
93            return sympy.true
94        else:
95            return sympy.false
96    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
97