1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13 14from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 15 MPSBatchNorm, 16 MPSGraph, 17 MPSLayerNorm, 18 MPSNode, 19) 20from executorch.backends.apple.mps.utils.mps_utils import get_input_node, get_scalar_val 21from executorch.exir.sym_util import eval_shape 22 23 24@register_node_visitor 25class BatchNorm(NodeVisitor): 26 target = "aten._native_batch_norm_legit_no_training.default" 27 28 def __init__(self, *args) -> None: 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 mps_graph: MPSGraph, 35 ) -> None: 36 37 input_id = self.define_tensor(get_input_node(node, 0), mps_graph) 38 weight_id = self.define_tensor(get_input_node(node, 1), mps_graph) 39 bias_id = self.define_tensor(get_input_node(node, 2), mps_graph) 40 mean_id = self.define_tensor(get_input_node(node, 3), mps_graph) 41 var_id = self.define_tensor(get_input_node(node, 4), mps_graph) 42 momentum: float = get_scalar_val(node, 5) 43 epsilon: float = get_scalar_val(node, 6) 44 45 output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph) 46 47 mps_node = MPSNode( 48 mpsnode_union=MPSBatchNorm( 49 input_id=input_id, 50 mean_id=mean_id, 51 var_id=var_id, 52 weight_id=weight_id, 53 bias_id=bias_id, 54 momentum=momentum, 55 epsilon=epsilon, 56 output1_id=output1_id, 57 output2_id=output2_id, 58 output3_id=output3_id, 59 ) 60 ) 61 mps_graph.mps_nodes.append(mps_node) 62 63 64@register_node_visitor 65class LayerNorm(NodeVisitor): 66 target = "aten.native_layer_norm.default" 67 68 def __init__(self, *args) -> None: 69 super().__init__(*args) 70 71 def define_node( 72 self, 73 node: torch.fx.Node, 74 mps_graph: MPSGraph, 75 ) -> None: 76 77 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 78 normalized_shape = eval_shape(cast(List[torch.SymInt], node.args[1])) 79 weight_id = self.define_tensor(get_input_node(node, 2), mps_graph) 80 bias_id = self.define_tensor(get_input_node(node, 3), mps_graph) 81 epsilon: float = get_scalar_val(node, 4) 82 output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph) 83 84 mps_graph.mps_nodes.append( 85 MPSNode( 86 mpsnode_union=MPSLayerNorm( 87 input1_id=input1_id, 88 normalized_shape=normalized_shape, 89 weight_id=weight_id, 90 bias_id=bias_id, 91 eps=epsilon, 92 output1_id=output1_id, 93 output2_id=output2_id, 94 output3_id=output3_id, 95 ) 96 ) 97 ) 98