• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict, List
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    check_or_raise,
12    get_tensor_value,
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17    XNNGlobalAvgPooling2d,
18    XNNGraph,
19    XNode,
20)
21from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS
22
23
24@register_node_visitor
25class MeanDim(NodeVisitor):
26    """
27    XNNPACK only supports a special case of mean dim in which the operation can be written
28    as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d,
29    the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag
30    for keepdim must be set to True.
31    """
32
33    target = "aten.mean.dim"
34
35    def __init__(self, *args) -> None:
36        super().__init__(*args)
37
38    def define_node(
39        self,
40        node: torch.fx.Node,
41        xnn_graph: XNNGraph,
42        vals_to_ids: Dict[torch.fx.Node, int],
43        debug_handle: int,
44    ) -> None:
45        self.define_nodes_tensor_inputs_outputs(
46            node, xnn_graph, vals_to_ids, convert_to_nhwc=True
47        )
48        # input
49        input_id = vals_to_ids[cast(torch.fx.Node, node.args[0])]
50
51        # output
52        output_id = vals_to_ids[node]
53
54        # mean dims
55        mean_dims = cast(List[int], node.args[1])
56        check_or_raise(
57            mean_dims == [-1, -2] or mean_dims == [-2, -1],
58            "XNNPACK only supports mean.dim across the innermost dimensions",
59        )
60
61        # keep dims
62        check_or_raise(
63            len(node.args) == 3 and bool(node.args[2]),
64            "XNNPACK only supports mean.dim that keeps dims",
65        )
66
67        input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
68        check_or_raise(
69            len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
70        )
71
72        ser_node = XNode(
73            xnode_union=XNNGlobalAvgPooling2d(
74                input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS
75            ),
76            debug_handle=debug_handle,
77        )
78        xnn_graph.xnodes.append(ser_node)
79