1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph 15from executorch.backends.xnnpack.utils.quant_utils import ( 16 is_per_channel_group, 17 is_per_token, 18) 19from executorch.backends.xnnpack.utils.utils import ( 20 check_or_raise, 21 get_input_node, 22 is_param_node, 23) 24 25 26@register_node_visitor 27class OpDynamicDequantizePerTensor(NodeVisitor): 28 """ 29 Dequantize Per Tensor Node visitor 30 """ 31 32 target = "quantized_decomposed.dequantize_per_tensor.tensor" 33 34 def __init__(self, *args) -> None: 35 super().__init__(*args) 36 37 def define_node( 38 self, 39 node: torch.fx.Node, 40 xnn_graph: XNNGraph, 41 vals_to_ids: Dict[torch.fx.Node, int], 42 debug_handle: int, 43 ) -> None: 44 """ 45 We always skip this node because we know it is implicit 46 """ 47 dq_input = get_input_node(node, 0) 48 if dq_input in vals_to_ids: 49 vals_to_ids[node] = vals_to_ids[dq_input] 50 51 52@register_node_visitor 53class OpDynamicDequantizePerToken(NodeVisitor): 54 """ 55 Dequantize Per Token Node visitor 56 """ 57 58 target = "quantized_decomposed.dequantize_per_token.default" 59 60 def __init__(self, *args) -> None: 61 super().__init__(*args) 62 63 def define_node( 64 self, 65 node: torch.fx.Node, 66 xnn_graph: XNNGraph, 67 vals_to_ids: Dict[torch.fx.Node, int], 68 debug_handle: int, 69 ) -> None: 70 """ 71 We always skip this node because we know it is implicit 72 """ 73 dq_input = get_input_node(node, 0) 74 if dq_input in vals_to_ids: 75 vals_to_ids[node] = vals_to_ids[dq_input] 76 77 78@register_node_visitor 79class OpDequantizeAffine(NodeVisitor): 80 target = "quant.dequantize_affine.default" 81 82 def __init__(self, *args) -> None: 83 super().__init__(*args) 84 85 def define_node( 86 self, 87 node: torch.fx.Node, 88 xnn_graph: XNNGraph, 89 vals_to_ids: Dict[torch.fx.Node, int], 90 debug_handle: int, 91 ) -> None: 92 """ 93 We always define dequantize affine nodes because they are always explicit 94 """ 95 if is_per_channel_group(node): 96 check_or_raise( 97 is_param_node(self._exported_program, node.all_input_nodes[0]), 98 f"Expected quantize affine node with per-token semantics to be used " 99 f"in front of a weight node, but found node {node.all_input_nodes[0]}", 100 ) 101 # Affine dequantize was recognized as per channel group which means that it should 102 # be skipped as this means it is used in front of a weight node 103 return 104 105 check_or_raise( 106 is_per_token(node), 107 "Expecting Affine Dequantized Op to have per-token semantics", 108 ) 109 # This must be a per-token affine dequantized node, so let us serialize as such 110 dq_input = get_input_node(node, 0) 111 if dq_input in vals_to_ids: 112 vals_to_ids[node] = vals_to_ids[dq_input] 113