# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging import operator import torch from torch.ao.quantization.pt2e.utils import ( _filter_sym_size_users, _is_valid_annotation, ) from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) __all__ = ["DuplicateDynamicQuantChainPass"] _QUANTIZE_OPS = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, ] def _replace_input_node_with_new_node(node, input_node, new_node): def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: if n == input_node: return new_node else: return n new_args = map_arg(node.args, maybe_replace_node) new_kwargs = map_arg(node.kwargs, maybe_replace_node) node.args = new_args node.kwargs = new_kwargs def _replicate_chose_qparam_nodes_for_q_dq( gm: torch.fx.GraphModule, chose_qparams_node, get_item_node_1, get_item_node_2 ): if ( ( chose_qparams_node.target != torch.ops.quantized_decomposed.choose_qparams.tensor ) or (get_item_node_1.target != operator.getitem) or (get_item_node_2.target != operator.getitem) ): raise RuntimeError( f"Expecting choose_qparams.tensor and getitem nodes but got {chose_qparams_node}, {get_item_node_1}, {get_item_node_2}" ) users = list(get_item_node_1.users.copy()) q_dq_pair = [] for user in users: if user.target in _QUANTIZE_OPS: if len(user.users) != 1: raise RuntimeError(f"Node {user} has more than one user") dq_node = list(user.users)[0] if dq_node.target not in _DEQUANTIZE_OPS: raise RuntimeError( f"Node {user}'s use must be a dequantize op but got {dq_node}:{dq_node.target}" ) q_dq_pair.append((user, dq_node)) for q_node, dq_node in q_dq_pair: with gm.graph.inserting_after(get_item_node_1): new_get_item_node_1 = gm.graph.node_copy(get_item_node_1) new_get_item_node_2 = gm.graph.node_copy(get_item_node_2) new_chose_qparams_node = gm.graph.node_copy(chose_qparams_node) _replace_input_node_with_new_node( new_get_item_node_1, chose_qparams_node, new_chose_qparams_node ) _replace_input_node_with_new_node( new_get_item_node_2, chose_qparams_node, new_chose_qparams_node ) _replace_input_node_with_new_node( q_node, get_item_node_1, new_get_item_node_1 ) _replace_input_node_with_new_node( dq_node, get_item_node_1, new_get_item_node_1 ) _replace_input_node_with_new_node( q_node, get_item_node_2, new_get_item_node_2 ) _replace_input_node_with_new_node( dq_node, get_item_node_2, new_get_item_node_2 ) gm.graph.eliminate_dead_code() gm.recompile() def _replicate_node_for_each_user(gm: torch.fx.GraphModule, node: torch.fx.Node): users = list(node.users.copy()) for user in users: with gm.graph.inserting_after(node): new_node = gm.graph.node_copy(node) _replace_input_node_with_new_node(user, node, new_node) gm.graph.eliminate_dead_code() gm.recompile() def _maybe_duplicate_dynamic_quantize_chain( gm: torch.fx.GraphModule, chose_qparams_node, get_item_node_1, get_item_node_2, q_node, dq_node: torch.fx.Node, ): num_dq_users = len(dq_node.users) dq_node_users = list(dq_node.users.copy()) for user in dq_node_users: annotation = user.meta.get("quantization_annotation", None) if not _is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) _replace_input_node_with_new_node(user, dq_node, new_node) gm.graph.eliminate_dead_code() gm.recompile() if len(q_node.users) != num_dq_users: raise RuntimeError( f"Expected {num_dq_users} users of {q_node}, but got {len(q_node.users)}" ) _replicate_node_for_each_user(gm, q_node) # *2 because scale/zp are used both with q and dq nodes if len(get_item_node_1.users) != num_dq_users * 2: raise RuntimeError( f"Expected {num_dq_users} users of {get_item_node_1}, but got {len(get_item_node_1.users)}" ) if len(get_item_node_2.users) != num_dq_users * 2: raise RuntimeError( f"Expected {num_dq_users} users of {get_item_node_2}, but got {len(get_item_node_2.users)}" ) _replicate_chose_qparam_nodes_for_q_dq( gm, chose_qparams_node, get_item_node_1, get_item_node_2 ) class DuplicateDynamicQuantChainPass(PassBase): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: dq_users = _filter_sym_size_users(node) if len(dq_users) <= 1: continue # Do not duplicate dq for dynamic quantization # Pattern: choose_qparam - getitem - q - dq q_node = node.args[0] if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: getitem_1_node = q_node.args[1] getitem_2_node = q_node.args[2] if ( isinstance(getitem_1_node, torch.fx.node.Node) and getitem_1_node.op == "call_function" and getitem_1_node.target == operator.getitem ): choose_qparam_node = getitem_1_node.args[0] if ( isinstance(choose_qparam_node, torch.fx.node.Node) and choose_qparam_node.op == "call_function" and choose_qparam_node.target == torch.ops.quantized_decomposed.choose_qparams.tensor ): _maybe_duplicate_dynamic_quantize_chain( graph_module, choose_qparam_node, getitem_1_node, getitem_2_node, q_node, node, ) continue graph_module.graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True)