1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 check_or_raise, 12 get_tensor_value, 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 17 XNNGlobalAvgPooling2d, 18 XNNGraph, 19 XNode, 20) 21from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS 22 23 24@register_node_visitor 25class MeanDim(NodeVisitor): 26 """ 27 XNNPACK only supports a special case of mean dim in which the operation can be written 28 as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d, 29 the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag 30 for keepdim must be set to True. 31 """ 32 33 target = "aten.mean.dim" 34 35 def __init__(self, *args) -> None: 36 super().__init__(*args) 37 38 def define_node( 39 self, 40 node: torch.fx.Node, 41 xnn_graph: XNNGraph, 42 vals_to_ids: Dict[torch.fx.Node, int], 43 debug_handle: int, 44 ) -> None: 45 self.define_nodes_tensor_inputs_outputs( 46 node, xnn_graph, vals_to_ids, convert_to_nhwc=True 47 ) 48 # input 49 input_id = vals_to_ids[cast(torch.fx.Node, node.args[0])] 50 51 # output 52 output_id = vals_to_ids[node] 53 54 # mean dims 55 mean_dims = cast(List[int], node.args[1]) 56 check_or_raise( 57 mean_dims == [-1, -2] or mean_dims == [-2, -1], 58 "XNNPACK only supports mean.dim across the innermost dimensions", 59 ) 60 61 # keep dims 62 check_or_raise( 63 len(node.args) == 3 and bool(node.args[2]), 64 "XNNPACK only supports mean.dim that keeps dims", 65 ) 66 67 input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims 68 check_or_raise( 69 len(input_shape) == 4, "Require input to mean.dim be 4 dimensional" 70 ) 71 72 ser_node = XNode( 73 xnode_union=XNNGlobalAvgPooling2d( 74 input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS 75 ), 76 debug_handle=debug_handle, 77 ) 78 xnn_graph.xnodes.append(ser_node) 79