1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
"""Subclass of ir.Value that supports Python operators."""
# mypy: allow-untyped-defs
from __future__ import annotations
import onnxscript
from onnxscript import ir
class SymbolicTensor(ir.Value):
"""A subclass of ir.Value that supports Python operators."""
def __init__(
self,
opset: onnxscript.values.Opset,
name: str | None = None,
shape: ir.Shape | None = None,
type: ir.TypeProtocol | None = None,
doc_string: str | None = None,
const_value: ir.TensorProtocol | None = None,
):
super().__init__(
name=name,
shape=shape,
type=type,
doc_string=doc_string,
const_value=const_value,
)
self._opset = opset
@property
def rank(self) -> int | None:
if self.shape is None:
return None
return len(self.shape)
# TODO: Implement indexing
def __mod__(self, other):
if self.dtype in {
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
}:
return self._opset.Mod(self, other, fmod=1)
return self._opset.Mod(self, other)
def __ne__(self, other):
return self._opset.Not(self._opset.Equal(self, other))
def __neg__(self):
return self._opset.Neg(self)
def __add__(self, other):
return self._opset.Add(self, other)
def __radd__(self, other):
return self._opset.Add(other, self)
def __rand__(self, other):
return self._opset.And(other, self)
def __mul__(self, other):
return self._opset.Mul(self, other)
def __rmul__(self, other):
return self._opset.Mul(other, self)
def __matmul__(self, other):
return self._opset.MatMul(self, other)
def __pow__(self, other):
return self._opset.Pow(self, other)
def __sub__(self, other):
return self._opset.Sub(self, other)
def __rsub__(self, other):
return self._opset.Sub(other, self)
def __truediv__(self, other):
return self._opset.Div(self, other)
def __lt__(self, other):
return self._opset.Less(self, other)
def __le__(self, other):
return self._opset.LessOrEqual(self, other)
def __ge__(self, other):
return self._opset.GreaterOrEqual(self, other)
def __gt__(self, other):
return self._opset.Greater(self, other)
|