1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 XNNAvgPooling2d, 16 XNNGraph, 17 XNode, 18) 19from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS 20 21 22@register_node_visitor 23class AveragePooling2d(NodeVisitor): 24 target = "aten.avg_pool2d.default" 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 xnn_graph: XNNGraph, 33 vals_to_ids: Dict[torch.fx.Node, int], 34 debug_handle: int, 35 ) -> None: 36 self.define_nodes_tensor_inputs_outputs( 37 node, xnn_graph, vals_to_ids, convert_to_nhwc=True 38 ) 39 40 # input 41 input_id = vals_to_ids[cast(torch.fx.Node, node.args[0])] 42 43 # output 44 output_id = vals_to_ids[node] 45 46 # kernel_size 47 pooling_height, pooling_width = cast(List, node.args[1]) 48 49 # stride 50 stride_height, stride_width = cast(List, node.args[2]) 51 52 # padding 53 padding_height, padding_width = 0, 0 54 if node.args[3] is not None: 55 padding_height, padding_width = cast(List[int], node.args[3]) 56 57 ser_node = XNode( 58 xnode_union=XNNAvgPooling2d( 59 padding_top=padding_height, 60 padding_right=padding_width, 61 padding_bottom=padding_height, 62 padding_left=padding_width, 63 pooling_height=pooling_height, 64 pooling_width=pooling_width, 65 stride_height=stride_height, 66 stride_width=stride_width, 67 dilation_height=0, # Unused 68 dilation_width=0, # Unused 69 input_id=input_id, 70 output_id=output_id, 71 flags=XNN_FLAG_KEEP_DIMS, 72 ), 73 debug_handle=debug_handle, 74 ) 75 xnn_graph.xnodes.append(ser_node) 76