1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.backends.apple.mps.operators.node_visitor import ( 8 NodeVisitor, 9 register_node_visitor, 10) 11from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 12 MPSAbs, 13 MPSAcos, 14 MPSAcosh, 15 MPSAsin, 16 MPSAsinh, 17 MPSAtan, 18 MPSAtanh, 19 MPSBitwiseNot, 20 MPSCeil, 21 MPSCos, 22 MPSCosh, 23 MPSErf, 24 MPSExp, 25 MPSExp2, 26 MPSFloor, 27 MPSGraph, 28 MPSIsinf, 29 MPSIsnan, 30 MPSLog, 31 MPSLog10, 32 MPSLog2, 33 MPSLogicalNot, 34 MPSNeg, 35 MPSReciprocal, 36 MPSRound, 37 MPSRsqrt, 38 MPSSigmoid, 39 MPSSign, 40 MPSSin, 41 MPSSinh, 42 MPSSqrt, 43 MPSTan, 44 MPSTanh, 45) 46from executorch.exir.dialects._ops import ops as exir_ops 47 48 49@register_node_visitor 50class UnaryOpVisitor(NodeVisitor): 51 target = [ 52 "aten.exp.default", 53 "aten.exp2.default", 54 "aten.reciprocal.default", 55 "aten.sqrt.default", 56 "aten.neg.default", 57 "aten.log.default", 58 "aten.log10.default", 59 "aten.log2.default", 60 "aten.erf.default", 61 "aten.floor.default", 62 "aten.ceil.default", 63 "aten.rsqrt.default", 64 "aten.sigmoid.default", 65 "aten.sin.default", 66 "aten.sign.default", 67 "aten.cos.default", 68 "aten.tan.default", 69 "aten.abs.default", 70 "aten.asin.default", 71 "aten.acos.default", 72 "aten.atan.default", 73 "aten.sinh.default", 74 "aten.cosh.default", 75 "aten.tanh.default", 76 "aten.asinh.default", 77 "aten.acosh.default", 78 "aten.atanh.default", 79 "aten.bitwise_not.default", 80 "aten.isnan.default", 81 "aten.isinf.default", 82 "aten.round.default", 83 "aten.logical_not.default", 84 ] 85 86 def __init__(self, *args) -> None: 87 super().__init__(*args) 88 self.unary_op = { 89 exir_ops.edge.aten.exp.default: MPSExp, 90 exir_ops.edge.aten.exp2.default: MPSExp2, 91 exir_ops.edge.aten.reciprocal.default: MPSReciprocal, 92 exir_ops.edge.aten.sqrt.default: MPSSqrt, 93 exir_ops.edge.aten.neg.default: MPSNeg, 94 exir_ops.edge.aten.log.default: MPSLog, 95 exir_ops.edge.aten.log10.default: MPSLog10, 96 exir_ops.edge.aten.log2.default: MPSLog2, 97 exir_ops.edge.aten.erf.default: MPSErf, 98 exir_ops.edge.aten.floor.default: MPSFloor, 99 exir_ops.edge.aten.ceil.default: MPSCeil, 100 exir_ops.edge.aten.rsqrt.default: MPSRsqrt, 101 exir_ops.edge.aten.sigmoid.default: MPSSigmoid, 102 exir_ops.edge.aten.sin.default: MPSSin, 103 exir_ops.edge.aten.sign.default: MPSSign, 104 exir_ops.edge.aten.cos.default: MPSCos, 105 exir_ops.edge.aten.tan.default: MPSTan, 106 exir_ops.edge.aten.abs.default: MPSAbs, 107 exir_ops.edge.aten.asin.default: MPSAsin, 108 exir_ops.edge.aten.acos.default: MPSAcos, 109 exir_ops.edge.aten.atan.default: MPSAtan, 110 exir_ops.edge.aten.sinh.default: MPSSinh, 111 exir_ops.edge.aten.cosh.default: MPSCosh, 112 exir_ops.edge.aten.tanh.default: MPSTanh, 113 exir_ops.edge.aten.asinh.default: MPSAsinh, 114 exir_ops.edge.aten.acosh.default: MPSAcosh, 115 exir_ops.edge.aten.atanh.default: MPSAtanh, 116 exir_ops.edge.aten.bitwise_not.default: MPSBitwiseNot, 117 exir_ops.edge.aten.isnan.default: MPSIsnan, 118 exir_ops.edge.aten.isinf.default: MPSIsinf, 119 exir_ops.edge.aten.round.default: MPSRound, 120 exir_ops.edge.aten.logical_not.default: MPSLogicalNot, 121 } 122 123 def define_node( 124 self, 125 node: torch.fx.Node, 126 mps_graph: MPSGraph, 127 ) -> None: 128 mps_graph.mps_nodes.append( 129 self.create_unary_node(node, mps_graph, self.unary_op[node.target]) 130 ) 131