import copy import logging from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import torch from torch.ao.ns.fx.utils import compute_sqnr from torch.fx import GraphModule, Node from torch.nn import functional as F NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" CUSTOM_KEY = "custom" log = logging.getLogger(__name__) def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node The graph nodes of input model is modified inplace. """ unique_id = 0 for node in graph_module.graph.nodes: if node.op in ["output", "placeholder"]: continue if CUSTOM_KEY not in node.meta: node.meta[CUSTOM_KEY] = {} if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id unique_id += 1 class OutputLogger(torch.nn.Module): """ Base class for capturing output values for nodes in a GraphModule, it only captures Tensor output currently, but we can extend it to work for other types of inputs later if needed """ # Mark as impure so that calls to it will not be removed during DCE. _is_impure = True def __init__( self, debug_handle: int, node_name: Optional[str] = None, nn_module_stack: Optional[object] = None, ) -> None: super().__init__() self.node_name = node_name self.nn_module_stack = nn_module_stack self.debug_handle = debug_handle self.stats: List[torch.Tensor] = [] def forward(self, x: object) -> object: if isinstance(x, torch.Tensor): self.stats.append(x.detach()) return x def __extra_repr__(self) -> str: return ( f"debug_handle={self.debug_handle}, node_name={self.node_name}, " "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" ) def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: """For a given node, adds an OutputLogger that observes the output of that node, and all its users use the OutputLogger output instead. The OutputLogger will contain the debug_handle which can be used to compare graphs after transforms""" # to avoid circular dep from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix # add a logger after the node with model.graph.inserting_after(node): get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") logger_name = get_new_attr_name(model) setattr( model, logger_name, OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), ) logger_node = model.graph.call_module(logger_name, (node,), {}) orig_users = list(node.users.keys()) for user_node in orig_users: if user_node is logger_node: continue user_node.replace_input_with(node, logger_node) return logger_node def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: """Add output loggers to node that has numeric_debug_handle Args: model (GraphModule): original model Returns: a model with output loggers for all nodes that has numeric_debug_handle_id """ # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: if ( CUSTOM_KEY not in n.meta or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] ): continue numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] _insert_logger(model, n, numeric_debug_handle) model.recompile() return model @dataclass(frozen=True) class QuantizationComparisonResult: actual: torch.Tensor ref: torch.Tensor @property def mse_loss(self) -> torch.Tensor: return F.mse_loss( self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) ) @property def sqnr(self) -> torch.Tensor: return compute_sqnr( self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) ) def __repr__(self) -> str: # Don't include the tensors themselves as they are quite large to print # out. return ( f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" ) def __post_init__(self) -> None: if not isinstance(self.actual, torch.Tensor): raise ValueError( f"`self.actual` value must be a Tensor, got: {self.actual}" ) if not isinstance(self.ref, torch.Tensor): raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") @dataclass(frozen=True) class NodeAccuracySummary: handle: int actual_node_name: str actual_module_stack: str ref_node_name: str ref_module_stack: str results: Sequence[QuantizationComparisonResult] def _module_stack_to_str(module_stack: object) -> str: """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") to "mod.foo.0.linear" """ if not isinstance(module_stack, dict): return str(module_stack) module_values_list = list(module_stack.values()) if len(module_values_list) > 0: owning_module = module_values_list[-1][0] return str(owning_module) else: return str(module_stack) def extract_results_from_loggers( model: GraphModule, ) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]: """For a given model, extract the tensors stats and related information for each debug handle. Returns: A dict is keyed by the debug_handle id and the values are a list of Tensors recorded in loggers""" # Results maps debug handle to a tensor list for each model being compared. handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {} for _name, module in model.named_children(): if isinstance(module, OutputLogger) and len(module.stats) > 0: handles[module.debug_handle] = ( module.node_name, module.nn_module_stack, module.stats, ) return handles def compare_results( ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], ) -> Dict[int, NodeAccuracySummary]: """Given two dict mapping from `debug_handle_id` (int) to list of tensors return a map from `debug_handle_id` to `NodeAccuracySummary` that contains comparison information like SQNR, MSE etc. Args: ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id Returns: Dict[int, NodeAccuracySummary] """ comparisons = {} for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): if debug_handle not in actual_results: log.debug( "Cannot compare for handle %s because it wasn't found in the transformed model", debug_handle, ) continue actual_name, actual_stack, actual_stats = actual_results[debug_handle] comparisons[debug_handle] = NodeAccuracySummary( handle=debug_handle, actual_node_name=actual_name, actual_module_stack=_module_stack_to_str(actual_stack), ref_node_name=ref_name, ref_module_stack=_module_stack_to_str(ref_stack), results=[ QuantizationComparisonResult(actual=a, ref=b) for a, b in zip(actual_stats, ref_stats) ], ) return comparisons