1"""Subclass of ir.Value that supports Python operators.""" 2 3# mypy: allow-untyped-defs 4from __future__ import annotations 5 6import onnxscript 7from onnxscript import ir 8 9 10class SymbolicTensor(ir.Value): 11 """A subclass of ir.Value that supports Python operators.""" 12 13 def __init__( 14 self, 15 opset: onnxscript.values.Opset, 16 name: str | None = None, 17 shape: ir.Shape | None = None, 18 type: ir.TypeProtocol | None = None, 19 doc_string: str | None = None, 20 const_value: ir.TensorProtocol | None = None, 21 ): 22 super().__init__( 23 name=name, 24 shape=shape, 25 type=type, 26 doc_string=doc_string, 27 const_value=const_value, 28 ) 29 self._opset = opset 30 31 @property 32 def rank(self) -> int | None: 33 if self.shape is None: 34 return None 35 return len(self.shape) 36 37 # TODO: Implement indexing 38 39 def __mod__(self, other): 40 if self.dtype in { 41 ir.DataType.FLOAT, 42 ir.DataType.DOUBLE, 43 ir.DataType.FLOAT16, 44 ir.DataType.BFLOAT16, 45 }: 46 return self._opset.Mod(self, other, fmod=1) 47 return self._opset.Mod(self, other) 48 49 def __ne__(self, other): 50 return self._opset.Not(self._opset.Equal(self, other)) 51 52 def __neg__(self): 53 return self._opset.Neg(self) 54 55 def __add__(self, other): 56 return self._opset.Add(self, other) 57 58 def __radd__(self, other): 59 return self._opset.Add(other, self) 60 61 def __rand__(self, other): 62 return self._opset.And(other, self) 63 64 def __mul__(self, other): 65 return self._opset.Mul(self, other) 66 67 def __rmul__(self, other): 68 return self._opset.Mul(other, self) 69 70 def __matmul__(self, other): 71 return self._opset.MatMul(self, other) 72 73 def __pow__(self, other): 74 return self._opset.Pow(self, other) 75 76 def __sub__(self, other): 77 return self._opset.Sub(self, other) 78 79 def __rsub__(self, other): 80 return self._opset.Sub(other, self) 81 82 def __truediv__(self, other): 83 return self._opset.Div(self, other) 84 85 def __lt__(self, other): 86 return self._opset.Less(self, other) 87 88 def __le__(self, other): 89 return self._opset.LessOrEqual(self, other) 90 91 def __ge__(self, other): 92 return self._opset.GreaterOrEqual(self, other) 93 94 def __gt__(self, other): 95 return self._opset.Greater(self, other) 96