• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2019, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Spec Visualizer
18
19Visualize python spec file for test generator.
20
21Modified from TFLite graph visualizer -- instead of flatbuffer, takes spec file as input.
22(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py)
23
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29import argparse
30import fnmatch
31import json
32import math
33import os
34import re
35import sys
36import traceback
37
38# Stuff from test generator
39import test_generator as tg
40from test_generator import ActivationConverter
41from test_generator import BoolScalar
42from test_generator import Configuration
43from test_generator import DataTypeConverter
44from test_generator import DataLayoutConverter
45from test_generator import Example
46from test_generator import Float32Scalar
47from test_generator import Float32Vector
48from test_generator import GetJointStr
49from test_generator import IgnoredOutput
50from test_generator import Input
51from test_generator import Int32Scalar
52from test_generator import Int32Vector
53from test_generator import Internal
54from test_generator import Model
55from test_generator import Operand
56from test_generator import Output
57from test_generator import Parameter
58from test_generator import ParameterAsInputConverter
59from test_generator import RelaxedModeConverter
60from test_generator import SmartOpen
61
62# A CSS description for making the visualizer
63_CSS = """
64<html>
65<head>
66<style>
67body {font-family: sans-serif; background-color: #ffaa00;}
68table {background-color: #eeccaa;}
69th {background-color: black; color: white;}
70h1 {
71  background-color: ffaa00;
72  padding:5px;
73  color: black;
74}
75
76div {
77  border-radius: 5px;
78  background-color: #ffeecc;
79  padding:5px;
80  margin:5px;
81}
82
83.tooltip {color: blue;}
84.tooltip .tooltipcontent  {
85    visibility: hidden;
86    color: black;
87    background-color: yellow;
88    padding: 5px;
89    border-radius: 4px;
90    position: absolute;
91    z-index: 1;
92}
93.tooltip:hover .tooltipcontent {
94    visibility: visible;
95}
96
97.edges line {
98  stroke: #333333;
99}
100
101.nodes text {
102  color: black;
103  pointer-events: none;
104  font-family: sans-serif;
105  font-size: 11px;
106}
107</style>
108
109<script src="https://d3js.org/d3.v4.min.js"></script>
110
111</head>
112<body>
113"""
114
115_D3_HTML_TEMPLATE = """
116  <script>
117    // Build graph data
118    var graph = %s;
119
120    var svg = d3.select("#subgraph_%s");
121    var width = svg.attr("width");
122    var height = svg.attr("height");
123    var color = d3.scaleOrdinal(d3.schemeCategory20);
124
125    var simulation = d3.forceSimulation()
126        .force("link", d3.forceLink().id(function(d) {return d.id;}))
127        .force("charge", d3.forceManyBody())
128        .force("center", d3.forceCenter(0.5 * width, 0.5 * height));
129
130
131    function buildGraph() {
132      var edge = svg.append("g").attr("class", "edges").selectAll("line")
133        .data(graph.edges).enter().append("line")
134      // Make the node group
135      var node = svg.selectAll(".nodes")
136        .data(graph.nodes)
137        .enter().append("g")
138        .attr("class", "nodes")
139          .call(d3.drag()
140              .on("start", function(d) {
141                if(!d3.event.active) simulation.alphaTarget(1.0).restart();
142                d.fx = d.x;d.fy = d.y;
143              })
144              .on("drag", function(d) {
145                d.fx = d3.event.x; d.fy = d3.event.y;
146              })
147              .on("end", function(d) {
148                if (!d3.event.active) simulation.alphaTarget(0);
149                d.fx = d.fy = null;
150              }));
151      // Within the group, draw a circle for the node position and text
152      // on the side.
153      node.append("circle")
154          .attr("r", "5px")
155          .attr("fill", function(d) { return color(d.group); })
156      node.append("text")
157          .attr("dx", 8).attr("dy", 5).text(function(d) { return d.name; });
158      // Setup force parameters and update position callback
159      simulation.nodes(graph.nodes).on("tick", forceSimulationUpdated);
160      simulation.force("link").links(graph.edges);
161
162      function forceSimulationUpdated() {
163        // Update edges.
164        edge.attr("x1", function(d) {return d.source.x;})
165            .attr("y1", function(d) {return d.source.y;})
166            .attr("x2", function(d) {return d.target.x;})
167            .attr("y2", function(d) {return d.target.y;});
168        // Update node positions
169        node.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; });
170      }
171    }
172  buildGraph()
173</script>
174"""
175
176class OpCodeMapper(object):
177  """Maps an opcode index to an op name."""
178
179  def __init__(self, data):
180    self.code_to_name = {}
181    for idx, d in enumerate(data["operator_codes"]):
182      self.code_to_name[idx] = d["builtin_code"]
183
184  def __call__(self, x):
185    if x not in self.code_to_name:
186      s = "<UNKNOWN>"
187    else:
188      s = self.code_to_name[x]
189    return "%s (opcode=%d)" % (s, x)
190
191
192class DataSizeMapper(object):
193  """For buffers, report the number of bytes."""
194
195  def __call__(self, x):
196    if x is not None:
197      return "%d bytes" % len(x)
198    else:
199      return "--"
200
201
202class TensorMapper(object):
203  """Maps a list of tensor indices to a tooltip hoverable indicator of more."""
204
205  def __init__(self, subgraph_data):
206    self.data = subgraph_data
207
208  def __call__(self, x):
209    html = ""
210    html += "<span class='tooltip'><span class='tooltipcontent'>"
211    for i in x:
212      tensor = self.data["operands"][i]
213      html += str(i) + " "
214      html += tensor["name"] + " "
215      html += str(tensor["type"]) + " "
216      html += (repr(tensor["dimensions"]) if "dimensions" in tensor else "[]") + "<br>"
217    html += "</span>"
218    html += repr(x)
219    html += "</span>"
220    return html
221
222def GenerateGraph(g):
223  """Produces the HTML required to have a d3 visualization of the dag."""
224
225#   def TensorName(idx):
226#     return "t%d" % idx
227
228  def OpName(idx):
229    return "o%d" % idx
230
231  edges = []
232  nodes = []
233  first = {}
234  pixel_mult = 50
235  for op_index, op in enumerate(g["operations"]):
236    for tensor in op["inputs"]:
237      if tensor not in first:
238        first[str(tensor)] = (
239            op_index * pixel_mult,
240            len(first) * pixel_mult - pixel_mult / 2)
241      edges.append({
242          "source": str(tensor),
243          "target": OpName(op_index)
244      })
245    for tensor in op["outputs"]:
246      edges.append({
247          "target": str(tensor),
248          "source": OpName(op_index)
249      })
250    nodes.append({
251        "id": OpName(op_index),
252        "name": op["opcode"],
253        "group": 2,
254        "x": pixel_mult,
255        "y": op_index * pixel_mult
256    })
257  for tensor_index, tensor in enumerate(g["operands"]):
258    initial_y = (
259        first[tensor["name"]] if tensor["name"] in first else len(g["operations"]))
260
261    nodes.append({
262        "id": tensor["name"],
263        "name": "%s (%d)" % (tensor["name"], tensor_index),
264        "group": 1,
265        "x": 2,
266        "y": initial_y
267    })
268  graph_str = json.dumps({"nodes": nodes, "edges": edges})
269
270  html = _D3_HTML_TEMPLATE % (graph_str, g["name"])
271  return html
272
273def GenerateTableHtml(items, keys_to_print, display_index=True):
274  """Given a list of object values and keys to print, make an HTML table.
275
276  Args:
277    items: Items to print an array of dicts.
278    keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
279      items[0][key] should exist. display_fn is the mapping function on display.
280      i.e. the displayed html cell will have the string returned by
281      `mapping_fn(items[0][key])`.
282    display_index: add a column which is the index of each row in `items`.
283  Returns:
284    An html table.
285  """
286  html = ""
287  # Print the list of  items
288  html += "<table><tr>\n"
289  html += "<tr>\n"
290  if display_index:
291    html += "<th>index</th>"
292  for h, mapper in keys_to_print:
293    html += "<th>%s</th>" % h
294  html += "</tr>\n"
295  for idx, tensor in enumerate(items):
296    html += "<tr>\n"
297    if display_index:
298      html += "<td>%d</td>" % idx
299    # print tensor.keys()
300    for h, mapper in keys_to_print:
301      val = tensor[h] if h in tensor else None
302      val = val if mapper is None else mapper(val)
303      html += "<td>%s</td>\n" % val
304
305    html += "</tr>\n"
306  html += "</table>\n"
307  return html
308
309
310def CreateHtmlFile(g, fd):
311  """Given a tflite model in `tflite_input` file, produce html description."""
312  html = ""
313
314  # Subgraph local specs on what to display
315  html += "<div class='subgraph'>"
316  tensor_mapper = lambda l: ", ".join(str(op) for op in l)
317  op_keys_to_display = [("opcode", None), ("inputs", tensor_mapper), ("outputs", tensor_mapper)]
318  tensor_keys_to_display = [("name", None), ("type", None), ("dimensions", None), ("scale", None),
319                            ("zero_point", None), ("lifetime", None)]
320  html += "<h2>%s</h2>\n" % g["name"]
321
322  # Configurations.
323  html += "<h3>Configurations</h3>\n"
324  html += GenerateTableHtml(
325        [g["options"]], [(k, None) for k in g["options"].keys()], display_index=False)
326
327  # Inputs and outputs.
328  html += "<h3>Inputs/Outputs</h3>\n"
329  html += GenerateTableHtml(
330        [{
331            "inputs": g["inputs"],
332            "outputs": g["outputs"]
333        }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
334        display_index=False)
335
336  # Print the operands.
337  html += "<h3>Operands</h3>\n"
338  html += GenerateTableHtml(g["operands"], tensor_keys_to_display)
339
340  # Print the operations.
341  html += "<h3>Operations</h3>\n"
342  html += GenerateTableHtml(g["operations"], op_keys_to_display)
343
344  # Visual graph.
345  html += "<h3>Visual Graph</h3>\n"
346  html += "<svg id='subgraph_%s' width='%d' height='%d'></svg>\n"%(
347      g["name"], max(min(len(g["operations"])*100, 1600), 200), len(g["operations"])*100)
348  html += GenerateGraph(g)
349  html += "</div>"
350
351  fd.write(html)
352
353def InitializeHtml(fd):
354  html = ""
355  html += _CSS
356  html += "<h1>%s</h1>"%(tg.FileNames.specName)
357  fd.write(html)
358
359def FinalizeHtml(fd):
360  fd.write("</body></html>\n")
361
362def VisualizeModel(example, fd):
363    if varName is not None and not fnmatch.fnmatch(str(example.testName), varName):
364        print("    Skip variation %s" % example.testName)
365        return
366    print("    Visualizing variation %s" % example.testName)
367    model = example.model
368    g = {}
369    g["options"] = {"relaxed": str(model.isRelaxed), "useSHM": str(tg.Configuration.useSHM())}
370    g["name"] = str(example.testName)
371    g["inputs"] = model.GetInputs()
372    g["outputs"] = model.GetOutputs()
373    g["operands"] = [{
374            "name": str(op), "type": op.type.type, "dimensions": op.type.GetDimensionsString(),
375            "scale": op.type.scale, "zero_point": op.type.zeroPoint, "lifetime": op.lifetime
376        } for op in model.operands]
377    g["operations"] = [{
378            "inputs": op.ins, "outputs": op.outs, "opcode": op.optype
379        } for op in model.operations]
380    CreateHtmlFile(g, fd)
381
382# Take a model from command line
383def ParseCmdLine():
384    parser = argparse.ArgumentParser()
385    parser.add_argument("spec", help="the spec file")
386    parser.add_argument(
387        "-v", "--variation", help="the target variation name/pattern", default=None)
388    parser.add_argument(
389        "-o", "--out", help="the output html path", default="out.html")
390    args = parser.parse_args()
391    tg.FileNames.InitializeFileLists(
392        args.spec, "-", "-", "-", "-", "-")
393    tg.FileNames.NextFile()
394    return os.path.abspath(args.spec), args.variation, os.path.abspath(args.out)
395
396if __name__ == '__main__':
397    specFile, varName, outFile = ParseCmdLine()
398    print("Visualizing from spec: %s" % specFile)
399    exec(open(specFile, "r").read())
400    with SmartOpen(outFile) as fd:
401        InitializeHtml(fd)
402        Example.DumpAllExamples(
403            DumpModel=None, model_fd=None,
404            DumpExample=VisualizeModel, example_fd=fd,
405            DumpTest=None, test_fd=None)
406        FinalizeHtml(fd)
407    print("Output HTML file: %s" % outFile)
408
409