• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""This tool creates an html visualization of a TensorFlow Lite graph.
17
18Example usage:
19
20python visualize.py foo.tflite foo.html
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import json
28import os
29import re
30import sys
31import numpy as np
32
33from tensorflow.lite.python import schema_py_generated as schema_fb
34
35# A CSS description for making the visualizer
36_CSS = """
37<html>
38<head>
39<style>
40body {font-family: sans-serif; background-color: #fa0;}
41table {background-color: #eca;}
42th {background-color: black; color: white;}
43h1 {
44  background-color: ffaa00;
45  padding:5px;
46  color: black;
47}
48
49svg {
50  margin: 10px;
51  border: 2px;
52  border-style: solid;
53  border-color: black;
54  background: white;
55}
56
57div {
58  border-radius: 5px;
59  background-color: #fec;
60  padding:5px;
61  margin:5px;
62}
63
64.tooltip {color: blue;}
65.tooltip .tooltipcontent  {
66    visibility: hidden;
67    color: black;
68    background-color: yellow;
69    padding: 5px;
70    border-radius: 4px;
71    position: absolute;
72    z-index: 1;
73}
74.tooltip:hover .tooltipcontent {
75    visibility: visible;
76}
77
78.edges line {
79  stroke: #333;
80}
81
82text {
83  font-weight: bold;
84}
85
86.nodes text {
87  color: black;
88  pointer-events: none;
89  font-family: sans-serif;
90  font-size: 11px;
91}
92</style>
93
94<script src="https://d3js.org/d3.v4.min.js"></script>
95
96</head>
97<body>
98"""
99
100_D3_HTML_TEMPLATE = """
101  <script>
102    function buildGraph() {
103      // Build graph data
104      var graph = %s;
105
106      var svg = d3.select("#subgraph%d")
107      var width = svg.attr("width");
108      var height = svg.attr("height");
109      // Make the graph scrollable.
110      svg = svg.call(d3.zoom().on("zoom", function() {
111        svg.attr("transform", d3.event.transform);
112      })).append("g");
113
114
115      var color = d3.scaleOrdinal(d3.schemeDark2);
116
117      var simulation = d3.forceSimulation()
118          .force("link", d3.forceLink().id(function(d) {return d.id;}))
119          .force("charge", d3.forceManyBody())
120          .force("center", d3.forceCenter(0.5 * width, 0.5 * height));
121
122      var edge = svg.append("g").attr("class", "edges").selectAll("line")
123        .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
124
125      // Make the node group
126      var node = svg.selectAll(".nodes")
127        .data(graph.nodes)
128        .enter().append("g")
129        .attr("x", function(d){return d.x})
130        .attr("y", function(d){return d.y})
131        .attr("transform", function(d) {
132          return "translate( " + d.x + ", " + d.y + ")"
133        })
134        .attr("class", "nodes")
135          .call(d3.drag()
136              .on("start", function(d) {
137                if(!d3.event.active) simulation.alphaTarget(1.0).restart();
138                d.fx = d.x;d.fy = d.y;
139              })
140              .on("drag", function(d) {
141                d.fx = d3.event.x; d.fy = d3.event.y;
142              })
143              .on("end", function(d) {
144                if (!d3.event.active) simulation.alphaTarget(0);
145                d.fx = d.fy = null;
146              }));
147      // Within the group, draw a box for the node position and text
148      // on the side.
149
150      var node_width = 150;
151      var node_height = 30;
152
153      node.append("rect")
154          .attr("r", "5px")
155          .attr("width", node_width)
156          .attr("height", node_height)
157          .attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
158          .attr("stroke", "#000000")
159          .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
160      node.append("text")
161          .text(function(d) { return d.name; })
162          .attr("x", 5)
163          .attr("y", 20)
164          .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
165      // Setup force parameters and update position callback
166
167
168      var node = svg.selectAll(".nodes")
169        .data(graph.nodes);
170
171      // Bind the links
172      var name_to_g = {}
173      node.each(function(data, index, nodes) {
174        console.log(data.id)
175        name_to_g[data.id] = this;
176      });
177
178      function proc(w, t) {
179        return parseInt(w.getAttribute(t));
180      }
181      edge.attr("d", function(d) {
182        function lerp(t, a, b) {
183          return (1.0-t) * a + t * b;
184        }
185        var x1 = proc(name_to_g[d.source],"x") + node_width /2;
186        var y1 = proc(name_to_g[d.source],"y") + node_height;
187        var x2 = proc(name_to_g[d.target],"x") + node_width /2;
188        var y2 = proc(name_to_g[d.target],"y");
189        var s = "M " + x1 + " " + y1
190            + " C " + x1 + " " + lerp(.5, y1, y2)
191            + " " + x2 + " " + lerp(.5, y1, y2)
192            + " " + x2  + " " + y2
193      return s;
194    });
195
196  }
197  buildGraph()
198</script>
199"""
200
201
202def TensorTypeToName(tensor_type):
203  """Converts a numerical enum to a readable tensor type."""
204  for name, value in schema_fb.TensorType.__dict__.items():
205    if value == tensor_type:
206      return name
207  return None
208
209
210def BuiltinCodeToName(code):
211  """Converts a builtin op code enum to a readable name."""
212  for name, value in schema_fb.BuiltinOperator.__dict__.items():
213    if value == code:
214      return name
215  return None
216
217
218def NameListToString(name_list):
219  """Converts a list of integers to the equivalent ASCII string."""
220  if isinstance(name_list, str):
221    return name_list
222  else:
223    result = ""
224    if name_list is not None:
225      for val in name_list:
226        result = result + chr(int(val))
227    return result
228
229
230class OpCodeMapper(object):
231  """Maps an opcode index to an op name."""
232
233  def __init__(self, data):
234    self.code_to_name = {}
235    for idx, d in enumerate(data["operator_codes"]):
236      self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
237      if self.code_to_name[idx] == "CUSTOM":
238        self.code_to_name[idx] = NameListToString(d["custom_code"])
239
240  def __call__(self, x):
241    if x not in self.code_to_name:
242      s = "<UNKNOWN>"
243    else:
244      s = self.code_to_name[x]
245    return "%s (%d)" % (s, x)
246
247
248class DataSizeMapper(object):
249  """For buffers, report the number of bytes."""
250
251  def __call__(self, x):
252    if x is not None:
253      return "%d bytes" % len(x)
254    else:
255      return "--"
256
257
258class TensorMapper(object):
259  """Maps a list of tensor indices to a tooltip hoverable indicator of more."""
260
261  def __init__(self, subgraph_data):
262    self.data = subgraph_data
263
264  def __call__(self, x):
265    html = ""
266    html += "<span class='tooltip'><span class='tooltipcontent'>"
267    for i in x:
268      tensor = self.data["tensors"][i]
269      html += str(i) + " "
270      html += NameListToString(tensor["name"]) + " "
271      html += TensorTypeToName(tensor["type"]) + " "
272      html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
273      html += (repr(tensor["shape_signature"])
274               if "shape_signature" in tensor else "[]") + "<br>"
275    html += "</span>"
276    html += repr(x)
277    html += "</span>"
278    return html
279
280
281def GenerateGraph(subgraph_idx, g, opcode_mapper):
282  """Produces the HTML required to have a d3 visualization of the dag."""
283
284  def TensorName(idx):
285    return "t%d" % idx
286
287  def OpName(idx):
288    return "o%d" % idx
289
290  edges = []
291  nodes = []
292  first = {}
293  second = {}
294  pixel_mult = 200  # TODO(aselle): multiplier for initial placement
295  width_mult = 170  # TODO(aselle): multiplier for initial placement
296  for op_index, op in enumerate(g["operators"]):
297
298    for tensor_input_position, tensor_index in enumerate(op["inputs"]):
299      if tensor_index not in first:
300        first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
301                               (tensor_input_position + 1) * width_mult)
302      edges.append({
303          "source": TensorName(tensor_index),
304          "target": OpName(op_index)
305      })
306    for tensor_output_position, tensor_index in enumerate(op["outputs"]):
307      if tensor_index not in second:
308        second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
309                                (tensor_output_position + 1) * width_mult)
310      edges.append({
311          "target": TensorName(tensor_index),
312          "source": OpName(op_index)
313      })
314
315    nodes.append({
316        "id": OpName(op_index),
317        "name": opcode_mapper(op["opcode_index"]),
318        "group": 2,
319        "x": pixel_mult,
320        "y": (op_index + 1) * pixel_mult
321    })
322  for tensor_index, tensor in enumerate(g["tensors"]):
323    initial_y = (
324        first[tensor_index] if tensor_index in first else
325        second[tensor_index] if tensor_index in second else (0, 0))
326
327    nodes.append({
328        "id": TensorName(tensor_index),
329        "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
330        "group": 1,
331        "x": initial_y[1],
332        "y": initial_y[0]
333    })
334  graph_str = json.dumps({"nodes": nodes, "edges": edges})
335
336  html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
337  return html
338
339
340def GenerateTableHtml(items, keys_to_print, display_index=True):
341  """Given a list of object values and keys to print, make an HTML table.
342
343  Args:
344    items: Items to print an array of dicts.
345    keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
346      items[0][key] should exist. display_fn is the mapping function on display.
347      i.e. the displayed html cell will have the string returned by
348      `mapping_fn(items[0][key])`.
349    display_index: add a column which is the index of each row in `items`.
350
351  Returns:
352    An html table.
353  """
354  html = ""
355  # Print the list of  items
356  html += "<table><tr>\n"
357  html += "<tr>\n"
358  if display_index:
359    html += "<th>index</th>"
360  for h, mapper in keys_to_print:
361    html += "<th>%s</th>" % h
362  html += "</tr>\n"
363  for idx, tensor in enumerate(items):
364    html += "<tr>\n"
365    if display_index:
366      html += "<td>%d</td>" % idx
367    # print tensor.keys()
368    for h, mapper in keys_to_print:
369      val = tensor[h] if h in tensor else None
370      val = val if mapper is None else mapper(val)
371      html += "<td>%s</td>\n" % val
372
373    html += "</tr>\n"
374  html += "</table>\n"
375  return html
376
377
378def CamelCaseToSnakeCase(camel_case_input):
379  """Converts an identifier in CamelCase to snake_case."""
380  s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
381  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
382
383
384def FlatbufferToDict(fb, preserve_as_numpy):
385  """Converts a hierarchy of FB objects into a nested dict.
386
387  We avoid transforming big parts of the flat buffer into python arrays. This
388  speeds conversion from ten minutes to a few seconds on big graphs.
389
390  Args:
391    fb: a flat buffer structure. (i.e. ModelT)
392    preserve_as_numpy: true if all downstream np.arrays should be preserved.
393      false if all downstream np.array should become python arrays
394  Returns:
395    A dictionary representing the flatbuffer rather than a flatbuffer object.
396  """
397  if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
398    return fb
399  elif hasattr(fb, "__dict__"):
400    result = {}
401    for attribute_name in dir(fb):
402      attribute = fb.__getattribute__(attribute_name)
403      if not callable(attribute) and attribute_name[0] != "_":
404        snake_name = CamelCaseToSnakeCase(attribute_name)
405        preserve = True if attribute_name == "buffers" else preserve_as_numpy
406        result[snake_name] = FlatbufferToDict(attribute, preserve)
407    return result
408  elif isinstance(fb, np.ndarray):
409    return fb if preserve_as_numpy else fb.tolist()
410  elif hasattr(fb, "__len__"):
411    return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
412  else:
413    return fb
414
415
416def CreateDictFromFlatbuffer(buffer_data):
417  model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
418  model = schema_fb.ModelT.InitFromObj(model_obj)
419  return FlatbufferToDict(model, preserve_as_numpy=False)
420
421
422def CreateHtmlFile(tflite_input, html_output):
423  """Given a tflite model in `tflite_input` file, produce html description."""
424
425  # Convert the model into a JSON flatbuffer using flatc (build if doesn't
426  # exist.
427  if not os.path.exists(tflite_input):
428    raise RuntimeError("Invalid filename %r" % tflite_input)
429  if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
430    with open(tflite_input, "rb") as file_handle:
431      file_data = bytearray(file_handle.read())
432    data = CreateDictFromFlatbuffer(file_data)
433  elif tflite_input.endswith(".json"):
434    data = json.load(open(tflite_input))
435  else:
436    raise RuntimeError("Input file was not .tflite or .json")
437  html = ""
438  html += _CSS
439  html += "<h1>TensorFlow Lite Model</h2>"
440
441  data["filename"] = tflite_input  # Avoid special case
442  toplevel_stuff = [("filename", None), ("version", None),
443                    ("description", None)]
444
445  html += "<table>\n"
446  for key, mapping in toplevel_stuff:
447    if not mapping:
448      mapping = lambda x: x
449    html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
450  html += "</table>\n"
451
452  # Spec on what keys to display
453  buffer_keys_to_display = [("data", DataSizeMapper())]
454  operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
455                              ("custom_code", NameListToString),
456                              ("version", None)]
457
458  # Update builtin code fields.
459  for idx, d in enumerate(data["operator_codes"]):
460    d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
461
462  for subgraph_idx, g in enumerate(data["subgraphs"]):
463    # Subgraph local specs on what to display
464    html += "<div class='subgraph'>"
465    tensor_mapper = TensorMapper(g)
466    opcode_mapper = OpCodeMapper(data)
467    op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
468                          ("builtin_options", None),
469                          ("opcode_index", opcode_mapper)]
470    tensor_keys_to_display = [("name", NameListToString),
471                              ("type", TensorTypeToName), ("shape", None),
472                              ("shape_signature", None), ("buffer", None),
473                              ("quantization", None)]
474
475    html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
476
477    # Inputs and outputs.
478    html += "<h3>Inputs/Outputs</h3>\n"
479    html += GenerateTableHtml([{
480        "inputs": g["inputs"],
481        "outputs": g["outputs"]
482    }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
483                              display_index=False)
484
485    # Print the tensors.
486    html += "<h3>Tensors</h3>\n"
487    html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
488
489    # Print the ops.
490    html += "<h3>Ops</h3>\n"
491    html += GenerateTableHtml(g["operators"], op_keys_to_display)
492
493    # Visual graph.
494    html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
495        subgraph_idx,)
496    html += GenerateGraph(subgraph_idx, g, opcode_mapper)
497    html += "</div>"
498
499  # Buffers have no data, but maybe in the future they will
500  html += "<h2>Buffers</h2>\n"
501  html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
502
503  # Operator codes
504  html += "<h2>Operator Codes</h2>\n"
505  html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
506
507  html += "</body></html>\n"
508
509  with open(html_output, "w") as output_file:
510    output_file.write(html)
511
512
513def main(argv):
514  try:
515    tflite_input = argv[1]
516    html_output = argv[2]
517  except IndexError:
518    print("Usage: %s <input tflite> <output html>" % (argv[0]))
519  else:
520    CreateHtmlFile(tflite_input, html_output)
521
522
523if __name__ == "__main__":
524  main(sys.argv)
525