1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNGraph, 17 XNNSoftmax, 18 XNode, 19) 20from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node 21 22 23@register_node_visitor 24class SoftmaxVisitor(NodeVisitor): 25 target = "aten._softmax.default" 26 27 def __init__(self, *args) -> None: 28 super().__init__(*args) 29 30 def define_node( 31 self, 32 node: torch.fx.Node, 33 xnn_graph: XNNGraph, 34 vals_to_ids: Dict[torch.fx.Node, int], 35 debug_handle: int, 36 ) -> None: 37 # XNNPACK does not support softmax_dim != -1, atleast from the graph level APIs. 38 # XNNPACK partitioner should not let this pass, let's just make sure. 39 softmax_dim = node.args[1] 40 input_dim = get_input_node(node, 0).meta["val"].dim() 41 check_or_raise( 42 bool(softmax_dim == -1) or bool(softmax_dim == input_dim - 1), 43 f"XNNPACK does not support softmax_dim != -1, but got {softmax_dim} for tensor with dim() = {input_dim}", 44 ) 45 46 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 47 48 # input 49 input_id = vals_to_ids[get_input_node(node, 0)] 50 51 # output 52 output_id = vals_to_ids[node] 53 54 ser_node = XNode( 55 xnode_union=XNNSoftmax(input_id=input_id, output_id=output_id, flags=0), 56 debug_handle=debug_handle, 57 ) 58 xnn_graph.xnodes.append(ser_node) 59