• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import torch
7from executorch.backends.apple.mps.operators.node_visitor import (
8    NodeVisitor,
9    register_node_visitor,
10)
11from executorch.backends.apple.mps.serialization.mps_graph_schema import (
12    MPSAbs,
13    MPSAcos,
14    MPSAcosh,
15    MPSAsin,
16    MPSAsinh,
17    MPSAtan,
18    MPSAtanh,
19    MPSBitwiseNot,
20    MPSCeil,
21    MPSCos,
22    MPSCosh,
23    MPSErf,
24    MPSExp,
25    MPSExp2,
26    MPSFloor,
27    MPSGraph,
28    MPSIsinf,
29    MPSIsnan,
30    MPSLog,
31    MPSLog10,
32    MPSLog2,
33    MPSLogicalNot,
34    MPSNeg,
35    MPSReciprocal,
36    MPSRound,
37    MPSRsqrt,
38    MPSSigmoid,
39    MPSSign,
40    MPSSin,
41    MPSSinh,
42    MPSSqrt,
43    MPSTan,
44    MPSTanh,
45)
46from executorch.exir.dialects._ops import ops as exir_ops
47
48
49@register_node_visitor
50class UnaryOpVisitor(NodeVisitor):
51    target = [
52        "aten.exp.default",
53        "aten.exp2.default",
54        "aten.reciprocal.default",
55        "aten.sqrt.default",
56        "aten.neg.default",
57        "aten.log.default",
58        "aten.log10.default",
59        "aten.log2.default",
60        "aten.erf.default",
61        "aten.floor.default",
62        "aten.ceil.default",
63        "aten.rsqrt.default",
64        "aten.sigmoid.default",
65        "aten.sin.default",
66        "aten.sign.default",
67        "aten.cos.default",
68        "aten.tan.default",
69        "aten.abs.default",
70        "aten.asin.default",
71        "aten.acos.default",
72        "aten.atan.default",
73        "aten.sinh.default",
74        "aten.cosh.default",
75        "aten.tanh.default",
76        "aten.asinh.default",
77        "aten.acosh.default",
78        "aten.atanh.default",
79        "aten.bitwise_not.default",
80        "aten.isnan.default",
81        "aten.isinf.default",
82        "aten.round.default",
83        "aten.logical_not.default",
84    ]
85
86    def __init__(self, *args) -> None:
87        super().__init__(*args)
88        self.unary_op = {
89            exir_ops.edge.aten.exp.default: MPSExp,
90            exir_ops.edge.aten.exp2.default: MPSExp2,
91            exir_ops.edge.aten.reciprocal.default: MPSReciprocal,
92            exir_ops.edge.aten.sqrt.default: MPSSqrt,
93            exir_ops.edge.aten.neg.default: MPSNeg,
94            exir_ops.edge.aten.log.default: MPSLog,
95            exir_ops.edge.aten.log10.default: MPSLog10,
96            exir_ops.edge.aten.log2.default: MPSLog2,
97            exir_ops.edge.aten.erf.default: MPSErf,
98            exir_ops.edge.aten.floor.default: MPSFloor,
99            exir_ops.edge.aten.ceil.default: MPSCeil,
100            exir_ops.edge.aten.rsqrt.default: MPSRsqrt,
101            exir_ops.edge.aten.sigmoid.default: MPSSigmoid,
102            exir_ops.edge.aten.sin.default: MPSSin,
103            exir_ops.edge.aten.sign.default: MPSSign,
104            exir_ops.edge.aten.cos.default: MPSCos,
105            exir_ops.edge.aten.tan.default: MPSTan,
106            exir_ops.edge.aten.abs.default: MPSAbs,
107            exir_ops.edge.aten.asin.default: MPSAsin,
108            exir_ops.edge.aten.acos.default: MPSAcos,
109            exir_ops.edge.aten.atan.default: MPSAtan,
110            exir_ops.edge.aten.sinh.default: MPSSinh,
111            exir_ops.edge.aten.cosh.default: MPSCosh,
112            exir_ops.edge.aten.tanh.default: MPSTanh,
113            exir_ops.edge.aten.asinh.default: MPSAsinh,
114            exir_ops.edge.aten.acosh.default: MPSAcosh,
115            exir_ops.edge.aten.atanh.default: MPSAtanh,
116            exir_ops.edge.aten.bitwise_not.default: MPSBitwiseNot,
117            exir_ops.edge.aten.isnan.default: MPSIsnan,
118            exir_ops.edge.aten.isinf.default: MPSIsinf,
119            exir_ops.edge.aten.round.default: MPSRound,
120            exir_ops.edge.aten.logical_not.default: MPSLogicalNot,
121        }
122
123    def define_node(
124        self,
125        node: torch.fx.Node,
126        mps_graph: MPSGraph,
127    ) -> None:
128        mps_graph.mps_nodes.append(
129            self.create_unary_node(node, mps_graph, self.unary_op[node.target])
130        )
131