# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import builtins import math import operator from typing import Any, Dict, Set, Union # necessary to ensure the ops are registered import torch from executorch.exir.dialects._ops import bind_pattern_to_op, ops from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverload from torch.library import Library # pyre-unsafe executorch_prims_lib = Library("executorch_prim", "DEF") _SymScalar = Union[SymBool, SymFloat, SymInt] @bind_pattern_to_op(executorch_prims_lib, "add.Scalar(Scalar a, Scalar b) -> Scalar") def add(a: _SymScalar, b: _SymScalar) -> _SymScalar: return a + b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "mul.Scalar(Scalar a, Scalar b) -> Scalar") def mul(a: _SymScalar, b: _SymScalar) -> _SymScalar: return a * b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "sub.Scalar(Scalar a, Scalar b) -> Scalar") def sub(a: _SymScalar, b: _SymScalar) -> _SymScalar: return a - b # pyre-ignore @bind_pattern_to_op( executorch_prims_lib, "floordiv.Scalar(Scalar a, Scalar b) -> Scalar" ) def floordiv(a: _SymScalar, b: _SymScalar) -> _SymScalar: return a // b # pyre-ignore @bind_pattern_to_op( executorch_prims_lib, "truediv.Scalar(Scalar a, Scalar b) -> Scalar" ) def truediv(a: _SymScalar, b: _SymScalar) -> _SymScalar: return a / b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "sym_float.Scalar(Scalar a) -> Scalar") def sym_float(a: _SymScalar) -> _SymScalar: return float(a) # pyre-ignore # TODO: ideally we should return SymBool in the schema, but it seems # the schema parser does not recognize SymBool yet: P629748075 @bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool") def gt(a: _SymScalar, b: _SymScalar) -> bool: return a > b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "lt.Scalar(Scalar a, Scalar b) -> bool") def lt(a: _SymScalar, b: _SymScalar) -> bool: return a < b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "ge.Scalar(Scalar a, Scalar b) -> bool") def ge(a: _SymScalar, b: _SymScalar) -> bool: return a >= b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "le.Scalar(Scalar a, Scalar b) -> bool") def le(a: _SymScalar, b: _SymScalar) -> bool: return a <= b # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "eq.Scalar(Scalar a, Scalar b) -> bool") def eq(a: _SymScalar, b: _SymScalar) -> bool: return a == b @bind_pattern_to_op(executorch_prims_lib, "mod.Scalar(SymInt a, SymInt b) -> SymInt") def mod(a: SymInt, b: SymInt) -> SymInt: return SymInt(int(a) % int(b)) @bind_pattern_to_op(executorch_prims_lib, "neg.Scalar(Scalar a) -> Scalar") def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") def ceil(a: _SymScalar) -> _SymScalar: return math.ceil(a) # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") def builtin_round(a: _SymScalar) -> _SymScalar: return round(a) # pyre-ignore @bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") def trunc(a: _SymScalar) -> _SymScalar: return math.trunc(a) # pyre-ignore _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { builtins.round: ops.backend.executorch_prim.round.Scalar, math.ceil: ops.backend.executorch_prim.ceil.Scalar, math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, operator.add: ops.backend.executorch_prim.add.Scalar, operator.floordiv: ops.backend.executorch_prim.floordiv.Scalar, operator.truediv: ops.backend.executorch_prim.truediv.Scalar, operator.eq: ops.backend.executorch_prim.eq.Scalar, operator.gt: ops.backend.executorch_prim.gt.Scalar, operator.lt: ops.backend.executorch_prim.lt.Scalar, operator.ge: ops.backend.executorch_prim.ge.Scalar, operator.le: ops.backend.executorch_prim.le.Scalar, operator.mod: ops.backend.executorch_prim.mod.Scalar, operator.neg: ops.backend.executorch_prim.neg.Scalar, torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar, } _EXECUTORCH_SYM_OPS: Set[OpOverload] = set( _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values() ) _EXECUTORCH_SYM_OPS.update( { torch.ops.aten.sym_stride.int, torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default, torch.ops.aten._local_scalar_dense.default, torch.ops.aten.sym_constrain_range_for_size.default, torch.ops.aten.sym_constrain_range.default, } ) def is_sym_op(target) -> bool: return ( target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.keys() or target in _EXECUTORCH_SYM_OPS )