#!/usr/bin/python3 # Copyright 2019, The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Spec Visualizer Visualize python spec file for test generator. Invoked by ml/nn/runtime/test/specs/visualize_spec.sh; See that script for details on how this script is used. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import json import os import sys from string import Template # Stuff from test generator import test_generator as tg from test_generator import ActivationConverter from test_generator import BoolScalar from test_generator import Configuration from test_generator import DataTypeConverter from test_generator import DataLayoutConverter from test_generator import Example from test_generator import Float16Scalar from test_generator import Float32Scalar from test_generator import Float32Vector from test_generator import GetJointStr from test_generator import IgnoredOutput from test_generator import Input from test_generator import Int32Scalar from test_generator import Int32Vector from test_generator import Internal from test_generator import Model from test_generator import Operand from test_generator import Output from test_generator import Parameter from test_generator import RelaxedModeConverter from test_generator import SymmPerChannelQuantParams TEMPLATE_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "spec_viz_template.html") global_graphs = dict() def FormatArray(data, is_scalar=False): if is_scalar: assert len(data) == 1 return str(data[0]) else: return "[%s]" % (", ".join(str(i) for i in data)) def FormatDict(data): return "
".join("%s: %s"%(k.capitalize(), v) for k, v in data.items()) def GetOperandInfo(op): op_info = {"lifetime": op.lifetime, "type": op.type.type} if not op.type.IsScalar(): op_info["dimensions"] = FormatArray(op.type.dimensions) if op.type.scale != 0: op_info["scale"] = op.type.scale op_info["zero point"] = op.type.zeroPoint if op.type.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL": op_info["scale"] = FormatArray(op.type.extraParams.scales) op_info["channel dim"] = op.type.extraParams.channelDim return op_info def FormatOperand(op): # All keys and values in op_info will appear in the tooltip. We only display the operand data # if the length is less than 10. This should be convenient enough for most parameters. op_info = GetOperandInfo(op) if isinstance(op, Parameter) and len(op.value) <= 10: op_info["data"] = FormatArray(op.value, op.type.IsScalar()) template = "{tooltip_content}{op_name}" return template.format( op_name=str(op), tooltip_content=FormatDict(op_info), inpage_link="#details-operands-%d" % (op.model_index), ) def GetSubgraph(example): """Produces the nodes and edges information for d3 visualization.""" node_index_map = {} topological_order = [] def AddToTopologicalOrder(op): if op not in node_index_map: node_index_map[op] = len(topological_order) topological_order.append(op) # Get the topological order, both operands and operations are treated the same. # Given that the example.model.operations is already topologically sorted, here we simply # iterate through and insert inputs and outputs. for op in example.model.operations: for i in op.ins: AddToTopologicalOrder(i) AddToTopologicalOrder(op) for o in op.outs: AddToTopologicalOrder(o) # Assign layers to the nodes. layers = {} for node in topological_order: layers[node] = max([layers[i] for i in node.ins], default=-1) + 1 for node in reversed(topological_order): layers[node] = min([layers[o] for o in node.outs], default=layers[node]+1) - 1 num_layers = max(layers.values()) + 1 # Assign coordinates to the nodes. Nodes are equally spaced. CoordX = lambda index: (index + 0.5) * 200 # 200px spacing horizontally CoordY = lambda index: (index + 0.5) * 100 # 100px spacing vertically coords = {} layer_cnt = [0] * num_layers for node in topological_order: coords[node] = (CoordX(layer_cnt[layers[node]]), CoordY(layers[node])) layer_cnt[layers[node]] += 1 # Create edges and nodes dictionaries for d3 visualization. OpName = lambda idx: "operation%d" % idx edges = [] nodes = [] for ind, op in enumerate(example.model.operations): for tensor in op.ins: edges.append({ "source": str(tensor), "target": OpName(ind) }) for tensor in op.outs: edges.append({ "target": str(tensor), "source": OpName(ind) }) nodes.append({ "index": ind, "id": OpName(ind), "name": op.optype, "group": 2, "x": coords[op][0], "y": coords[op][1], }) for ind, op in enumerate(example.model.operands): nodes.append({ "index": ind, "id": str(op), "name": str(op), "group": 1, "x": coords[op][0], "y": coords[op][1], }) return {"nodes": nodes, "edges": edges} # The following Get**Info methods will each return a list of dictionaries, # whose content will appear in the tables and sidebar views. def GetConfigurationsInfo(example): return [{ "relaxed": str(example.model.isRelaxed), "use shared memory": str(tg.Configuration.useSHM()), "expect failure": str(example.expectFailure), }] def GetOperandsInfo(example): ret = [] for index, op in enumerate(example.model.operands): ret.append({ "index": index, "name": str(op), "group": "operand" }) ret[-1].update(GetOperandInfo(op)) if isinstance(op, (Parameter, Input, Output)): ret[-1]["data"] = FormatArray(op.value, op.type.IsScalar()) return ret def GetOperationsInfo(example): return [{ "index": index, "name": op.optype, "group": "operation", "opcode": op.optype, "inputs": ", ".join(FormatOperand(i) for i in op.ins), "outputs": ", ".join(FormatOperand(o) for o in op.outs), } for index,op in enumerate(example.model.operations)] # TODO: Remove the unused fd from the parameter. def ProcessExample(example, fd): """Process an example and save the information into the global dictionary global_graphs.""" global global_graphs print(" Processing variation %s" % example.testName) global_graphs[str(example.testName)] = { "subgraph": GetSubgraph(example), "details": { "configurations": GetConfigurationsInfo(example), "operands": GetOperandsInfo(example), "operations": GetOperationsInfo(example) } } def DumpHtml(spec_file, out_file): """Dump the final HTML file by replacing entries from a template file.""" with open(TEMPLATE_FILE, "r") as template_fd: html_template = template_fd.read() with open(out_file, "w") as out_fd: out_fd.write(Template(html_template).substitute( spec_name=os.path.basename(spec_file), graph_dump=json.dumps(global_graphs), )) def ParseCmdLine(): parser = argparse.ArgumentParser() parser.add_argument("spec", help="the spec file") parser.add_argument("-o", "--out", help="the output html path", default="out.html") args = parser.parse_args() tg.FileNames.InitializeFileLists(args.spec, "-") tg.FileNames.NextFile() return os.path.abspath(args.spec), os.path.abspath(args.out) if __name__ == '__main__': spec_file, out_file = ParseCmdLine() print("Visualizing from spec: %s" % spec_file) exec(open(spec_file, "r").read()) Example.DumpAllExamples(DumpExample=ProcessExample, example_fd=0) DumpHtml(spec_file, out_file) print("Output HTML file: %s" % out_file)