• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import builtins
8import math
9import operator
10from typing import Any, Dict, Set, Union
11
12# necessary to ensure the ops are registered
13import torch
14from executorch.exir.dialects._ops import bind_pattern_to_op, ops
15from torch import SymBool, SymFloat, SymInt
16from torch._ops import OpOverload
17from torch.library import Library
18
19# pyre-unsafe
20
21
22executorch_prims_lib = Library("executorch_prim", "DEF")
23
24_SymScalar = Union[SymBool, SymFloat, SymInt]
25
26
27@bind_pattern_to_op(executorch_prims_lib, "add.Scalar(Scalar a, Scalar b) -> Scalar")
28def add(a: _SymScalar, b: _SymScalar) -> _SymScalar:
29    return a + b  # pyre-ignore
30
31
32@bind_pattern_to_op(executorch_prims_lib, "mul.Scalar(Scalar a, Scalar b) -> Scalar")
33def mul(a: _SymScalar, b: _SymScalar) -> _SymScalar:
34    return a * b  # pyre-ignore
35
36
37@bind_pattern_to_op(executorch_prims_lib, "sub.Scalar(Scalar a, Scalar b) -> Scalar")
38def sub(a: _SymScalar, b: _SymScalar) -> _SymScalar:
39    return a - b  # pyre-ignore
40
41
42@bind_pattern_to_op(
43    executorch_prims_lib, "floordiv.Scalar(Scalar a, Scalar b) -> Scalar"
44)
45def floordiv(a: _SymScalar, b: _SymScalar) -> _SymScalar:
46    return a // b  # pyre-ignore
47
48
49@bind_pattern_to_op(
50    executorch_prims_lib, "truediv.Scalar(Scalar a, Scalar b) -> Scalar"
51)
52def truediv(a: _SymScalar, b: _SymScalar) -> _SymScalar:
53    return a / b  # pyre-ignore
54
55
56@bind_pattern_to_op(executorch_prims_lib, "sym_float.Scalar(Scalar a) -> Scalar")
57def sym_float(a: _SymScalar) -> _SymScalar:
58    return float(a)  # pyre-ignore
59
60
61# TODO: ideally we should return SymBool in the schema, but it seems
62# the schema parser does not recognize SymBool yet: P629748075
63@bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool")
64def gt(a: _SymScalar, b: _SymScalar) -> bool:
65    return a > b  # pyre-ignore
66
67
68@bind_pattern_to_op(executorch_prims_lib, "lt.Scalar(Scalar a, Scalar b) -> bool")
69def lt(a: _SymScalar, b: _SymScalar) -> bool:
70    return a < b  # pyre-ignore
71
72
73@bind_pattern_to_op(executorch_prims_lib, "ge.Scalar(Scalar a, Scalar b) -> bool")
74def ge(a: _SymScalar, b: _SymScalar) -> bool:
75    return a >= b  # pyre-ignore
76
77
78@bind_pattern_to_op(executorch_prims_lib, "le.Scalar(Scalar a, Scalar b) -> bool")
79def le(a: _SymScalar, b: _SymScalar) -> bool:
80    return a <= b  # pyre-ignore
81
82
83@bind_pattern_to_op(executorch_prims_lib, "eq.Scalar(Scalar a, Scalar b) -> bool")
84def eq(a: _SymScalar, b: _SymScalar) -> bool:
85    return a == b
86
87
88@bind_pattern_to_op(executorch_prims_lib, "mod.Scalar(SymInt a, SymInt b) -> SymInt")
89def mod(a: SymInt, b: SymInt) -> SymInt:
90    return SymInt(int(a) % int(b))
91
92
93@bind_pattern_to_op(executorch_prims_lib, "neg.Scalar(Scalar a) -> Scalar")
94def neg(a: _SymScalar) -> _SymScalar:
95    return -a  # pyre-ignore
96
97
98@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar")
99def ceil(a: _SymScalar) -> _SymScalar:
100    return math.ceil(a)  # pyre-ignore
101
102
103@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar")
104def builtin_round(a: _SymScalar) -> _SymScalar:
105    return round(a)  # pyre-ignore
106
107
108@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
109def trunc(a: _SymScalar) -> _SymScalar:
110    return math.trunc(a)  # pyre-ignore
111
112
113_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = {
114    builtins.round: ops.backend.executorch_prim.round.Scalar,
115    math.ceil: ops.backend.executorch_prim.ceil.Scalar,
116    math.trunc: ops.backend.executorch_prim.trunc.Scalar,
117    operator.sub: ops.backend.executorch_prim.sub.Scalar,
118    operator.mul: ops.backend.executorch_prim.mul.Scalar,
119    operator.add: ops.backend.executorch_prim.add.Scalar,
120    operator.floordiv: ops.backend.executorch_prim.floordiv.Scalar,
121    operator.truediv: ops.backend.executorch_prim.truediv.Scalar,
122    operator.eq: ops.backend.executorch_prim.eq.Scalar,
123    operator.gt: ops.backend.executorch_prim.gt.Scalar,
124    operator.lt: ops.backend.executorch_prim.lt.Scalar,
125    operator.ge: ops.backend.executorch_prim.ge.Scalar,
126    operator.le: ops.backend.executorch_prim.le.Scalar,
127    operator.mod: ops.backend.executorch_prim.mod.Scalar,
128    operator.neg: ops.backend.executorch_prim.neg.Scalar,
129    torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar,
130}
131
132
133_EXECUTORCH_SYM_OPS: Set[OpOverload] = set(
134    _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values()
135)
136_EXECUTORCH_SYM_OPS.update(
137    {
138        torch.ops.aten.sym_stride.int,
139        torch.ops.aten.sym_size.int,
140        torch.ops.aten.sym_numel.default,
141        torch.ops.aten._local_scalar_dense.default,
142        torch.ops.aten.sym_constrain_range_for_size.default,
143        torch.ops.aten.sym_constrain_range.default,
144    }
145)
146
147
148def is_sym_op(target) -> bool:
149    return (
150        target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.keys()
151        or target in _EXECUTORCH_SYM_OPS
152    )
153