• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""This tool creates an html visualization of a TensorFlow Lite graph.
17
18Example usage:
19
20python visualize.py foo.tflite foo.html
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import json
28import os
29import re
30import sys
31import numpy as np
32
33# pylint: disable=g-import-not-at-top
34if not os.path.splitext(__file__)[0].endswith(
35    os.path.join("tflite_runtime", "visualize")):
36  # This file is part of tensorflow package.
37  from tensorflow.lite.python import schema_py_generated as schema_fb
38else:
39  # This file is part of tflite_runtime package.
40  from tflite_runtime import schema_py_generated as schema_fb
41
42# A CSS description for making the visualizer
43_CSS = """
44<html>
45<head>
46<style>
47body {font-family: sans-serif; background-color: #fa0;}
48table {background-color: #eca;}
49th {background-color: black; color: white;}
50h1 {
51  background-color: ffaa00;
52  padding:5px;
53  color: black;
54}
55
56svg {
57  margin: 10px;
58  border: 2px;
59  border-style: solid;
60  border-color: black;
61  background: white;
62}
63
64div {
65  border-radius: 5px;
66  background-color: #fec;
67  padding:5px;
68  margin:5px;
69}
70
71.tooltip {color: blue;}
72.tooltip .tooltipcontent  {
73    visibility: hidden;
74    color: black;
75    background-color: yellow;
76    padding: 5px;
77    border-radius: 4px;
78    position: absolute;
79    z-index: 1;
80}
81.tooltip:hover .tooltipcontent {
82    visibility: visible;
83}
84
85.edges line {
86  stroke: #333;
87}
88
89text {
90  font-weight: bold;
91}
92
93.nodes text {
94  color: black;
95  pointer-events: none;
96  font-family: sans-serif;
97  font-size: 11px;
98}
99</style>
100
101<script src="https://d3js.org/d3.v4.min.js"></script>
102
103</head>
104<body>
105"""
106
107_D3_HTML_TEMPLATE = """
108  <script>
109    function buildGraph() {
110      // Build graph data
111      var graph = %s;
112
113      var svg = d3.select("#subgraph%d")
114      var width = svg.attr("width");
115      var height = svg.attr("height");
116      // Make the graph scrollable.
117      svg = svg.call(d3.zoom().on("zoom", function() {
118        svg.attr("transform", d3.event.transform);
119      })).append("g");
120
121
122      var color = d3.scaleOrdinal(d3.schemeDark2);
123
124      var simulation = d3.forceSimulation()
125          .force("link", d3.forceLink().id(function(d) {return d.id;}))
126          .force("charge", d3.forceManyBody())
127          .force("center", d3.forceCenter(0.5 * width, 0.5 * height));
128
129      var edge = svg.append("g").attr("class", "edges").selectAll("line")
130        .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
131
132      // Make the node group
133      var node = svg.selectAll(".nodes")
134        .data(graph.nodes)
135        .enter().append("g")
136        .attr("x", function(d){return d.x})
137        .attr("y", function(d){return d.y})
138        .attr("transform", function(d) {
139          return "translate( " + d.x + ", " + d.y + ")"
140        })
141        .attr("class", "nodes")
142          .call(d3.drag()
143              .on("start", function(d) {
144                if(!d3.event.active) simulation.alphaTarget(1.0).restart();
145                d.fx = d.x;d.fy = d.y;
146              })
147              .on("drag", function(d) {
148                d.fx = d3.event.x; d.fy = d3.event.y;
149              })
150              .on("end", function(d) {
151                if (!d3.event.active) simulation.alphaTarget(0);
152                d.fx = d.fy = null;
153              }));
154      // Within the group, draw a box for the node position and text
155      // on the side.
156
157      var node_width = 150;
158      var node_height = 30;
159
160      node.append("rect")
161          .attr("r", "5px")
162          .attr("width", node_width)
163          .attr("height", node_height)
164          .attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
165          .attr("stroke", "#000000")
166          .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
167      node.append("text")
168          .text(function(d) { return d.name; })
169          .attr("x", 5)
170          .attr("y", 20)
171          .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
172      // Setup force parameters and update position callback
173
174
175      var node = svg.selectAll(".nodes")
176        .data(graph.nodes);
177
178      // Bind the links
179      var name_to_g = {}
180      node.each(function(data, index, nodes) {
181        console.log(data.id)
182        name_to_g[data.id] = this;
183      });
184
185      function proc(w, t) {
186        return parseInt(w.getAttribute(t));
187      }
188      edge.attr("d", function(d) {
189        function lerp(t, a, b) {
190          return (1.0-t) * a + t * b;
191        }
192        var x1 = proc(name_to_g[d.source],"x") + node_width /2;
193        var y1 = proc(name_to_g[d.source],"y") + node_height;
194        var x2 = proc(name_to_g[d.target],"x") + node_width /2;
195        var y2 = proc(name_to_g[d.target],"y");
196        var s = "M " + x1 + " " + y1
197            + " C " + x1 + " " + lerp(.5, y1, y2)
198            + " " + x2 + " " + lerp(.5, y1, y2)
199            + " " + x2  + " " + y2
200      return s;
201    });
202
203  }
204  buildGraph()
205</script>
206"""
207
208
209def TensorTypeToName(tensor_type):
210  """Converts a numerical enum to a readable tensor type."""
211  for name, value in schema_fb.TensorType.__dict__.items():
212    if value == tensor_type:
213      return name
214  return None
215
216
217def BuiltinCodeToName(code):
218  """Converts a builtin op code enum to a readable name."""
219  for name, value in schema_fb.BuiltinOperator.__dict__.items():
220    if value == code:
221      return name
222  return None
223
224
225def NameListToString(name_list):
226  """Converts a list of integers to the equivalent ASCII string."""
227  if isinstance(name_list, str):
228    return name_list
229  else:
230    result = ""
231    if name_list is not None:
232      for val in name_list:
233        result = result + chr(int(val))
234    return result
235
236
237class OpCodeMapper(object):
238  """Maps an opcode index to an op name."""
239
240  def __init__(self, data):
241    self.code_to_name = {}
242    for idx, d in enumerate(data["operator_codes"]):
243      self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
244      if self.code_to_name[idx] == "CUSTOM":
245        self.code_to_name[idx] = NameListToString(d["custom_code"])
246
247  def __call__(self, x):
248    if x not in self.code_to_name:
249      s = "<UNKNOWN>"
250    else:
251      s = self.code_to_name[x]
252    return "%s (%d)" % (s, x)
253
254
255class DataSizeMapper(object):
256  """For buffers, report the number of bytes."""
257
258  def __call__(self, x):
259    if x is not None:
260      return "%d bytes" % len(x)
261    else:
262      return "--"
263
264
265class TensorMapper(object):
266  """Maps a list of tensor indices to a tooltip hoverable indicator of more."""
267
268  def __init__(self, subgraph_data):
269    self.data = subgraph_data
270
271  def __call__(self, x):
272    html = ""
273    html += "<span class='tooltip'><span class='tooltipcontent'>"
274    for i in x:
275      tensor = self.data["tensors"][i]
276      html += str(i) + " "
277      html += NameListToString(tensor["name"]) + " "
278      html += TensorTypeToName(tensor["type"]) + " "
279      html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
280      html += (repr(tensor["shape_signature"])
281               if "shape_signature" in tensor else "[]") + "<br>"
282    html += "</span>"
283    html += repr(x)
284    html += "</span>"
285    return html
286
287
288def GenerateGraph(subgraph_idx, g, opcode_mapper):
289  """Produces the HTML required to have a d3 visualization of the dag."""
290
291  def TensorName(idx):
292    return "t%d" % idx
293
294  def OpName(idx):
295    return "o%d" % idx
296
297  edges = []
298  nodes = []
299  first = {}
300  second = {}
301  pixel_mult = 200  # TODO(aselle): multiplier for initial placement
302  width_mult = 170  # TODO(aselle): multiplier for initial placement
303  for op_index, op in enumerate(g["operators"] or []):
304
305    for tensor_input_position, tensor_index in enumerate(op["inputs"]):
306      if tensor_index not in first:
307        first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
308                               (tensor_input_position + 1) * width_mult)
309      edges.append({
310          "source": TensorName(tensor_index),
311          "target": OpName(op_index)
312      })
313    for tensor_output_position, tensor_index in enumerate(op["outputs"]):
314      if tensor_index not in second:
315        second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
316                                (tensor_output_position + 1) * width_mult)
317      edges.append({
318          "target": TensorName(tensor_index),
319          "source": OpName(op_index)
320      })
321
322    nodes.append({
323        "id": OpName(op_index),
324        "name": opcode_mapper(op["opcode_index"]),
325        "group": 2,
326        "x": pixel_mult,
327        "y": (op_index + 1) * pixel_mult
328    })
329  for tensor_index, tensor in enumerate(g["tensors"]):
330    initial_y = (
331        first[tensor_index] if tensor_index in first else
332        second[tensor_index] if tensor_index in second else (0, 0))
333
334    nodes.append({
335        "id": TensorName(tensor_index),
336        "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
337        "group": 1,
338        "x": initial_y[1],
339        "y": initial_y[0]
340    })
341  graph_str = json.dumps({"nodes": nodes, "edges": edges})
342
343  html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
344  return html
345
346
347def GenerateTableHtml(items, keys_to_print, display_index=True):
348  """Given a list of object values and keys to print, make an HTML table.
349
350  Args:
351    items: Items to print an array of dicts.
352    keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
353      items[0][key] should exist. display_fn is the mapping function on display.
354      i.e. the displayed html cell will have the string returned by
355      `mapping_fn(items[0][key])`.
356    display_index: add a column which is the index of each row in `items`.
357
358  Returns:
359    An html table.
360  """
361  html = ""
362  # Print the list of  items
363  html += "<table><tr>\n"
364  html += "<tr>\n"
365  if display_index:
366    html += "<th>index</th>"
367  for h, mapper in keys_to_print:
368    html += "<th>%s</th>" % h
369  html += "</tr>\n"
370  for idx, tensor in enumerate(items):
371    html += "<tr>\n"
372    if display_index:
373      html += "<td>%d</td>" % idx
374    # print tensor.keys()
375    for h, mapper in keys_to_print:
376      val = tensor[h] if h in tensor else None
377      val = val if mapper is None else mapper(val)
378      html += "<td>%s</td>\n" % val
379
380    html += "</tr>\n"
381  html += "</table>\n"
382  return html
383
384
385def CamelCaseToSnakeCase(camel_case_input):
386  """Converts an identifier in CamelCase to snake_case."""
387  s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
388  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
389
390
391def FlatbufferToDict(fb, preserve_as_numpy):
392  """Converts a hierarchy of FB objects into a nested dict.
393
394  We avoid transforming big parts of the flat buffer into python arrays. This
395  speeds conversion from ten minutes to a few seconds on big graphs.
396
397  Args:
398    fb: a flat buffer structure. (i.e. ModelT)
399    preserve_as_numpy: true if all downstream np.arrays should be preserved.
400      false if all downstream np.array should become python arrays
401  Returns:
402    A dictionary representing the flatbuffer rather than a flatbuffer object.
403  """
404  if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
405    return fb
406  elif hasattr(fb, "__dict__"):
407    result = {}
408    for attribute_name in dir(fb):
409      attribute = fb.__getattribute__(attribute_name)
410      if not callable(attribute) and attribute_name[0] != "_":
411        snake_name = CamelCaseToSnakeCase(attribute_name)
412        preserve = True if attribute_name == "buffers" else preserve_as_numpy
413        result[snake_name] = FlatbufferToDict(attribute, preserve)
414    return result
415  elif isinstance(fb, np.ndarray):
416    return fb if preserve_as_numpy else fb.tolist()
417  elif hasattr(fb, "__len__"):
418    return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
419  else:
420    return fb
421
422
423def CreateDictFromFlatbuffer(buffer_data):
424  model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
425  model = schema_fb.ModelT.InitFromObj(model_obj)
426  return FlatbufferToDict(model, preserve_as_numpy=False)
427
428
429def create_html(tflite_input, input_is_filepath=True):  # pylint: disable=invalid-name
430  """Returns html description with the given tflite model.
431
432  Args:
433    tflite_input: TFLite flatbuffer model path or model object.
434    input_is_filepath: Tells if tflite_input is a model path or a model object.
435
436  Returns:
437    Dump of the given tflite model in HTML format.
438
439  Raises:
440    RuntimeError: If the input is not valid.
441  """
442
443  # Convert the model into a JSON flatbuffer using flatc (build if doesn't
444  # exist.
445  if input_is_filepath:
446    if not os.path.exists(tflite_input):
447      raise RuntimeError("Invalid filename %r" % tflite_input)
448    if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
449      with open(tflite_input, "rb") as file_handle:
450        file_data = bytearray(file_handle.read())
451      data = CreateDictFromFlatbuffer(file_data)
452    elif tflite_input.endswith(".json"):
453      data = json.load(open(tflite_input))
454    else:
455      raise RuntimeError("Input file was not .tflite or .json")
456  else:
457    data = CreateDictFromFlatbuffer(tflite_input)
458  html = ""
459  html += _CSS
460  html += "<h1>TensorFlow Lite Model</h2>"
461
462  data["filename"] = tflite_input  # Avoid special case
463  toplevel_stuff = [("filename", None), ("version", None),
464                    ("description", None)]
465
466  html += "<table>\n"
467  for key, mapping in toplevel_stuff:
468    if not mapping:
469      mapping = lambda x: x
470    html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
471  html += "</table>\n"
472
473  # Spec on what keys to display
474  buffer_keys_to_display = [("data", DataSizeMapper())]
475  operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
476                              ("custom_code", NameListToString),
477                              ("version", None)]
478
479  # Update builtin code fields.
480  for d in data["operator_codes"]:
481    d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
482
483  for subgraph_idx, g in enumerate(data["subgraphs"]):
484    # Subgraph local specs on what to display
485    html += "<div class='subgraph'>"
486    tensor_mapper = TensorMapper(g)
487    opcode_mapper = OpCodeMapper(data)
488    op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
489                          ("builtin_options", None),
490                          ("opcode_index", opcode_mapper)]
491    tensor_keys_to_display = [("name", NameListToString),
492                              ("type", TensorTypeToName), ("shape", None),
493                              ("shape_signature", None), ("buffer", None),
494                              ("quantization", None)]
495
496    html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
497
498    # Inputs and outputs.
499    html += "<h3>Inputs/Outputs</h3>\n"
500    html += GenerateTableHtml([{
501        "inputs": g["inputs"],
502        "outputs": g["outputs"]
503    }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
504                              display_index=False)
505
506    # Print the tensors.
507    html += "<h3>Tensors</h3>\n"
508    html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
509
510    # Print the ops.
511    if g["operators"]:
512      html += "<h3>Ops</h3>\n"
513      html += GenerateTableHtml(g["operators"], op_keys_to_display)
514
515    # Visual graph.
516    html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
517        subgraph_idx,)
518    html += GenerateGraph(subgraph_idx, g, opcode_mapper)
519    html += "</div>"
520
521  # Buffers have no data, but maybe in the future they will
522  html += "<h2>Buffers</h2>\n"
523  html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
524
525  # Operator codes
526  html += "<h2>Operator Codes</h2>\n"
527  html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
528
529  html += "</body></html>\n"
530
531  return html
532
533
534def main(argv):
535  try:
536    tflite_input = argv[1]
537    html_output = argv[2]
538  except IndexError:
539    print("Usage: %s <input tflite> <output html>" % (argv[0]))
540  else:
541    html = create_html(tflite_input)
542    with open(html_output, "w") as output_file:
543      output_file.write(html)
544
545
546if __name__ == "__main__":
547  main(sys.argv)
548