• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import torch
7from executorch.backends.apple.mps.operators.node_visitor import (
8    NodeVisitor,
9    register_node_visitor,
10)
11from executorch.backends.apple.mps.serialization.mps_graph_schema import (
12    MPSAdd,
13    MPSBitwiseAnd,
14    MPSBitwiseOr,
15    MPSBitwiseXor,
16    MPSDiv,
17    MPSEq,
18    MPSFmod,
19    MPSGe,
20    MPSGraph,
21    MPSGt,
22    MPSLe,
23    MPSLt,
24    MPSMinimum,
25    MPSMul,
26    MPSNe,
27    MPSPow,
28    MPSRemainder,
29    MPSSub,
30)
31from executorch.exir.dialects._ops import ops as exir_ops
32
33
34@register_node_visitor
35class BinaryOpVisitor(NodeVisitor):
36    target = [
37        # Arithmetic Binary Ops
38        "aten.add.Tensor",
39        "aten.add.Scalar",
40        "aten.sub.Tensor",
41        "aten.sub.Scalar",
42        "aten.div.Tensor",
43        "aten.div.Tensor_mode",
44        "aten.mul.Tensor",
45        "aten.mul.Scalar",
46        "aten.pow.Tensor_Tensor",
47        "aten.pow.Tensor_Scalar",
48        "aten.floor_divide.default",
49        "aten.fmod.Tensor",
50        "aten.fmod.Scalar",
51        "aten.remainder.Tensor",
52        "aten.remainder.Scalar",
53        "aten.bitwise_and.Tensor",
54        "aten.bitwise_and.Scalar",
55        "aten.bitwise_or.Tensor",
56        "aten.bitwise_or.Scalar",
57        "aten.bitwise_xor.Tensor",
58        "aten.bitwise_xor.Scalar",
59        "aten.minimum.default",
60    ]
61
62    def __init__(self, *args) -> None:
63        super().__init__(*args)
64        self.op_mapping = {
65            exir_ops.edge.aten.add.Tensor: MPSAdd,
66            exir_ops.edge.aten.add.Scalar: MPSAdd,
67            exir_ops.edge.aten.sub.Tensor: MPSSub,
68            exir_ops.edge.aten.sub.Scalar: MPSSub,
69            exir_ops.edge.aten.div.Tensor: MPSDiv,
70            exir_ops.edge.aten.div.Tensor_mode: MPSDiv,
71            exir_ops.edge.aten.mul.Tensor: MPSMul,
72            exir_ops.edge.aten.mul.Scalar: MPSMul,
73            exir_ops.edge.aten.pow.Tensor_Tensor: MPSPow,
74            exir_ops.edge.aten.pow.Tensor_Scalar: MPSPow,
75            exir_ops.edge.aten.floor_divide.default: MPSDiv,
76            exir_ops.edge.aten.fmod.Tensor: MPSFmod,
77            exir_ops.edge.aten.fmod.Scalar: MPSFmod,
78            exir_ops.edge.aten.remainder.Tensor: MPSRemainder,
79            exir_ops.edge.aten.remainder.Scalar: MPSRemainder,
80            exir_ops.edge.aten.bitwise_and.Tensor: MPSBitwiseAnd,
81            exir_ops.edge.aten.bitwise_and.Scalar: MPSBitwiseAnd,
82            exir_ops.edge.aten.bitwise_or.Tensor: MPSBitwiseOr,
83            exir_ops.edge.aten.bitwise_or.Scalar: MPSBitwiseOr,
84            exir_ops.edge.aten.bitwise_xor.Tensor: MPSBitwiseXor,
85            exir_ops.edge.aten.bitwise_xor.Scalar: MPSBitwiseXor,
86            exir_ops.edge.aten.minimum.default: MPSMinimum,
87        }
88
89    def define_node(
90        self,
91        node: torch.fx.Node,
92        mps_graph: MPSGraph,
93    ) -> None:
94        mps_node = self.create_binary_node(
95            node, mps_graph, self.op_mapping[node.target]
96        )
97
98        if node.kwargs and "alpha" in node.kwargs and node.kwargs["alpha"] is not None:
99            mps_node.mpsnode_union.alpha = node.kwargs["alpha"]
100
101        if (
102            node.kwargs
103            and "rounding_mode" in node.kwargs
104            and node.kwargs["rounding_mode"] is not None
105        ):
106            mps_node.mpsnode_union.rounding_mode = node.kwargs["rounding_mode"]
107
108        mps_graph.mps_nodes.append(mps_node)
109
110
111##
112## Boolean Binary Ops
113##
114@register_node_visitor
115class ComparasionOpVisitor(NodeVisitor):
116    target = [
117        "aten.eq.Tensor",
118        "aten.ne.Tensor",
119        "aten.ge.Tensor",
120        "aten.gt.Tensor",
121        "aten.le.Tensor",
122        "aten.lt.Tensor",
123        "aten.eq.Scalar",
124        "aten.ne.Scalar",
125        "aten.ge.Scalar",
126        "aten.gt.Scalar",
127        "aten.le.Scalar",
128        "aten.lt.Scalar",
129    ]
130
131    def __init__(self, *args) -> None:
132        super().__init__(*args)
133        self.comparison_ops = {
134            exir_ops.edge.aten.eq.Tensor: MPSEq,
135            exir_ops.edge.aten.ne.Tensor: MPSNe,
136            exir_ops.edge.aten.ge.Tensor: MPSGe,
137            exir_ops.edge.aten.gt.Tensor: MPSGt,
138            exir_ops.edge.aten.le.Tensor: MPSLe,
139            exir_ops.edge.aten.lt.Tensor: MPSLt,
140            exir_ops.edge.aten.eq.Scalar: MPSEq,
141            exir_ops.edge.aten.ne.Scalar: MPSNe,
142            exir_ops.edge.aten.ge.Scalar: MPSGe,
143            exir_ops.edge.aten.gt.Scalar: MPSGt,
144            exir_ops.edge.aten.le.Scalar: MPSLe,
145            exir_ops.edge.aten.lt.Scalar: MPSLt,
146        }
147
148    def define_node(
149        self,
150        node: torch.fx.Node,
151        mps_graph: MPSGraph,
152    ) -> None:
153
154        mps_graph.mps_nodes.append(
155            self.create_binary_node(node, mps_graph, self.comparison_ops[node.target])
156        )
157