# Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging from collections import Counter from pprint import pformat from typing import Any, Iterable, List, Literal, Optional, Tuple, Union import executorch.backends.xnnpack.test.tester.tester as tester import numpy as np import serializer.tosa_serializer as ts import torch.fx from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory from executorch.backends.arm.arm_partitioner import ArmPartitioner from executorch.backends.arm.quantizer.arm_quantizer import ( ArmQuantizer, get_symmetric_quantization_config, ) from executorch.backends.arm.test.common import ( arm_test_options, current_time_formated, get_option, ) from executorch.backends.arm.test.runner_utils import ( _get_input_quantization_params, _get_output_node, _get_output_quantization_params, dbg_tosa_fb_to_json, RunnerUtil, ) from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph logger = logging.getLogger(__name__) def _dump_lowered_modules_artifact( path_to_dump: Optional[str], artifact: ExecutorchProgramManager, graph_module: torch.fx.GraphModule, ): output = "Formated Graph Signature:\n" output += _format_export_graph_signature( artifact.exported_program().graph_signature ) def get_output_format(lowered_module) -> str | None: for spec in lowered_module.compile_specs: if spec.key == "output_format": return spec.value.decode() return None for node in graph_module.graph.nodes: if node.op == "get_attr" and node.name.startswith("lowered_module_"): lowered_module = getattr(graph_module, node.name) assert isinstance( lowered_module, LoweredBackendModule ), f"Attribute {node.name} must be of type LoweredBackendModule." output_format = get_output_format(lowered_module) if output_format == "tosa": tosa_fb = lowered_module.processed_bytes to_print = dbg_tosa_fb_to_json(tosa_fb) to_print = pformat(to_print, compact=True, indent=1) output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" elif output_format == "vela": vela_cmd_stream = lowered_module.processed_bytes output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" else: logger.warning( f"No TOSA nor Vela compile spec found in compile specs of {node.name}." ) continue if not output: logger.warning("No output to print generated from artifact.") return _dump_str(output, path_to_dump) class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) class Serialize(tester.Serialize): def __init__(self, runner_util: RunnerUtil, timeout: int = 1): super().__init__() self.runner = runner_util self.runner.set_timeout(timeout) def run_artifact(self, inputs): return self.runner.run_corstone(inputs) def dump_artifact(self, path_to_dump: Optional[str]): if not path_to_dump: path_to_dump = self.path + "/program.pte" super().dump_artifact(path_to_dump) class ToExecutorch(tester.ToExecutorch): def __init__( self, tosa_test_util: RunnerUtil, dynamic_shapes: Optional[Tuple[Any]] = None, ): super().__init__(dynamic_shapes) self.tosa_test_util = tosa_test_util def run_artifact(self, inputs): tosa_output = self.tosa_test_util.run_tosa_ref_model( inputs=inputs, ) return tosa_output class InitialModel(tester.Stage): def __init__(self, model: torch.nn.Module): self.model = model def run(self, artifact, inputs=None) -> None: pass @property def artifact(self) -> torch.nn.Module: return self.model @property def graph_module(self) -> None: return None def artifact_str(self) -> str: return str(self.model) def run_artifact(self, inputs): return self.model.forward(*inputs) class ArmTester(Tester): def __init__( self, model: torch.nn.Module, example_inputs: Tuple[torch.Tensor], compile_spec: List[CompileSpec] = None, tosa_ref_model_path: str | None = None, ): """ Args: model (torch.nn.Module): The model to test example_inputs (Tuple[torch.Tensor]): Example inputs to the model compile_spec (List[CompileSpec]): The compile spec to use """ # Initiate runner_util intermediate_path = get_intermediate_path(compile_spec) self.runner_util = RunnerUtil( intermediate_path=intermediate_path, tosa_ref_model_path=tosa_ref_model_path, ) self.compile_spec = compile_spec super().__init__(model, example_inputs) self.pipeline[self.stage_name(InitialModel)] = [ self.stage_name(tester.Quantize), self.stage_name(tester.Export), ] # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. self.stages[self.stage_name(InitialModel)] = None self._run_stage(InitialModel(self.original_module)) def quantize(self, quantize_stage: Optional[tester.Quantize] = None): if quantize_stage is None: quantize_stage = tester.Quantize( ArmQuantizer(), get_symmetric_quantization_config(is_per_channel=False), ) return super().quantize(quantize_stage) def to_edge( self, to_edge_stage: Optional[tester.ToEdge] = None, config: Optional[EdgeCompileConfig] = None, ): if to_edge_stage is None: to_edge_stage = tester.ToEdge(config) else: if config is not None: to_edge_stage.edge_compile_conf = config # TODO(T182928844): Delegate dim order op to backend. to_edge_stage.edge_compile_conf._skip_dim_order = True return super().to_edge(to_edge_stage) def partition(self, partition_stage: Optional[Partition] = None): if partition_stage is None: arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec) partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) def to_edge_transform_and_lower( self, to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, ): if to_edge_and_lower_stage is None: if partitioners is None: partitioners = [ArmPartitioner(compile_spec=self.compile_spec)] to_edge_and_lower_stage = ToEdgeTransformAndLower( partitioners, edge_compile_config ) else: if partitioners is not None: to_edge_and_lower_stage.partitioners = partitioners if edge_compile_config is not None: to_edge_and_lower_stage.edge_compile_conf = edge_compile_config to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True return super().to_edge_transform_and_lower(to_edge_and_lower_stage) def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): if to_executorch_stage is None: to_executorch_stage = ToExecutorch(self.runner_util) return super().to_executorch(to_executorch_stage) def serialize( self, serialize_stage: Optional[Serialize] = None, timeout: int = 120 ): if serialize_stage is None: serialize_stage = Serialize(self.runner_util, timeout=timeout) assert ( get_intermediate_path(self.compile_spec) is not None ), "Can't dump serialized file when compile specs do not contain an artifact path." return ( super() .serialize(serialize_stage) .dump_artifact(get_intermediate_path(self.compile_spec) + "/program.pte") ) def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, stage: Optional[str] = None, target_board: Optional[str] = "corstone-300", num_runs=1, atol=1e-03, rtol=1e-03, qtol=0, ): """ Compares the run_artifact output of 'stage' with the output of a reference stage. If the model is quantized, the reference stage is the Quantize stage output. Otherwise, the reference stage is the initial pytorch module. Asserts that the outputs are equal (within tolerances). Returns self to allow the function to be run in a test chain. Args: stage: (Optional[str]): The name of the stage to compare. The default is the latest run stage. inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ edge_stage = self.stages[self.stage_name(tester.ToEdge)] if edge_stage is None: edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] assert ( self.runner_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" assert ( edge_stage is not None ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." stage = stage or self.cur test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None exported_program = self.stages[self.stage_name(tester.Export)].artifact edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run( exported_program, edge_program, is_quantized, target_board, ) if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] quantization_scale = self.runner_util.qp_output.scale else: reference_stage = self.stages[self.stage_name(InitialModel)] quantization_scale = None logger.info( f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" ) is_nhwc = is_permute_memory(self.compile_spec) # Loop inputs and compare reference stage with the compared stage. for run_iteration in range(num_runs): reference_input = inputs if inputs else next(self.generate_random_inputs()) # Test parameters can include constants that are used in eager mode but are already set as attributes # in TOSA. Therefore, only accept torch.Tensor inputs. test_input: list[torch.Tensor] = [] for arg in reference_input: if isinstance(arg, torch.Tensor): test_input.append(arg.clone()) if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor): test_input.extend([tensor.clone() for tensor in arg]) if ( is_nhwc and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] ): test_input = self.transpose_data_format(test_input, "NHWC") input_shapes = [ generated_input.shape if hasattr(generated_input, "shape") else (1,) for generated_input in reference_input ] input_shape_str = ", ".join([str(list(i)) for i in input_shapes]) logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") reference_output = reference_stage.run_artifact(reference_input) test_output = tuple(test_stage.run_artifact(test_input)) if ( is_nhwc and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] ): test_output = self.transpose_data_format(test_output, "NCHW") self._compare_outputs( reference_output, test_output, quantization_scale, atol, rtol, qtol ) return self def get_graph(self, stage: str | None = None) -> Graph: if stage is None: stage = self.cur artifact = self.get_artifact(stage) if ( self.cur == self.stage_name(tester.ToEdge) or self.cur == self.stage_name(Partition) or self.cur == self.stage_name(ToEdgeTransformAndLower) ): graph = artifact.exported_program().graph elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( tester.Quantize ): graph = artifact.graph else: raise RuntimeError( "Can only get a graph from Quantize, ToEdge, Export, and Partition stages." ) return graph def dump_operator_distribution( self, path_to_dump: Optional[str] = None, print_table: bool = True ): """Dump the distribution of operators in the current stage. In the partition stage, additional information is included such as the number of delegates and the distribution of TOSA operators. Set parameter print_table to False to dump in a parseable format. Returns self for daisy-chaining. """ line = "#" * 10 to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" if ( self.cur in ( self.stage_name(tester.Partition), self.stage_name(ToEdgeTransformAndLower), ) and print_table ): graph_module = self.get_artifact().exported_program().graph_module if print_table: delegation_info = get_delegation_info(graph_module) op_dist = delegation_info.get_operator_delegation_dataframe() else: op_dist = dict(_get_operator_distribution(graph_module.graph)) to_print += _format_dict(op_dist, print_table) to_print += "\n" + _get_tosa_operator_distribution( graph_module, print_table ) to_print += "\n" to_print += delegation_info.get_summary() else: graph = self.get_graph(self.cur) op_dist = dict(_get_operator_distribution(graph)) if print_table: op_dist = { "Operator": list(op_dist), "Count": [op_dist[key] for key in op_dist], } to_print += _format_dict(op_dist, print_table) + "\n" _dump_str(to_print, path_to_dump) return self def dump_dtype_distribution( self, path_to_dump: Optional[str] = None, print_table: bool = True ): """Dump a the distributions of dtypes of nodes and placeholders in the current stage. Set parameter print_table to False to dump in a parseable format. Returns self for daisy-chaining. """ line = "#" * 10 to_print = ( f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n" ) graph = self.get_graph(self.cur) dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph) all_dtypes = set(dtype_dist_placeholders.keys()) | set( dtype_dirst_tensors.keys() ) if print_table: dtype_dist = { "Dtype": all_dtypes, "Placeholder Count": [ ( dtype_dist_placeholders[key] if key in dtype_dist_placeholders else 0 ) for key in all_dtypes ], "Tensor Count": [ (dtype_dirst_tensors[key] if key in dtype_dirst_tensors else 0) for key in all_dtypes ], } else: dtype_dist = dict(dtype_dist_placeholders + dtype_dirst_tensors) to_print += _format_dict(dtype_dist, print_table) + "\n" _dump_str(to_print, path_to_dump) return self @staticmethod def _calculate_reference_output( module: Union[torch.fx.GraphModule, torch.nn.Module], inputs ) -> torch.Tensor: """ Note: I'd prefer to use the base class method here, but since it use the exported program, I can't. The partitioner stage clears the state_dict of the exported program, which causes an issue when evaluating the module. """ return module.forward(*inputs) def transpose_data_format( self, data: Tuple[torch.Tensor], to: Literal["NHWC", "NCHW"] ): if to == "NCHW": dim_order = (0, 3, 1, 2) if to == "NHWC": dim_order = (0, 2, 3, 1) inputs_transposed = list(data) for i in range(len(data)): if hasattr(data[i], "shape") and len(data[i].shape) == 4: inputs_transposed[i] = np.transpose(data[i], dim_order) return tuple(inputs_transposed) def _compare_outputs( self, reference_output, stage_output, quantization_scale=None, atol=1e-03, rtol=1e-03, qtol=0, ): try: super()._compare_outputs( reference_output, stage_output, quantization_scale, atol, rtol, qtol ) except AssertionError as e: # Capture assertion error and print more info banner = "=" * 40 + "TOSA debug info" + "=" * 40 logger.error(banner) path_to_tosa_files = self.runner_util.intermediate_path export_stage = self.stages.get(self.stage_name(tester.Export), None) quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None) if export_stage is not None and quantize_stage is not None: output_node = _get_output_node(export_stage.artifact) qp_input = _get_input_quantization_params(export_stage.artifact) qp_output = _get_output_quantization_params( export_stage.artifact, output_node ) logger.error(f"{qp_input=}") logger.error(f"{qp_output=}") logger.error(f"{path_to_tosa_files=}") import os torch.save( stage_output, os.path.join(path_to_tosa_files, "torch_tosa_output.pt"), ) torch.save( reference_output, os.path.join(path_to_tosa_files, "torch_ref_output.pt"), ) logger.error(f"{atol=}, {rtol=}, {qtol=}") raise e def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]: """Counts the occurences of placeholder and call_function dtypes in a graph. The result is a tuple of Counters (placeholder_distribution, call_function_distribution) """ placeholder_dtypes = [] call_function_dtypes = [] for node in graph.nodes: if node.op == "placeholder": placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": if "val" in node.meta: dtype, _, _ = extract_tensor_meta(node.meta) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) def _get_operator_distribution(graph: Graph) -> dict[str, int]: """Counts the occurences of operator names in a graph. The result is a dict {'operator name':'number of nodes'} """ return Counter( [str(node.target) for node in list(graph.nodes) if node.op == "call_function"] ) def _format_export_graph_signature(signature: ExportGraphSignature) -> str: def specs_dict(specs: list[InputSpec | OutputSpec], title: str): _dict: dict[str, list] = {title: [], "arg": [], "kind": [], "target": []} for i, spec in enumerate(specs): _dict[title].append(i) _dict["arg"].append(spec.arg) _dict["kind"].append(spec.kind) _dict["target"].append(spec.target if spec.target else "-") return _dict input_dict = specs_dict(signature.input_specs, "Inputs") output_dict = specs_dict(signature.output_specs, "Outputs") return f"{_format_dict(input_dict)}\n{_format_dict(output_dict)}" def _get_tosa_operator_distribution( graph_module: torch.fx.GraphModule, print_table=False ) -> str: """Counts the occurences of operator names of all lowered modules containing a TOSA flatbuffer. The result is a string with the operator distribution or an error message. """ op_list = [] id = 0 while lowered_module := getattr(graph_module, f"lowered_module_{id}", None): for spec in lowered_module.compile_specs: if spec.key != "output_format": continue if spec.value == b"tosa": tosa_fb = lowered_module.processed_bytes tosa_json = dbg_tosa_fb_to_json(tosa_fb) for region in tosa_json["regions"]: for block in region["blocks"]: op_list.extend( [operator["op"] for operator in block["operators"]] ) break elif spec.value == b"vela": return "Can not get operator distribution for Vela command stream." else: return f"Unknown output format '{spec.value}'." id += 1 if id == 0: return "No delegate with name 'lowered_module_0 found in graph module." op_dist = dict(Counter(op_list)) op_dist = { "Operator": list(op_dist.keys()), "Count": [item[1] for item in op_dist.items()], } return "TOSA operators:\n" + _format_dict(dict(op_dist), print_table) def _dump_str(to_print: str, path_to_dump: Optional[str] = None): default_dump_path = get_option(arm_test_options.dump_path) if not path_to_dump and default_dump_path: path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log" if path_to_dump: with open(path_to_dump, "a") as fp: fp.write(to_print) else: logger.info(to_print) def _format_dict(to_print: dict, print_table: bool = True) -> str: if isinstance(list(to_print.items())[0], Iterable) and print_table: return tabulate( to_print, headers="keys", tablefmt="fancy_grid", maxcolwidths=35 ) else: return pformat(to_print, compact=True, indent=1)