1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.backends.apple.mps.operators.node_visitor import ( 8 NodeVisitor, 9 register_node_visitor, 10) 11from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 12 MPSAdd, 13 MPSBitwiseAnd, 14 MPSBitwiseOr, 15 MPSBitwiseXor, 16 MPSDiv, 17 MPSEq, 18 MPSFmod, 19 MPSGe, 20 MPSGraph, 21 MPSGt, 22 MPSLe, 23 MPSLt, 24 MPSMinimum, 25 MPSMul, 26 MPSNe, 27 MPSPow, 28 MPSRemainder, 29 MPSSub, 30) 31from executorch.exir.dialects._ops import ops as exir_ops 32 33 34@register_node_visitor 35class BinaryOpVisitor(NodeVisitor): 36 target = [ 37 # Arithmetic Binary Ops 38 "aten.add.Tensor", 39 "aten.add.Scalar", 40 "aten.sub.Tensor", 41 "aten.sub.Scalar", 42 "aten.div.Tensor", 43 "aten.div.Tensor_mode", 44 "aten.mul.Tensor", 45 "aten.mul.Scalar", 46 "aten.pow.Tensor_Tensor", 47 "aten.pow.Tensor_Scalar", 48 "aten.floor_divide.default", 49 "aten.fmod.Tensor", 50 "aten.fmod.Scalar", 51 "aten.remainder.Tensor", 52 "aten.remainder.Scalar", 53 "aten.bitwise_and.Tensor", 54 "aten.bitwise_and.Scalar", 55 "aten.bitwise_or.Tensor", 56 "aten.bitwise_or.Scalar", 57 "aten.bitwise_xor.Tensor", 58 "aten.bitwise_xor.Scalar", 59 "aten.minimum.default", 60 ] 61 62 def __init__(self, *args) -> None: 63 super().__init__(*args) 64 self.op_mapping = { 65 exir_ops.edge.aten.add.Tensor: MPSAdd, 66 exir_ops.edge.aten.add.Scalar: MPSAdd, 67 exir_ops.edge.aten.sub.Tensor: MPSSub, 68 exir_ops.edge.aten.sub.Scalar: MPSSub, 69 exir_ops.edge.aten.div.Tensor: MPSDiv, 70 exir_ops.edge.aten.div.Tensor_mode: MPSDiv, 71 exir_ops.edge.aten.mul.Tensor: MPSMul, 72 exir_ops.edge.aten.mul.Scalar: MPSMul, 73 exir_ops.edge.aten.pow.Tensor_Tensor: MPSPow, 74 exir_ops.edge.aten.pow.Tensor_Scalar: MPSPow, 75 exir_ops.edge.aten.floor_divide.default: MPSDiv, 76 exir_ops.edge.aten.fmod.Tensor: MPSFmod, 77 exir_ops.edge.aten.fmod.Scalar: MPSFmod, 78 exir_ops.edge.aten.remainder.Tensor: MPSRemainder, 79 exir_ops.edge.aten.remainder.Scalar: MPSRemainder, 80 exir_ops.edge.aten.bitwise_and.Tensor: MPSBitwiseAnd, 81 exir_ops.edge.aten.bitwise_and.Scalar: MPSBitwiseAnd, 82 exir_ops.edge.aten.bitwise_or.Tensor: MPSBitwiseOr, 83 exir_ops.edge.aten.bitwise_or.Scalar: MPSBitwiseOr, 84 exir_ops.edge.aten.bitwise_xor.Tensor: MPSBitwiseXor, 85 exir_ops.edge.aten.bitwise_xor.Scalar: MPSBitwiseXor, 86 exir_ops.edge.aten.minimum.default: MPSMinimum, 87 } 88 89 def define_node( 90 self, 91 node: torch.fx.Node, 92 mps_graph: MPSGraph, 93 ) -> None: 94 mps_node = self.create_binary_node( 95 node, mps_graph, self.op_mapping[node.target] 96 ) 97 98 if node.kwargs and "alpha" in node.kwargs and node.kwargs["alpha"] is not None: 99 mps_node.mpsnode_union.alpha = node.kwargs["alpha"] 100 101 if ( 102 node.kwargs 103 and "rounding_mode" in node.kwargs 104 and node.kwargs["rounding_mode"] is not None 105 ): 106 mps_node.mpsnode_union.rounding_mode = node.kwargs["rounding_mode"] 107 108 mps_graph.mps_nodes.append(mps_node) 109 110 111## 112## Boolean Binary Ops 113## 114@register_node_visitor 115class ComparasionOpVisitor(NodeVisitor): 116 target = [ 117 "aten.eq.Tensor", 118 "aten.ne.Tensor", 119 "aten.ge.Tensor", 120 "aten.gt.Tensor", 121 "aten.le.Tensor", 122 "aten.lt.Tensor", 123 "aten.eq.Scalar", 124 "aten.ne.Scalar", 125 "aten.ge.Scalar", 126 "aten.gt.Scalar", 127 "aten.le.Scalar", 128 "aten.lt.Scalar", 129 ] 130 131 def __init__(self, *args) -> None: 132 super().__init__(*args) 133 self.comparison_ops = { 134 exir_ops.edge.aten.eq.Tensor: MPSEq, 135 exir_ops.edge.aten.ne.Tensor: MPSNe, 136 exir_ops.edge.aten.ge.Tensor: MPSGe, 137 exir_ops.edge.aten.gt.Tensor: MPSGt, 138 exir_ops.edge.aten.le.Tensor: MPSLe, 139 exir_ops.edge.aten.lt.Tensor: MPSLt, 140 exir_ops.edge.aten.eq.Scalar: MPSEq, 141 exir_ops.edge.aten.ne.Scalar: MPSNe, 142 exir_ops.edge.aten.ge.Scalar: MPSGe, 143 exir_ops.edge.aten.gt.Scalar: MPSGt, 144 exir_ops.edge.aten.le.Scalar: MPSLe, 145 exir_ops.edge.aten.lt.Scalar: MPSLt, 146 } 147 148 def define_node( 149 self, 150 node: torch.fx.Node, 151 mps_graph: MPSGraph, 152 ) -> None: 153 154 mps_graph.mps_nodes.append( 155 self.create_binary_node(node, mps_graph, self.comparison_ops[node.target]) 156 ) 157