1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSAvgPool2D, 15 MPSGraph, 16 MPSMaxPool2DWithIndices, 17 MPSNode, 18) 19from executorch.backends.apple.mps.utils.mps_utils import get_input_node 20 21 22@register_node_visitor 23class MaxPool2DWithIndicesVisitor(NodeVisitor): 24 target = "aten.max_pool2d_with_indices.default" 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 mps_graph: MPSGraph, 33 ) -> None: 34 n_args = len(node.args) 35 if n_args > 6: 36 raise AssertionError( 37 f"Unexpected number of input parameters for {self.target}" 38 ) 39 40 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 41 42 padding = [0, 0] 43 dilation = [1, 1] 44 ceil_mode = False 45 kernel_size = cast(List[int], node.args[1]) 46 stride = cast(List[int], node.args[2]) 47 if n_args >= 4: 48 padding = cast(List[int], node.args[3]) 49 if n_args >= 5: 50 dilation = cast(List[int], node.args[4]) 51 if n_args == 6: 52 ceil_mode = cast(bool, node.args[5]) 53 padding_top = padding[0] 54 padding_left = padding[1] 55 padding_bottom = padding[0] * stride[0] if ceil_mode else padding[0] 56 padding_right = padding[1] * stride[1] if ceil_mode else padding[1] 57 58 output1_id, output2_id = self.define_tensor_list(node, mps_graph) 59 mps_graph.mps_nodes.append( 60 MPSNode( 61 mpsnode_union=MPSMaxPool2DWithIndices( 62 input1_id=input1_id, 63 kernel_height=kernel_size[0], 64 kernel_width=kernel_size[1], 65 stride_height=stride[0], 66 stride_width=stride[1], 67 padding_left=padding_left, 68 padding_right=padding_right, 69 padding_top=padding_top, 70 padding_bottom=padding_bottom, 71 dilation_height=dilation[0], 72 dilation_width=dilation[1], 73 ceil_mode=ceil_mode, 74 output1_id=output1_id, 75 output2_id=output2_id, 76 ) 77 ) 78 ) 79 80 81@register_node_visitor 82class AvgPool2DVisitor(NodeVisitor): 83 target = "aten.avg_pool2d.default" 84 85 def __init__(self, *args) -> None: 86 super().__init__(*args) 87 88 def define_node( 89 self, 90 node: torch.fx.Node, 91 mps_graph: MPSGraph, 92 ) -> None: 93 n_args = len(node.args) 94 if n_args > 7: 95 raise AssertionError( 96 f"Unexpected number of input parameters for {self.target}" 97 ) 98 99 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 100 output1_id = self.define_tensor(node, mps_graph) 101 102 padding_top, padding_left = [0, 0] 103 dilation_height, dilation_width = [1, 1] 104 105 ceil_mode = False 106 count_include_pad = True 107 divisor_override = 0 108 kernel_height, kernel_width = cast(List[int], node.args[1]) 109 stride_height, stride_width = cast(List[int], node.args[2]) 110 if n_args >= 4: 111 padding_top, padding_left = cast(List[int], node.args[3]) 112 if n_args >= 5: 113 ceil_mode = cast(bool, node.args[4]) 114 if n_args == 6: 115 count_include_pad = cast(bool, node.args[5]) 116 if n_args == 7: 117 divisor_override = cast(int, node.args[6]) 118 119 padding_bottom = padding_top * stride_height if ceil_mode else padding_top 120 padding_right = padding_left * stride_width if ceil_mode else padding_left 121 122 mps_graph.mps_nodes.append( 123 MPSNode( 124 mpsnode_union=MPSAvgPool2D( 125 input1_id=input1_id, 126 kernel_height=kernel_height, 127 kernel_width=kernel_width, 128 stride_height=stride_height, 129 stride_width=stride_width, 130 padding_left=padding_left, 131 padding_right=padding_right, 132 padding_top=padding_top, 133 padding_bottom=padding_bottom, 134 dilation_height=dilation_height, 135 dilation_width=dilation_width, 136 ceil_mode=ceil_mode, 137 count_include_pad=count_include_pad, 138 divisor_override=divisor_override, 139 output1_id=output1_id, 140 ) 141 ) 142 ) 143