1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import builtins 8import math 9import operator 10from typing import Any, Dict, Set, Union 11 12# necessary to ensure the ops are registered 13import torch 14from executorch.exir.dialects._ops import bind_pattern_to_op, ops 15from torch import SymBool, SymFloat, SymInt 16from torch._ops import OpOverload 17from torch.library import Library 18 19# pyre-unsafe 20 21 22executorch_prims_lib = Library("executorch_prim", "DEF") 23 24_SymScalar = Union[SymBool, SymFloat, SymInt] 25 26 27@bind_pattern_to_op(executorch_prims_lib, "add.Scalar(Scalar a, Scalar b) -> Scalar") 28def add(a: _SymScalar, b: _SymScalar) -> _SymScalar: 29 return a + b # pyre-ignore 30 31 32@bind_pattern_to_op(executorch_prims_lib, "mul.Scalar(Scalar a, Scalar b) -> Scalar") 33def mul(a: _SymScalar, b: _SymScalar) -> _SymScalar: 34 return a * b # pyre-ignore 35 36 37@bind_pattern_to_op(executorch_prims_lib, "sub.Scalar(Scalar a, Scalar b) -> Scalar") 38def sub(a: _SymScalar, b: _SymScalar) -> _SymScalar: 39 return a - b # pyre-ignore 40 41 42@bind_pattern_to_op( 43 executorch_prims_lib, "floordiv.Scalar(Scalar a, Scalar b) -> Scalar" 44) 45def floordiv(a: _SymScalar, b: _SymScalar) -> _SymScalar: 46 return a // b # pyre-ignore 47 48 49@bind_pattern_to_op( 50 executorch_prims_lib, "truediv.Scalar(Scalar a, Scalar b) -> Scalar" 51) 52def truediv(a: _SymScalar, b: _SymScalar) -> _SymScalar: 53 return a / b # pyre-ignore 54 55 56@bind_pattern_to_op(executorch_prims_lib, "sym_float.Scalar(Scalar a) -> Scalar") 57def sym_float(a: _SymScalar) -> _SymScalar: 58 return float(a) # pyre-ignore 59 60 61# TODO: ideally we should return SymBool in the schema, but it seems 62# the schema parser does not recognize SymBool yet: P629748075 63@bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool") 64def gt(a: _SymScalar, b: _SymScalar) -> bool: 65 return a > b # pyre-ignore 66 67 68@bind_pattern_to_op(executorch_prims_lib, "lt.Scalar(Scalar a, Scalar b) -> bool") 69def lt(a: _SymScalar, b: _SymScalar) -> bool: 70 return a < b # pyre-ignore 71 72 73@bind_pattern_to_op(executorch_prims_lib, "ge.Scalar(Scalar a, Scalar b) -> bool") 74def ge(a: _SymScalar, b: _SymScalar) -> bool: 75 return a >= b # pyre-ignore 76 77 78@bind_pattern_to_op(executorch_prims_lib, "le.Scalar(Scalar a, Scalar b) -> bool") 79def le(a: _SymScalar, b: _SymScalar) -> bool: 80 return a <= b # pyre-ignore 81 82 83@bind_pattern_to_op(executorch_prims_lib, "eq.Scalar(Scalar a, Scalar b) -> bool") 84def eq(a: _SymScalar, b: _SymScalar) -> bool: 85 return a == b 86 87 88@bind_pattern_to_op(executorch_prims_lib, "mod.Scalar(SymInt a, SymInt b) -> SymInt") 89def mod(a: SymInt, b: SymInt) -> SymInt: 90 return SymInt(int(a) % int(b)) 91 92 93@bind_pattern_to_op(executorch_prims_lib, "neg.Scalar(Scalar a) -> Scalar") 94def neg(a: _SymScalar) -> _SymScalar: 95 return -a # pyre-ignore 96 97 98@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") 99def ceil(a: _SymScalar) -> _SymScalar: 100 return math.ceil(a) # pyre-ignore 101 102 103@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") 104def builtin_round(a: _SymScalar) -> _SymScalar: 105 return round(a) # pyre-ignore 106 107 108@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") 109def trunc(a: _SymScalar) -> _SymScalar: 110 return math.trunc(a) # pyre-ignore 111 112 113_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { 114 builtins.round: ops.backend.executorch_prim.round.Scalar, 115 math.ceil: ops.backend.executorch_prim.ceil.Scalar, 116 math.trunc: ops.backend.executorch_prim.trunc.Scalar, 117 operator.sub: ops.backend.executorch_prim.sub.Scalar, 118 operator.mul: ops.backend.executorch_prim.mul.Scalar, 119 operator.add: ops.backend.executorch_prim.add.Scalar, 120 operator.floordiv: ops.backend.executorch_prim.floordiv.Scalar, 121 operator.truediv: ops.backend.executorch_prim.truediv.Scalar, 122 operator.eq: ops.backend.executorch_prim.eq.Scalar, 123 operator.gt: ops.backend.executorch_prim.gt.Scalar, 124 operator.lt: ops.backend.executorch_prim.lt.Scalar, 125 operator.ge: ops.backend.executorch_prim.ge.Scalar, 126 operator.le: ops.backend.executorch_prim.le.Scalar, 127 operator.mod: ops.backend.executorch_prim.mod.Scalar, 128 operator.neg: ops.backend.executorch_prim.neg.Scalar, 129 torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar, 130} 131 132 133_EXECUTORCH_SYM_OPS: Set[OpOverload] = set( 134 _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values() 135) 136_EXECUTORCH_SYM_OPS.update( 137 { 138 torch.ops.aten.sym_stride.int, 139 torch.ops.aten.sym_size.int, 140 torch.ops.aten.sym_numel.default, 141 torch.ops.aten._local_scalar_dense.default, 142 torch.ops.aten.sym_constrain_range_for_size.default, 143 torch.ops.aten.sym_constrain_range.default, 144 } 145) 146 147 148def is_sym_op(target) -> bool: 149 return ( 150 target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.keys() 151 or target in _EXECUTORCH_SYM_OPS 152 ) 153