1#!/usr/bin/env python 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""This tool creates an html visualization of a TensorFlow Lite graph. 17 18Example usage: 19 20python visualize.py foo.tflite foo.html 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import json 28import os 29import re 30import sys 31import numpy as np 32 33from tensorflow.lite.python import schema_py_generated as schema_fb 34 35# A CSS description for making the visualizer 36_CSS = """ 37<html> 38<head> 39<style> 40body {font-family: sans-serif; background-color: #fa0;} 41table {background-color: #eca;} 42th {background-color: black; color: white;} 43h1 { 44 background-color: ffaa00; 45 padding:5px; 46 color: black; 47} 48 49svg { 50 margin: 10px; 51 border: 2px; 52 border-style: solid; 53 border-color: black; 54 background: white; 55} 56 57div { 58 border-radius: 5px; 59 background-color: #fec; 60 padding:5px; 61 margin:5px; 62} 63 64.tooltip {color: blue;} 65.tooltip .tooltipcontent { 66 visibility: hidden; 67 color: black; 68 background-color: yellow; 69 padding: 5px; 70 border-radius: 4px; 71 position: absolute; 72 z-index: 1; 73} 74.tooltip:hover .tooltipcontent { 75 visibility: visible; 76} 77 78.edges line { 79 stroke: #333; 80} 81 82text { 83 font-weight: bold; 84} 85 86.nodes text { 87 color: black; 88 pointer-events: none; 89 font-family: sans-serif; 90 font-size: 11px; 91} 92</style> 93 94<script src="https://d3js.org/d3.v4.min.js"></script> 95 96</head> 97<body> 98""" 99 100_D3_HTML_TEMPLATE = """ 101 <script> 102 function buildGraph() { 103 // Build graph data 104 var graph = %s; 105 106 var svg = d3.select("#subgraph%d") 107 var width = svg.attr("width"); 108 var height = svg.attr("height"); 109 // Make the graph scrollable. 110 svg = svg.call(d3.zoom().on("zoom", function() { 111 svg.attr("transform", d3.event.transform); 112 })).append("g"); 113 114 115 var color = d3.scaleOrdinal(d3.schemeDark2); 116 117 var simulation = d3.forceSimulation() 118 .force("link", d3.forceLink().id(function(d) {return d.id;})) 119 .force("charge", d3.forceManyBody()) 120 .force("center", d3.forceCenter(0.5 * width, 0.5 * height)); 121 122 var edge = svg.append("g").attr("class", "edges").selectAll("line") 123 .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none") 124 125 // Make the node group 126 var node = svg.selectAll(".nodes") 127 .data(graph.nodes) 128 .enter().append("g") 129 .attr("x", function(d){return d.x}) 130 .attr("y", function(d){return d.y}) 131 .attr("transform", function(d) { 132 return "translate( " + d.x + ", " + d.y + ")" 133 }) 134 .attr("class", "nodes") 135 .call(d3.drag() 136 .on("start", function(d) { 137 if(!d3.event.active) simulation.alphaTarget(1.0).restart(); 138 d.fx = d.x;d.fy = d.y; 139 }) 140 .on("drag", function(d) { 141 d.fx = d3.event.x; d.fy = d3.event.y; 142 }) 143 .on("end", function(d) { 144 if (!d3.event.active) simulation.alphaTarget(0); 145 d.fx = d.fy = null; 146 })); 147 // Within the group, draw a box for the node position and text 148 // on the side. 149 150 var node_width = 150; 151 var node_height = 30; 152 153 node.append("rect") 154 .attr("r", "5px") 155 .attr("width", node_width) 156 .attr("height", node_height) 157 .attr("rx", function(d) { return d.group == 1 ? 1 : 10; }) 158 .attr("stroke", "#000000") 159 .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; }) 160 node.append("text") 161 .text(function(d) { return d.name; }) 162 .attr("x", 5) 163 .attr("y", 20) 164 .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; }) 165 // Setup force parameters and update position callback 166 167 168 var node = svg.selectAll(".nodes") 169 .data(graph.nodes); 170 171 // Bind the links 172 var name_to_g = {} 173 node.each(function(data, index, nodes) { 174 console.log(data.id) 175 name_to_g[data.id] = this; 176 }); 177 178 function proc(w, t) { 179 return parseInt(w.getAttribute(t)); 180 } 181 edge.attr("d", function(d) { 182 function lerp(t, a, b) { 183 return (1.0-t) * a + t * b; 184 } 185 var x1 = proc(name_to_g[d.source],"x") + node_width /2; 186 var y1 = proc(name_to_g[d.source],"y") + node_height; 187 var x2 = proc(name_to_g[d.target],"x") + node_width /2; 188 var y2 = proc(name_to_g[d.target],"y"); 189 var s = "M " + x1 + " " + y1 190 + " C " + x1 + " " + lerp(.5, y1, y2) 191 + " " + x2 + " " + lerp(.5, y1, y2) 192 + " " + x2 + " " + y2 193 return s; 194 }); 195 196 } 197 buildGraph() 198</script> 199""" 200 201 202def TensorTypeToName(tensor_type): 203 """Converts a numerical enum to a readable tensor type.""" 204 for name, value in schema_fb.TensorType.__dict__.items(): 205 if value == tensor_type: 206 return name 207 return None 208 209 210def BuiltinCodeToName(code): 211 """Converts a builtin op code enum to a readable name.""" 212 for name, value in schema_fb.BuiltinOperator.__dict__.items(): 213 if value == code: 214 return name 215 return None 216 217 218def NameListToString(name_list): 219 """Converts a list of integers to the equivalent ASCII string.""" 220 if isinstance(name_list, str): 221 return name_list 222 else: 223 result = "" 224 if name_list is not None: 225 for val in name_list: 226 result = result + chr(int(val)) 227 return result 228 229 230class OpCodeMapper(object): 231 """Maps an opcode index to an op name.""" 232 233 def __init__(self, data): 234 self.code_to_name = {} 235 for idx, d in enumerate(data["operator_codes"]): 236 self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"]) 237 if self.code_to_name[idx] == "CUSTOM": 238 self.code_to_name[idx] = NameListToString(d["custom_code"]) 239 240 def __call__(self, x): 241 if x not in self.code_to_name: 242 s = "<UNKNOWN>" 243 else: 244 s = self.code_to_name[x] 245 return "%s (%d)" % (s, x) 246 247 248class DataSizeMapper(object): 249 """For buffers, report the number of bytes.""" 250 251 def __call__(self, x): 252 if x is not None: 253 return "%d bytes" % len(x) 254 else: 255 return "--" 256 257 258class TensorMapper(object): 259 """Maps a list of tensor indices to a tooltip hoverable indicator of more.""" 260 261 def __init__(self, subgraph_data): 262 self.data = subgraph_data 263 264 def __call__(self, x): 265 html = "" 266 html += "<span class='tooltip'><span class='tooltipcontent'>" 267 for i in x: 268 tensor = self.data["tensors"][i] 269 html += str(i) + " " 270 html += NameListToString(tensor["name"]) + " " 271 html += TensorTypeToName(tensor["type"]) + " " 272 html += (repr(tensor["shape"]) if "shape" in tensor else "[]") 273 html += (repr(tensor["shape_signature"]) 274 if "shape_signature" in tensor else "[]") + "<br>" 275 html += "</span>" 276 html += repr(x) 277 html += "</span>" 278 return html 279 280 281def GenerateGraph(subgraph_idx, g, opcode_mapper): 282 """Produces the HTML required to have a d3 visualization of the dag.""" 283 284 def TensorName(idx): 285 return "t%d" % idx 286 287 def OpName(idx): 288 return "o%d" % idx 289 290 edges = [] 291 nodes = [] 292 first = {} 293 second = {} 294 pixel_mult = 200 # TODO(aselle): multiplier for initial placement 295 width_mult = 170 # TODO(aselle): multiplier for initial placement 296 for op_index, op in enumerate(g["operators"]): 297 298 for tensor_input_position, tensor_index in enumerate(op["inputs"]): 299 if tensor_index not in first: 300 first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult, 301 (tensor_input_position + 1) * width_mult) 302 edges.append({ 303 "source": TensorName(tensor_index), 304 "target": OpName(op_index) 305 }) 306 for tensor_output_position, tensor_index in enumerate(op["outputs"]): 307 if tensor_index not in second: 308 second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult, 309 (tensor_output_position + 1) * width_mult) 310 edges.append({ 311 "target": TensorName(tensor_index), 312 "source": OpName(op_index) 313 }) 314 315 nodes.append({ 316 "id": OpName(op_index), 317 "name": opcode_mapper(op["opcode_index"]), 318 "group": 2, 319 "x": pixel_mult, 320 "y": (op_index + 1) * pixel_mult 321 }) 322 for tensor_index, tensor in enumerate(g["tensors"]): 323 initial_y = ( 324 first[tensor_index] if tensor_index in first else 325 second[tensor_index] if tensor_index in second else (0, 0)) 326 327 nodes.append({ 328 "id": TensorName(tensor_index), 329 "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index), 330 "group": 1, 331 "x": initial_y[1], 332 "y": initial_y[0] 333 }) 334 graph_str = json.dumps({"nodes": nodes, "edges": edges}) 335 336 html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) 337 return html 338 339 340def GenerateTableHtml(items, keys_to_print, display_index=True): 341 """Given a list of object values and keys to print, make an HTML table. 342 343 Args: 344 items: Items to print an array of dicts. 345 keys_to_print: (key, display_fn). `key` is a key in the object. i.e. 346 items[0][key] should exist. display_fn is the mapping function on display. 347 i.e. the displayed html cell will have the string returned by 348 `mapping_fn(items[0][key])`. 349 display_index: add a column which is the index of each row in `items`. 350 351 Returns: 352 An html table. 353 """ 354 html = "" 355 # Print the list of items 356 html += "<table><tr>\n" 357 html += "<tr>\n" 358 if display_index: 359 html += "<th>index</th>" 360 for h, mapper in keys_to_print: 361 html += "<th>%s</th>" % h 362 html += "</tr>\n" 363 for idx, tensor in enumerate(items): 364 html += "<tr>\n" 365 if display_index: 366 html += "<td>%d</td>" % idx 367 # print tensor.keys() 368 for h, mapper in keys_to_print: 369 val = tensor[h] if h in tensor else None 370 val = val if mapper is None else mapper(val) 371 html += "<td>%s</td>\n" % val 372 373 html += "</tr>\n" 374 html += "</table>\n" 375 return html 376 377 378def CamelCaseToSnakeCase(camel_case_input): 379 """Converts an identifier in CamelCase to snake_case.""" 380 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input) 381 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 382 383 384def FlatbufferToDict(fb, preserve_as_numpy): 385 """Converts a hierarchy of FB objects into a nested dict. 386 387 We avoid transforming big parts of the flat buffer into python arrays. This 388 speeds conversion from ten minutes to a few seconds on big graphs. 389 390 Args: 391 fb: a flat buffer structure. (i.e. ModelT) 392 preserve_as_numpy: true if all downstream np.arrays should be preserved. 393 false if all downstream np.array should become python arrays 394 Returns: 395 A dictionary representing the flatbuffer rather than a flatbuffer object. 396 """ 397 if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): 398 return fb 399 elif hasattr(fb, "__dict__"): 400 result = {} 401 for attribute_name in dir(fb): 402 attribute = fb.__getattribute__(attribute_name) 403 if not callable(attribute) and attribute_name[0] != "_": 404 snake_name = CamelCaseToSnakeCase(attribute_name) 405 preserve = True if attribute_name == "buffers" else preserve_as_numpy 406 result[snake_name] = FlatbufferToDict(attribute, preserve) 407 return result 408 elif isinstance(fb, np.ndarray): 409 return fb if preserve_as_numpy else fb.tolist() 410 elif hasattr(fb, "__len__"): 411 return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] 412 else: 413 return fb 414 415 416def CreateDictFromFlatbuffer(buffer_data): 417 model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) 418 model = schema_fb.ModelT.InitFromObj(model_obj) 419 return FlatbufferToDict(model, preserve_as_numpy=False) 420 421 422def CreateHtmlFile(tflite_input, html_output): 423 """Given a tflite model in `tflite_input` file, produce html description.""" 424 425 # Convert the model into a JSON flatbuffer using flatc (build if doesn't 426 # exist. 427 if not os.path.exists(tflite_input): 428 raise RuntimeError("Invalid filename %r" % tflite_input) 429 if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): 430 with open(tflite_input, "rb") as file_handle: 431 file_data = bytearray(file_handle.read()) 432 data = CreateDictFromFlatbuffer(file_data) 433 elif tflite_input.endswith(".json"): 434 data = json.load(open(tflite_input)) 435 else: 436 raise RuntimeError("Input file was not .tflite or .json") 437 html = "" 438 html += _CSS 439 html += "<h1>TensorFlow Lite Model</h2>" 440 441 data["filename"] = tflite_input # Avoid special case 442 toplevel_stuff = [("filename", None), ("version", None), 443 ("description", None)] 444 445 html += "<table>\n" 446 for key, mapping in toplevel_stuff: 447 if not mapping: 448 mapping = lambda x: x 449 html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key))) 450 html += "</table>\n" 451 452 # Spec on what keys to display 453 buffer_keys_to_display = [("data", DataSizeMapper())] 454 operator_keys_to_display = [("builtin_code", BuiltinCodeToName), 455 ("custom_code", NameListToString), 456 ("version", None)] 457 458 # Update builtin code fields. 459 for idx, d in enumerate(data["operator_codes"]): 460 d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) 461 462 for subgraph_idx, g in enumerate(data["subgraphs"]): 463 # Subgraph local specs on what to display 464 html += "<div class='subgraph'>" 465 tensor_mapper = TensorMapper(g) 466 opcode_mapper = OpCodeMapper(data) 467 op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), 468 ("builtin_options", None), 469 ("opcode_index", opcode_mapper)] 470 tensor_keys_to_display = [("name", NameListToString), 471 ("type", TensorTypeToName), ("shape", None), 472 ("shape_signature", None), ("buffer", None), 473 ("quantization", None)] 474 475 html += "<h2>Subgraph %d</h2>\n" % subgraph_idx 476 477 # Inputs and outputs. 478 html += "<h3>Inputs/Outputs</h3>\n" 479 html += GenerateTableHtml([{ 480 "inputs": g["inputs"], 481 "outputs": g["outputs"] 482 }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], 483 display_index=False) 484 485 # Print the tensors. 486 html += "<h3>Tensors</h3>\n" 487 html += GenerateTableHtml(g["tensors"], tensor_keys_to_display) 488 489 # Print the ops. 490 html += "<h3>Ops</h3>\n" 491 html += GenerateTableHtml(g["operators"], op_keys_to_display) 492 493 # Visual graph. 494 html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % ( 495 subgraph_idx,) 496 html += GenerateGraph(subgraph_idx, g, opcode_mapper) 497 html += "</div>" 498 499 # Buffers have no data, but maybe in the future they will 500 html += "<h2>Buffers</h2>\n" 501 html += GenerateTableHtml(data["buffers"], buffer_keys_to_display) 502 503 # Operator codes 504 html += "<h2>Operator Codes</h2>\n" 505 html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) 506 507 html += "</body></html>\n" 508 509 with open(html_output, "w") as output_file: 510 output_file.write(html) 511 512 513def main(argv): 514 try: 515 tflite_input = argv[1] 516 html_output = argv[2] 517 except IndexError: 518 print("Usage: %s <input tflite> <output html>" % (argv[0])) 519 else: 520 CreateHtmlFile(tflite_input, html_output) 521 522 523if __name__ == "__main__": 524 main(sys.argv) 525