• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13
14from executorch.backends.apple.mps.serialization.mps_graph_schema import (
15    MPSBatchNorm,
16    MPSGraph,
17    MPSLayerNorm,
18    MPSNode,
19)
20from executorch.backends.apple.mps.utils.mps_utils import get_input_node, get_scalar_val
21from executorch.exir.sym_util import eval_shape
22
23
24@register_node_visitor
25class BatchNorm(NodeVisitor):
26    target = "aten._native_batch_norm_legit_no_training.default"
27
28    def __init__(self, *args) -> None:
29        super().__init__(*args)
30
31    def define_node(
32        self,
33        node: torch.fx.Node,
34        mps_graph: MPSGraph,
35    ) -> None:
36
37        input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
38        weight_id = self.define_tensor(get_input_node(node, 1), mps_graph)
39        bias_id = self.define_tensor(get_input_node(node, 2), mps_graph)
40        mean_id = self.define_tensor(get_input_node(node, 3), mps_graph)
41        var_id = self.define_tensor(get_input_node(node, 4), mps_graph)
42        momentum: float = get_scalar_val(node, 5)
43        epsilon: float = get_scalar_val(node, 6)
44
45        output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph)
46
47        mps_node = MPSNode(
48            mpsnode_union=MPSBatchNorm(
49                input_id=input_id,
50                mean_id=mean_id,
51                var_id=var_id,
52                weight_id=weight_id,
53                bias_id=bias_id,
54                momentum=momentum,
55                epsilon=epsilon,
56                output1_id=output1_id,
57                output2_id=output2_id,
58                output3_id=output3_id,
59            )
60        )
61        mps_graph.mps_nodes.append(mps_node)
62
63
64@register_node_visitor
65class LayerNorm(NodeVisitor):
66    target = "aten.native_layer_norm.default"
67
68    def __init__(self, *args) -> None:
69        super().__init__(*args)
70
71    def define_node(
72        self,
73        node: torch.fx.Node,
74        mps_graph: MPSGraph,
75    ) -> None:
76
77        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
78        normalized_shape = eval_shape(cast(List[torch.SymInt], node.args[1]))
79        weight_id = self.define_tensor(get_input_node(node, 2), mps_graph)
80        bias_id = self.define_tensor(get_input_node(node, 3), mps_graph)
81        epsilon: float = get_scalar_val(node, 4)
82        output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph)
83
84        mps_graph.mps_nodes.append(
85            MPSNode(
86                mpsnode_union=MPSLayerNorm(
87                    input1_id=input1_id,
88                    normalized_shape=normalized_shape,
89                    weight_id=weight_id,
90                    bias_id=bias_id,
91                    eps=epsilon,
92                    output1_id=output1_id,
93                    output2_id=output2_id,
94                    output3_id=output3_id,
95                )
96            )
97        )
98