# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe from typing import Callable, List, Optional import torch from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.quantization_annotation import register_annotator from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) from torch.fx import Node @register_annotator("linear") def _annotate_linear( gm: torch.fx.GraphModule, quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: annotated_partitions = [] input_act_qspec = quantization_config.get_input_act_qspec() output_act_qspec = quantization_config.get_output_act_qspec() weight_qspec = quantization_config.get_weight_qspec() bias_qspec = quantization_config.get_bias_qspec() for node in gm.graph.nodes: if node.op != "call_function" or node.target != torch.ops.aten.linear.default: continue if filter_fn and not filter_fn(node): continue act_node = node.args[0] weight_node = node.args[1] bias_node = None if len(node.args) > 2: bias_node = node.args[2] if arm_quantizer_utils.is_annotated(node) is False: # type: ignore[list-item] _annotate_input_qspec_map( node, act_node, input_act_qspec, ) _annotate_input_qspec_map( node, weight_node, weight_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: _annotate_input_qspec_map( node, bias_node, bias_qspec, ) nodes_to_mark_annotated.append(bias_node) _annotate_output_qspec(node, output_act_qspec) arm_quantizer_utils.mark_nodes_as_annotated(nodes_to_mark_annotated) annotated_partitions.append(nodes_to_mark_annotated) return annotated_partitions