• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10
11from executorch.backends.xnnpack.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16    XNNGraph,
17    XNNSoftmax,
18    XNode,
19)
20from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
21
22
23@register_node_visitor
24class SoftmaxVisitor(NodeVisitor):
25    target = "aten._softmax.default"
26
27    def __init__(self, *args) -> None:
28        super().__init__(*args)
29
30    def define_node(
31        self,
32        node: torch.fx.Node,
33        xnn_graph: XNNGraph,
34        vals_to_ids: Dict[torch.fx.Node, int],
35        debug_handle: int,
36    ) -> None:
37        # XNNPACK does not support softmax_dim != -1, atleast from the graph level APIs.
38        # XNNPACK partitioner should not let this pass, let's just make sure.
39        softmax_dim = node.args[1]
40        input_dim = get_input_node(node, 0).meta["val"].dim()
41        check_or_raise(
42            bool(softmax_dim == -1) or bool(softmax_dim == input_dim - 1),
43            f"XNNPACK does not support softmax_dim != -1, but got {softmax_dim} for tensor with dim() = {input_dim}",
44        )
45
46        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
47
48        # input
49        input_id = vals_to_ids[get_input_node(node, 0)]
50
51        # output
52        output_id = vals_to_ids[node]
53
54        ser_node = XNode(
55            xnode_union=XNNSoftmax(input_id=input_id, output_id=output_id, flags=0),
56            debug_handle=debug_handle,
57        )
58        xnn_graph.xnodes.append(ser_node)
59