• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSAvgPool2D,
15    MPSGraph,
16    MPSMaxPool2DWithIndices,
17    MPSNode,
18)
19from executorch.backends.apple.mps.utils.mps_utils import get_input_node
20
21
22@register_node_visitor
23class MaxPool2DWithIndicesVisitor(NodeVisitor):
24    target = "aten.max_pool2d_with_indices.default"
25
26    def __init__(self, *args) -> None:
27        super().__init__(*args)
28
29    def define_node(
30        self,
31        node: torch.fx.Node,
32        mps_graph: MPSGraph,
33    ) -> None:
34        n_args = len(node.args)
35        if n_args > 6:
36            raise AssertionError(
37                f"Unexpected number of input parameters for {self.target}"
38            )
39
40        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
41
42        padding = [0, 0]
43        dilation = [1, 1]
44        ceil_mode = False
45        kernel_size = cast(List[int], node.args[1])
46        stride = cast(List[int], node.args[2])
47        if n_args >= 4:
48            padding = cast(List[int], node.args[3])
49        if n_args >= 5:
50            dilation = cast(List[int], node.args[4])
51        if n_args == 6:
52            ceil_mode = cast(bool, node.args[5])
53        padding_top = padding[0]
54        padding_left = padding[1]
55        padding_bottom = padding[0] * stride[0] if ceil_mode else padding[0]
56        padding_right = padding[1] * stride[1] if ceil_mode else padding[1]
57
58        output1_id, output2_id = self.define_tensor_list(node, mps_graph)
59        mps_graph.mps_nodes.append(
60            MPSNode(
61                mpsnode_union=MPSMaxPool2DWithIndices(
62                    input1_id=input1_id,
63                    kernel_height=kernel_size[0],
64                    kernel_width=kernel_size[1],
65                    stride_height=stride[0],
66                    stride_width=stride[1],
67                    padding_left=padding_left,
68                    padding_right=padding_right,
69                    padding_top=padding_top,
70                    padding_bottom=padding_bottom,
71                    dilation_height=dilation[0],
72                    dilation_width=dilation[1],
73                    ceil_mode=ceil_mode,
74                    output1_id=output1_id,
75                    output2_id=output2_id,
76                )
77            )
78        )
79
80
81@register_node_visitor
82class AvgPool2DVisitor(NodeVisitor):
83    target = "aten.avg_pool2d.default"
84
85    def __init__(self, *args) -> None:
86        super().__init__(*args)
87
88    def define_node(
89        self,
90        node: torch.fx.Node,
91        mps_graph: MPSGraph,
92    ) -> None:
93        n_args = len(node.args)
94        if n_args > 7:
95            raise AssertionError(
96                f"Unexpected number of input parameters for {self.target}"
97            )
98
99        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
100        output1_id = self.define_tensor(node, mps_graph)
101
102        padding_top, padding_left = [0, 0]
103        dilation_height, dilation_width = [1, 1]
104
105        ceil_mode = False
106        count_include_pad = True
107        divisor_override = 0
108        kernel_height, kernel_width = cast(List[int], node.args[1])
109        stride_height, stride_width = cast(List[int], node.args[2])
110        if n_args >= 4:
111            padding_top, padding_left = cast(List[int], node.args[3])
112        if n_args >= 5:
113            ceil_mode = cast(bool, node.args[4])
114        if n_args == 6:
115            count_include_pad = cast(bool, node.args[5])
116        if n_args == 7:
117            divisor_override = cast(int, node.args[6])
118
119        padding_bottom = padding_top * stride_height if ceil_mode else padding_top
120        padding_right = padding_left * stride_width if ceil_mode else padding_left
121
122        mps_graph.mps_nodes.append(
123            MPSNode(
124                mpsnode_union=MPSAvgPool2D(
125                    input1_id=input1_id,
126                    kernel_height=kernel_height,
127                    kernel_width=kernel_width,
128                    stride_height=stride_height,
129                    stride_width=stride_width,
130                    padding_left=padding_left,
131                    padding_right=padding_right,
132                    padding_top=padding_top,
133                    padding_bottom=padding_bottom,
134                    dilation_height=dilation_height,
135                    dilation_width=dilation_width,
136                    ceil_mode=ceil_mode,
137                    count_include_pad=count_include_pad,
138                    divisor_override=divisor_override,
139                    output1_id=output1_id,
140                )
141            )
142        )
143