1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSGELU, 15 MPSGraph, 16 MPSHardTanh, 17 MPSLeakyReLU, 18 MPSLogSoftmax, 19 MPSReLU, 20 MPSSoftmax, 21) 22from executorch.backends.apple.mps.utils.mps_utils import get_scalar_val 23from executorch.exir.dialects._ops import ops as exir_ops 24 25 26@register_node_visitor 27class HardTanhVisitor(NodeVisitor): 28 target = "aten.hardtanh.default" 29 30 def __init__(self, *args) -> None: 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 mps_graph: MPSGraph, 37 ) -> None: 38 mps_node = self.create_unary_node(node, mps_graph, MPSHardTanh) 39 mps_node.mpsnode_union.min_value = get_scalar_val(node, 1) 40 mps_node.mpsnode_union.max_value = get_scalar_val(node, 2) 41 42 mps_graph.mps_nodes.append(mps_node) 43 44 45@register_node_visitor 46class ReLU_LeakyReLU_GELU_Visitor(NodeVisitor): 47 target = ["aten.relu.default", "aten.leaky_relu.default", "aten.gelu.default"] 48 49 def __init__(self, *args) -> None: 50 super().__init__(*args) 51 self.activation_ops = { 52 exir_ops.edge.aten.relu.default: MPSReLU, 53 exir_ops.edge.aten.leaky_relu.default: MPSLeakyReLU, 54 exir_ops.edge.aten.gelu.default: MPSGELU, 55 } 56 57 def define_node( 58 self, 59 node: torch.fx.Node, 60 mps_graph: MPSGraph, 61 ) -> None: 62 node_type = self.activation_ops[node.target] 63 mps_node = self.create_unary_node(node, mps_graph, node_type) 64 65 if node_type is MPSLeakyReLU and len(node.args) == 2: 66 mps_node.mpsnode_union.negative_slope = cast(float, node.args[1]) 67 elif ( 68 node_type is MPSGELU 69 and node.kwargs 70 and node.kwargs["approximate"] is not None 71 ): 72 mps_node.mpsnode_union.approximate = node.kwargs["approximate"] 73 74 mps_graph.mps_nodes.append(mps_node) 75 76 77@register_node_visitor 78class Softmax_LogSoftmax_Visitor(NodeVisitor): 79 target = ["aten._softmax.default", "aten._log_softmax.default"] 80 81 def __init__(self, *args) -> None: 82 super().__init__(*args) 83 84 def define_node( 85 self, 86 node: torch.fx.Node, 87 mps_graph: MPSGraph, 88 ) -> None: 89 node_type = ( 90 MPSSoftmax 91 if node.target == exir_ops.edge.aten._softmax.default 92 else MPSLogSoftmax 93 ) 94 mps_node = self.create_unary_node(node, mps_graph, node_type) 95 96 mps_node.mpsnode_union.dim = cast(int, node.args[1]) 97 mps_node.mpsnode_union.half_to_float = cast(bool, node.args[2]) 98 99 mps_graph.mps_nodes.append(mps_node) 100