1#!/usr/bin/env python 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""This tool creates an html visualization of a TensorFlow Lite graph. 17 18Example usage: 19 20python visualize.py foo.tflite foo.html 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import json 28import os 29import re 30import sys 31import numpy as np 32 33# pylint: disable=g-import-not-at-top 34if not os.path.splitext(__file__)[0].endswith( 35 os.path.join("tflite_runtime", "visualize")): 36 # This file is part of tensorflow package. 37 from tensorflow.lite.python import schema_py_generated as schema_fb 38else: 39 # This file is part of tflite_runtime package. 40 from tflite_runtime import schema_py_generated as schema_fb 41 42# A CSS description for making the visualizer 43_CSS = """ 44<html> 45<head> 46<style> 47body {font-family: sans-serif; background-color: #fa0;} 48table {background-color: #eca;} 49th {background-color: black; color: white;} 50h1 { 51 background-color: ffaa00; 52 padding:5px; 53 color: black; 54} 55 56svg { 57 margin: 10px; 58 border: 2px; 59 border-style: solid; 60 border-color: black; 61 background: white; 62} 63 64div { 65 border-radius: 5px; 66 background-color: #fec; 67 padding:5px; 68 margin:5px; 69} 70 71.tooltip {color: blue;} 72.tooltip .tooltipcontent { 73 visibility: hidden; 74 color: black; 75 background-color: yellow; 76 padding: 5px; 77 border-radius: 4px; 78 position: absolute; 79 z-index: 1; 80} 81.tooltip:hover .tooltipcontent { 82 visibility: visible; 83} 84 85.edges line { 86 stroke: #333; 87} 88 89text { 90 font-weight: bold; 91} 92 93.nodes text { 94 color: black; 95 pointer-events: none; 96 font-family: sans-serif; 97 font-size: 11px; 98} 99</style> 100 101<script src="https://d3js.org/d3.v4.min.js"></script> 102 103</head> 104<body> 105""" 106 107_D3_HTML_TEMPLATE = """ 108 <script> 109 function buildGraph() { 110 // Build graph data 111 var graph = %s; 112 113 var svg = d3.select("#subgraph%d") 114 var width = svg.attr("width"); 115 var height = svg.attr("height"); 116 // Make the graph scrollable. 117 svg = svg.call(d3.zoom().on("zoom", function() { 118 svg.attr("transform", d3.event.transform); 119 })).append("g"); 120 121 122 var color = d3.scaleOrdinal(d3.schemeDark2); 123 124 var simulation = d3.forceSimulation() 125 .force("link", d3.forceLink().id(function(d) {return d.id;})) 126 .force("charge", d3.forceManyBody()) 127 .force("center", d3.forceCenter(0.5 * width, 0.5 * height)); 128 129 var edge = svg.append("g").attr("class", "edges").selectAll("line") 130 .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none") 131 132 // Make the node group 133 var node = svg.selectAll(".nodes") 134 .data(graph.nodes) 135 .enter().append("g") 136 .attr("x", function(d){return d.x}) 137 .attr("y", function(d){return d.y}) 138 .attr("transform", function(d) { 139 return "translate( " + d.x + ", " + d.y + ")" 140 }) 141 .attr("class", "nodes") 142 .call(d3.drag() 143 .on("start", function(d) { 144 if(!d3.event.active) simulation.alphaTarget(1.0).restart(); 145 d.fx = d.x;d.fy = d.y; 146 }) 147 .on("drag", function(d) { 148 d.fx = d3.event.x; d.fy = d3.event.y; 149 }) 150 .on("end", function(d) { 151 if (!d3.event.active) simulation.alphaTarget(0); 152 d.fx = d.fy = null; 153 })); 154 // Within the group, draw a box for the node position and text 155 // on the side. 156 157 var node_width = 150; 158 var node_height = 30; 159 160 node.append("rect") 161 .attr("r", "5px") 162 .attr("width", node_width) 163 .attr("height", node_height) 164 .attr("rx", function(d) { return d.group == 1 ? 1 : 10; }) 165 .attr("stroke", "#000000") 166 .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; }) 167 node.append("text") 168 .text(function(d) { return d.name; }) 169 .attr("x", 5) 170 .attr("y", 20) 171 .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; }) 172 // Setup force parameters and update position callback 173 174 175 var node = svg.selectAll(".nodes") 176 .data(graph.nodes); 177 178 // Bind the links 179 var name_to_g = {} 180 node.each(function(data, index, nodes) { 181 console.log(data.id) 182 name_to_g[data.id] = this; 183 }); 184 185 function proc(w, t) { 186 return parseInt(w.getAttribute(t)); 187 } 188 edge.attr("d", function(d) { 189 function lerp(t, a, b) { 190 return (1.0-t) * a + t * b; 191 } 192 var x1 = proc(name_to_g[d.source],"x") + node_width /2; 193 var y1 = proc(name_to_g[d.source],"y") + node_height; 194 var x2 = proc(name_to_g[d.target],"x") + node_width /2; 195 var y2 = proc(name_to_g[d.target],"y"); 196 var s = "M " + x1 + " " + y1 197 + " C " + x1 + " " + lerp(.5, y1, y2) 198 + " " + x2 + " " + lerp(.5, y1, y2) 199 + " " + x2 + " " + y2 200 return s; 201 }); 202 203 } 204 buildGraph() 205</script> 206""" 207 208 209def TensorTypeToName(tensor_type): 210 """Converts a numerical enum to a readable tensor type.""" 211 for name, value in schema_fb.TensorType.__dict__.items(): 212 if value == tensor_type: 213 return name 214 return None 215 216 217def BuiltinCodeToName(code): 218 """Converts a builtin op code enum to a readable name.""" 219 for name, value in schema_fb.BuiltinOperator.__dict__.items(): 220 if value == code: 221 return name 222 return None 223 224 225def NameListToString(name_list): 226 """Converts a list of integers to the equivalent ASCII string.""" 227 if isinstance(name_list, str): 228 return name_list 229 else: 230 result = "" 231 if name_list is not None: 232 for val in name_list: 233 result = result + chr(int(val)) 234 return result 235 236 237class OpCodeMapper(object): 238 """Maps an opcode index to an op name.""" 239 240 def __init__(self, data): 241 self.code_to_name = {} 242 for idx, d in enumerate(data["operator_codes"]): 243 self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"]) 244 if self.code_to_name[idx] == "CUSTOM": 245 self.code_to_name[idx] = NameListToString(d["custom_code"]) 246 247 def __call__(self, x): 248 if x not in self.code_to_name: 249 s = "<UNKNOWN>" 250 else: 251 s = self.code_to_name[x] 252 return "%s (%d)" % (s, x) 253 254 255class DataSizeMapper(object): 256 """For buffers, report the number of bytes.""" 257 258 def __call__(self, x): 259 if x is not None: 260 return "%d bytes" % len(x) 261 else: 262 return "--" 263 264 265class TensorMapper(object): 266 """Maps a list of tensor indices to a tooltip hoverable indicator of more.""" 267 268 def __init__(self, subgraph_data): 269 self.data = subgraph_data 270 271 def __call__(self, x): 272 html = "" 273 html += "<span class='tooltip'><span class='tooltipcontent'>" 274 for i in x: 275 tensor = self.data["tensors"][i] 276 html += str(i) + " " 277 html += NameListToString(tensor["name"]) + " " 278 html += TensorTypeToName(tensor["type"]) + " " 279 html += (repr(tensor["shape"]) if "shape" in tensor else "[]") 280 html += (repr(tensor["shape_signature"]) 281 if "shape_signature" in tensor else "[]") + "<br>" 282 html += "</span>" 283 html += repr(x) 284 html += "</span>" 285 return html 286 287 288def GenerateGraph(subgraph_idx, g, opcode_mapper): 289 """Produces the HTML required to have a d3 visualization of the dag.""" 290 291 def TensorName(idx): 292 return "t%d" % idx 293 294 def OpName(idx): 295 return "o%d" % idx 296 297 edges = [] 298 nodes = [] 299 first = {} 300 second = {} 301 pixel_mult = 200 # TODO(aselle): multiplier for initial placement 302 width_mult = 170 # TODO(aselle): multiplier for initial placement 303 for op_index, op in enumerate(g["operators"] or []): 304 305 for tensor_input_position, tensor_index in enumerate(op["inputs"]): 306 if tensor_index not in first: 307 first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult, 308 (tensor_input_position + 1) * width_mult) 309 edges.append({ 310 "source": TensorName(tensor_index), 311 "target": OpName(op_index) 312 }) 313 for tensor_output_position, tensor_index in enumerate(op["outputs"]): 314 if tensor_index not in second: 315 second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult, 316 (tensor_output_position + 1) * width_mult) 317 edges.append({ 318 "target": TensorName(tensor_index), 319 "source": OpName(op_index) 320 }) 321 322 nodes.append({ 323 "id": OpName(op_index), 324 "name": opcode_mapper(op["opcode_index"]), 325 "group": 2, 326 "x": pixel_mult, 327 "y": (op_index + 1) * pixel_mult 328 }) 329 for tensor_index, tensor in enumerate(g["tensors"]): 330 initial_y = ( 331 first[tensor_index] if tensor_index in first else 332 second[tensor_index] if tensor_index in second else (0, 0)) 333 334 nodes.append({ 335 "id": TensorName(tensor_index), 336 "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index), 337 "group": 1, 338 "x": initial_y[1], 339 "y": initial_y[0] 340 }) 341 graph_str = json.dumps({"nodes": nodes, "edges": edges}) 342 343 html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) 344 return html 345 346 347def GenerateTableHtml(items, keys_to_print, display_index=True): 348 """Given a list of object values and keys to print, make an HTML table. 349 350 Args: 351 items: Items to print an array of dicts. 352 keys_to_print: (key, display_fn). `key` is a key in the object. i.e. 353 items[0][key] should exist. display_fn is the mapping function on display. 354 i.e. the displayed html cell will have the string returned by 355 `mapping_fn(items[0][key])`. 356 display_index: add a column which is the index of each row in `items`. 357 358 Returns: 359 An html table. 360 """ 361 html = "" 362 # Print the list of items 363 html += "<table><tr>\n" 364 html += "<tr>\n" 365 if display_index: 366 html += "<th>index</th>" 367 for h, mapper in keys_to_print: 368 html += "<th>%s</th>" % h 369 html += "</tr>\n" 370 for idx, tensor in enumerate(items): 371 html += "<tr>\n" 372 if display_index: 373 html += "<td>%d</td>" % idx 374 # print tensor.keys() 375 for h, mapper in keys_to_print: 376 val = tensor[h] if h in tensor else None 377 val = val if mapper is None else mapper(val) 378 html += "<td>%s</td>\n" % val 379 380 html += "</tr>\n" 381 html += "</table>\n" 382 return html 383 384 385def CamelCaseToSnakeCase(camel_case_input): 386 """Converts an identifier in CamelCase to snake_case.""" 387 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input) 388 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 389 390 391def FlatbufferToDict(fb, preserve_as_numpy): 392 """Converts a hierarchy of FB objects into a nested dict. 393 394 We avoid transforming big parts of the flat buffer into python arrays. This 395 speeds conversion from ten minutes to a few seconds on big graphs. 396 397 Args: 398 fb: a flat buffer structure. (i.e. ModelT) 399 preserve_as_numpy: true if all downstream np.arrays should be preserved. 400 false if all downstream np.array should become python arrays 401 Returns: 402 A dictionary representing the flatbuffer rather than a flatbuffer object. 403 """ 404 if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): 405 return fb 406 elif hasattr(fb, "__dict__"): 407 result = {} 408 for attribute_name in dir(fb): 409 attribute = fb.__getattribute__(attribute_name) 410 if not callable(attribute) and attribute_name[0] != "_": 411 snake_name = CamelCaseToSnakeCase(attribute_name) 412 preserve = True if attribute_name == "buffers" else preserve_as_numpy 413 result[snake_name] = FlatbufferToDict(attribute, preserve) 414 return result 415 elif isinstance(fb, np.ndarray): 416 return fb if preserve_as_numpy else fb.tolist() 417 elif hasattr(fb, "__len__"): 418 return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] 419 else: 420 return fb 421 422 423def CreateDictFromFlatbuffer(buffer_data): 424 model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) 425 model = schema_fb.ModelT.InitFromObj(model_obj) 426 return FlatbufferToDict(model, preserve_as_numpy=False) 427 428 429def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name 430 """Returns html description with the given tflite model. 431 432 Args: 433 tflite_input: TFLite flatbuffer model path or model object. 434 input_is_filepath: Tells if tflite_input is a model path or a model object. 435 436 Returns: 437 Dump of the given tflite model in HTML format. 438 439 Raises: 440 RuntimeError: If the input is not valid. 441 """ 442 443 # Convert the model into a JSON flatbuffer using flatc (build if doesn't 444 # exist. 445 if input_is_filepath: 446 if not os.path.exists(tflite_input): 447 raise RuntimeError("Invalid filename %r" % tflite_input) 448 if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): 449 with open(tflite_input, "rb") as file_handle: 450 file_data = bytearray(file_handle.read()) 451 data = CreateDictFromFlatbuffer(file_data) 452 elif tflite_input.endswith(".json"): 453 data = json.load(open(tflite_input)) 454 else: 455 raise RuntimeError("Input file was not .tflite or .json") 456 else: 457 data = CreateDictFromFlatbuffer(tflite_input) 458 html = "" 459 html += _CSS 460 html += "<h1>TensorFlow Lite Model</h2>" 461 462 data["filename"] = tflite_input # Avoid special case 463 toplevel_stuff = [("filename", None), ("version", None), 464 ("description", None)] 465 466 html += "<table>\n" 467 for key, mapping in toplevel_stuff: 468 if not mapping: 469 mapping = lambda x: x 470 html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key))) 471 html += "</table>\n" 472 473 # Spec on what keys to display 474 buffer_keys_to_display = [("data", DataSizeMapper())] 475 operator_keys_to_display = [("builtin_code", BuiltinCodeToName), 476 ("custom_code", NameListToString), 477 ("version", None)] 478 479 # Update builtin code fields. 480 for d in data["operator_codes"]: 481 d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) 482 483 for subgraph_idx, g in enumerate(data["subgraphs"]): 484 # Subgraph local specs on what to display 485 html += "<div class='subgraph'>" 486 tensor_mapper = TensorMapper(g) 487 opcode_mapper = OpCodeMapper(data) 488 op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), 489 ("builtin_options", None), 490 ("opcode_index", opcode_mapper)] 491 tensor_keys_to_display = [("name", NameListToString), 492 ("type", TensorTypeToName), ("shape", None), 493 ("shape_signature", None), ("buffer", None), 494 ("quantization", None)] 495 496 html += "<h2>Subgraph %d</h2>\n" % subgraph_idx 497 498 # Inputs and outputs. 499 html += "<h3>Inputs/Outputs</h3>\n" 500 html += GenerateTableHtml([{ 501 "inputs": g["inputs"], 502 "outputs": g["outputs"] 503 }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], 504 display_index=False) 505 506 # Print the tensors. 507 html += "<h3>Tensors</h3>\n" 508 html += GenerateTableHtml(g["tensors"], tensor_keys_to_display) 509 510 # Print the ops. 511 if g["operators"]: 512 html += "<h3>Ops</h3>\n" 513 html += GenerateTableHtml(g["operators"], op_keys_to_display) 514 515 # Visual graph. 516 html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % ( 517 subgraph_idx,) 518 html += GenerateGraph(subgraph_idx, g, opcode_mapper) 519 html += "</div>" 520 521 # Buffers have no data, but maybe in the future they will 522 html += "<h2>Buffers</h2>\n" 523 html += GenerateTableHtml(data["buffers"], buffer_keys_to_display) 524 525 # Operator codes 526 html += "<h2>Operator Codes</h2>\n" 527 html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) 528 529 html += "</body></html>\n" 530 531 return html 532 533 534def main(argv): 535 try: 536 tflite_input = argv[1] 537 html_output = argv[2] 538 except IndexError: 539 print("Usage: %s <input tflite> <output html>" % (argv[0])) 540 else: 541 html = create_html(tflite_input) 542 with open(html_output, "w") as output_file: 543 output_file.write(html) 544 545 546if __name__ == "__main__": 547 main(sys.argv) 548