1#!/usr/bin/python3 2 3# Copyright 2019, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Spec Visualizer 18 19Visualize python spec file for test generator. 20 21Modified from TFLite graph visualizer -- instead of flatbuffer, takes spec file as input. 22(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py) 23 24""" 25 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29import argparse 30import fnmatch 31import json 32import math 33import os 34import re 35import sys 36import traceback 37 38# Stuff from test generator 39import test_generator as tg 40from test_generator import ActivationConverter 41from test_generator import BoolScalar 42from test_generator import Configuration 43from test_generator import DataTypeConverter 44from test_generator import DataLayoutConverter 45from test_generator import Example 46from test_generator import Float32Scalar 47from test_generator import Float32Vector 48from test_generator import GetJointStr 49from test_generator import IgnoredOutput 50from test_generator import Input 51from test_generator import Int32Scalar 52from test_generator import Int32Vector 53from test_generator import Internal 54from test_generator import Model 55from test_generator import Operand 56from test_generator import Output 57from test_generator import Parameter 58from test_generator import ParameterAsInputConverter 59from test_generator import RelaxedModeConverter 60from test_generator import SmartOpen 61 62# A CSS description for making the visualizer 63_CSS = """ 64<html> 65<head> 66<style> 67body {font-family: sans-serif; background-color: #ffaa00;} 68table {background-color: #eeccaa;} 69th {background-color: black; color: white;} 70h1 { 71 background-color: ffaa00; 72 padding:5px; 73 color: black; 74} 75 76div { 77 border-radius: 5px; 78 background-color: #ffeecc; 79 padding:5px; 80 margin:5px; 81} 82 83.tooltip {color: blue;} 84.tooltip .tooltipcontent { 85 visibility: hidden; 86 color: black; 87 background-color: yellow; 88 padding: 5px; 89 border-radius: 4px; 90 position: absolute; 91 z-index: 1; 92} 93.tooltip:hover .tooltipcontent { 94 visibility: visible; 95} 96 97.edges line { 98 stroke: #333333; 99} 100 101.nodes text { 102 color: black; 103 pointer-events: none; 104 font-family: sans-serif; 105 font-size: 11px; 106} 107</style> 108 109<script src="https://d3js.org/d3.v4.min.js"></script> 110 111</head> 112<body> 113""" 114 115_D3_HTML_TEMPLATE = """ 116 <script> 117 // Build graph data 118 var graph = %s; 119 120 var svg = d3.select("#subgraph_%s"); 121 var width = svg.attr("width"); 122 var height = svg.attr("height"); 123 var color = d3.scaleOrdinal(d3.schemeCategory20); 124 125 var simulation = d3.forceSimulation() 126 .force("link", d3.forceLink().id(function(d) {return d.id;})) 127 .force("charge", d3.forceManyBody()) 128 .force("center", d3.forceCenter(0.5 * width, 0.5 * height)); 129 130 131 function buildGraph() { 132 var edge = svg.append("g").attr("class", "edges").selectAll("line") 133 .data(graph.edges).enter().append("line") 134 // Make the node group 135 var node = svg.selectAll(".nodes") 136 .data(graph.nodes) 137 .enter().append("g") 138 .attr("class", "nodes") 139 .call(d3.drag() 140 .on("start", function(d) { 141 if(!d3.event.active) simulation.alphaTarget(1.0).restart(); 142 d.fx = d.x;d.fy = d.y; 143 }) 144 .on("drag", function(d) { 145 d.fx = d3.event.x; d.fy = d3.event.y; 146 }) 147 .on("end", function(d) { 148 if (!d3.event.active) simulation.alphaTarget(0); 149 d.fx = d.fy = null; 150 })); 151 // Within the group, draw a circle for the node position and text 152 // on the side. 153 node.append("circle") 154 .attr("r", "5px") 155 .attr("fill", function(d) { return color(d.group); }) 156 node.append("text") 157 .attr("dx", 8).attr("dy", 5).text(function(d) { return d.name; }); 158 // Setup force parameters and update position callback 159 simulation.nodes(graph.nodes).on("tick", forceSimulationUpdated); 160 simulation.force("link").links(graph.edges); 161 162 function forceSimulationUpdated() { 163 // Update edges. 164 edge.attr("x1", function(d) {return d.source.x;}) 165 .attr("y1", function(d) {return d.source.y;}) 166 .attr("x2", function(d) {return d.target.x;}) 167 .attr("y2", function(d) {return d.target.y;}); 168 // Update node positions 169 node.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; }); 170 } 171 } 172 buildGraph() 173</script> 174""" 175 176class OpCodeMapper(object): 177 """Maps an opcode index to an op name.""" 178 179 def __init__(self, data): 180 self.code_to_name = {} 181 for idx, d in enumerate(data["operator_codes"]): 182 self.code_to_name[idx] = d["builtin_code"] 183 184 def __call__(self, x): 185 if x not in self.code_to_name: 186 s = "<UNKNOWN>" 187 else: 188 s = self.code_to_name[x] 189 return "%s (opcode=%d)" % (s, x) 190 191 192class DataSizeMapper(object): 193 """For buffers, report the number of bytes.""" 194 195 def __call__(self, x): 196 if x is not None: 197 return "%d bytes" % len(x) 198 else: 199 return "--" 200 201 202class TensorMapper(object): 203 """Maps a list of tensor indices to a tooltip hoverable indicator of more.""" 204 205 def __init__(self, subgraph_data): 206 self.data = subgraph_data 207 208 def __call__(self, x): 209 html = "" 210 html += "<span class='tooltip'><span class='tooltipcontent'>" 211 for i in x: 212 tensor = self.data["operands"][i] 213 html += str(i) + " " 214 html += tensor["name"] + " " 215 html += str(tensor["type"]) + " " 216 html += (repr(tensor["dimensions"]) if "dimensions" in tensor else "[]") + "<br>" 217 html += "</span>" 218 html += repr(x) 219 html += "</span>" 220 return html 221 222def GenerateGraph(g): 223 """Produces the HTML required to have a d3 visualization of the dag.""" 224 225# def TensorName(idx): 226# return "t%d" % idx 227 228 def OpName(idx): 229 return "o%d" % idx 230 231 edges = [] 232 nodes = [] 233 first = {} 234 pixel_mult = 50 235 for op_index, op in enumerate(g["operations"]): 236 for tensor in op["inputs"]: 237 if tensor not in first: 238 first[str(tensor)] = ( 239 op_index * pixel_mult, 240 len(first) * pixel_mult - pixel_mult / 2) 241 edges.append({ 242 "source": str(tensor), 243 "target": OpName(op_index) 244 }) 245 for tensor in op["outputs"]: 246 edges.append({ 247 "target": str(tensor), 248 "source": OpName(op_index) 249 }) 250 nodes.append({ 251 "id": OpName(op_index), 252 "name": op["opcode"], 253 "group": 2, 254 "x": pixel_mult, 255 "y": op_index * pixel_mult 256 }) 257 for tensor_index, tensor in enumerate(g["operands"]): 258 initial_y = ( 259 first[tensor["name"]] if tensor["name"] in first else len(g["operations"])) 260 261 nodes.append({ 262 "id": tensor["name"], 263 "name": "%s (%d)" % (tensor["name"], tensor_index), 264 "group": 1, 265 "x": 2, 266 "y": initial_y 267 }) 268 graph_str = json.dumps({"nodes": nodes, "edges": edges}) 269 270 html = _D3_HTML_TEMPLATE % (graph_str, g["name"]) 271 return html 272 273def GenerateTableHtml(items, keys_to_print, display_index=True): 274 """Given a list of object values and keys to print, make an HTML table. 275 276 Args: 277 items: Items to print an array of dicts. 278 keys_to_print: (key, display_fn). `key` is a key in the object. i.e. 279 items[0][key] should exist. display_fn is the mapping function on display. 280 i.e. the displayed html cell will have the string returned by 281 `mapping_fn(items[0][key])`. 282 display_index: add a column which is the index of each row in `items`. 283 Returns: 284 An html table. 285 """ 286 html = "" 287 # Print the list of items 288 html += "<table><tr>\n" 289 html += "<tr>\n" 290 if display_index: 291 html += "<th>index</th>" 292 for h, mapper in keys_to_print: 293 html += "<th>%s</th>" % h 294 html += "</tr>\n" 295 for idx, tensor in enumerate(items): 296 html += "<tr>\n" 297 if display_index: 298 html += "<td>%d</td>" % idx 299 # print tensor.keys() 300 for h, mapper in keys_to_print: 301 val = tensor[h] if h in tensor else None 302 val = val if mapper is None else mapper(val) 303 html += "<td>%s</td>\n" % val 304 305 html += "</tr>\n" 306 html += "</table>\n" 307 return html 308 309 310def CreateHtmlFile(g, fd): 311 """Given a tflite model in `tflite_input` file, produce html description.""" 312 html = "" 313 314 # Subgraph local specs on what to display 315 html += "<div class='subgraph'>" 316 tensor_mapper = lambda l: ", ".join(str(op) for op in l) 317 op_keys_to_display = [("opcode", None), ("inputs", tensor_mapper), ("outputs", tensor_mapper)] 318 tensor_keys_to_display = [("name", None), ("type", None), ("dimensions", None), ("scale", None), 319 ("zero_point", None), ("lifetime", None)] 320 html += "<h2>%s</h2>\n" % g["name"] 321 322 # Configurations. 323 html += "<h3>Configurations</h3>\n" 324 html += GenerateTableHtml( 325 [g["options"]], [(k, None) for k in g["options"].keys()], display_index=False) 326 327 # Inputs and outputs. 328 html += "<h3>Inputs/Outputs</h3>\n" 329 html += GenerateTableHtml( 330 [{ 331 "inputs": g["inputs"], 332 "outputs": g["outputs"] 333 }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], 334 display_index=False) 335 336 # Print the operands. 337 html += "<h3>Operands</h3>\n" 338 html += GenerateTableHtml(g["operands"], tensor_keys_to_display) 339 340 # Print the operations. 341 html += "<h3>Operations</h3>\n" 342 html += GenerateTableHtml(g["operations"], op_keys_to_display) 343 344 # Visual graph. 345 html += "<h3>Visual Graph</h3>\n" 346 html += "<svg id='subgraph_%s' width='%d' height='%d'></svg>\n"%( 347 g["name"], max(min(len(g["operations"])*100, 1600), 200), len(g["operations"])*100) 348 html += GenerateGraph(g) 349 html += "</div>" 350 351 fd.write(html) 352 353def InitializeHtml(fd): 354 html = "" 355 html += _CSS 356 html += "<h1>%s</h1>"%(tg.FileNames.specName) 357 fd.write(html) 358 359def FinalizeHtml(fd): 360 fd.write("</body></html>\n") 361 362def VisualizeModel(example, fd): 363 if varName is not None and not fnmatch.fnmatch(str(example.testName), varName): 364 print(" Skip variation %s" % example.testName) 365 return 366 print(" Visualizing variation %s" % example.testName) 367 model = example.model 368 g = {} 369 g["options"] = {"relaxed": str(model.isRelaxed), "useSHM": str(tg.Configuration.useSHM())} 370 g["name"] = str(example.testName) 371 g["inputs"] = model.GetInputs() 372 g["outputs"] = model.GetOutputs() 373 g["operands"] = [{ 374 "name": str(op), "type": op.type.type, "dimensions": op.type.GetDimensionsString(), 375 "scale": op.type.scale, "zero_point": op.type.zeroPoint, "lifetime": op.lifetime 376 } for op in model.operands] 377 g["operations"] = [{ 378 "inputs": op.ins, "outputs": op.outs, "opcode": op.optype 379 } for op in model.operations] 380 CreateHtmlFile(g, fd) 381 382# Take a model from command line 383def ParseCmdLine(): 384 parser = argparse.ArgumentParser() 385 parser.add_argument("spec", help="the spec file") 386 parser.add_argument( 387 "-v", "--variation", help="the target variation name/pattern", default=None) 388 parser.add_argument( 389 "-o", "--out", help="the output html path", default="out.html") 390 args = parser.parse_args() 391 tg.FileNames.InitializeFileLists( 392 args.spec, "-", "-", "-", "-", "-") 393 tg.FileNames.NextFile() 394 return os.path.abspath(args.spec), args.variation, os.path.abspath(args.out) 395 396if __name__ == '__main__': 397 specFile, varName, outFile = ParseCmdLine() 398 print("Visualizing from spec: %s" % specFile) 399 exec(open(specFile, "r").read()) 400 with SmartOpen(outFile) as fd: 401 InitializeHtml(fd) 402 Example.DumpAllExamples( 403 DumpModel=None, model_fd=None, 404 DumpExample=VisualizeModel, example_fd=fd, 405 DumpTest=None, test_fd=None) 406 FinalizeHtml(fd) 407 print("Output HTML file: %s" % outFile) 408 409