• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Subclass of ir.Value that supports Python operators."""
2
3# mypy: allow-untyped-defs
4from __future__ import annotations
5
6import onnxscript
7from onnxscript import ir
8
9
10class SymbolicTensor(ir.Value):
11    """A subclass of ir.Value that supports Python operators."""
12
13    def __init__(
14        self,
15        opset: onnxscript.values.Opset,
16        name: str | None = None,
17        shape: ir.Shape | None = None,
18        type: ir.TypeProtocol | None = None,
19        doc_string: str | None = None,
20        const_value: ir.TensorProtocol | None = None,
21    ):
22        super().__init__(
23            name=name,
24            shape=shape,
25            type=type,
26            doc_string=doc_string,
27            const_value=const_value,
28        )
29        self._opset = opset
30
31    @property
32    def rank(self) -> int | None:
33        if self.shape is None:
34            return None
35        return len(self.shape)
36
37    # TODO: Implement indexing
38
39    def __mod__(self, other):
40        if self.dtype in {
41            ir.DataType.FLOAT,
42            ir.DataType.DOUBLE,
43            ir.DataType.FLOAT16,
44            ir.DataType.BFLOAT16,
45        }:
46            return self._opset.Mod(self, other, fmod=1)
47        return self._opset.Mod(self, other)
48
49    def __ne__(self, other):
50        return self._opset.Not(self._opset.Equal(self, other))
51
52    def __neg__(self):
53        return self._opset.Neg(self)
54
55    def __add__(self, other):
56        return self._opset.Add(self, other)
57
58    def __radd__(self, other):
59        return self._opset.Add(other, self)
60
61    def __rand__(self, other):
62        return self._opset.And(other, self)
63
64    def __mul__(self, other):
65        return self._opset.Mul(self, other)
66
67    def __rmul__(self, other):
68        return self._opset.Mul(other, self)
69
70    def __matmul__(self, other):
71        return self._opset.MatMul(self, other)
72
73    def __pow__(self, other):
74        return self._opset.Pow(self, other)
75
76    def __sub__(self, other):
77        return self._opset.Sub(self, other)
78
79    def __rsub__(self, other):
80        return self._opset.Sub(other, self)
81
82    def __truediv__(self, other):
83        return self._opset.Div(self, other)
84
85    def __lt__(self, other):
86        return self._opset.Less(self, other)
87
88    def __le__(self, other):
89        return self._opset.LessOrEqual(self, other)
90
91    def __ge__(self, other):
92        return self._opset.GreaterOrEqual(self, other)
93
94    def __gt__(self, other):
95        return self._opset.Greater(self, other)
96