• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/dump_graphviz.h"
16 
17 #include <cmath>
18 #include <functional>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "re2/re2.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
30 #include "tensorflow/lite/toco/toco_port.h"
31 #include "tensorflow/lite/toco/toco_types.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33 
34 using toco::port::AppendF;
35 using toco::port::StringF;
36 
37 namespace toco {
38 namespace {
39 
40 // 'nslimit' is a graphviz (dot) paramater that limits the iterations during
41 // the layout phase. Omitting it allows infinite iterations, causing some
42 // complex graphs to never finish. A value of 125 produces good graphs
43 // while allowing complex graphs to finish.
44 constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
45     nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
46 )CODE";
47 // Note: tooltip's are only supported on SVGs in Chrome.
48 constexpr char kSubgraphFmt[] =
49     R"CODE(    subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
50 )CODE";
51 constexpr char kArrayNodeFmt[] =
52     R"CODE(        "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
53 )CODE";
54 constexpr char kOpNodeFmt[] =
55     R"CODE(        %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
56 )CODE";
57 constexpr char kInputEdgeFmt[] =
58     R"CODE(        "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
59 )CODE";
60 constexpr char kOutputEdgeFmt[] =
61     R"CODE(        %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
62 )CODE";
63 constexpr char kRNNBackEdgeFmt[] =
64     R"CODE(        "%s":s -> "%s":n [color="#0F9D58" constraint=false];
65 )CODE";
66 constexpr char kUnicodeMult[] = "\u00D7";
67 constexpr char kUnicodeEllipsis[] = " \u2026 ";
68 
69 class Color {
70  public:
Color()71   Color() {}
Color(uint8 r,uint8 g,uint8 b)72   Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
Color(uint32 word)73   explicit Color(uint32 word)
74       : r_((word & 0x00FF0000) >> 16),
75         g_((word & 0x0000FF00) >> 8),
76         b_((word & 0x000000FF) >> 0) {}
77 
78   // Returns the string serialization of this color in graphviz format,
79   // for use as 'fillcolor' in boxes.
AsHexString() const80   string AsHexString() const { return StringF("#%.2X%.2X%.2X", r_, g_, b_); }
81   // The color to use for this node; will be used as 'fillcolor'
82   // for its box. See Color::AsHexString. A suitable, different
83   // color will be chosen for the 'fontcolor' for the inside text
84   // label, see Color::TextColorString.
85   // Returns the serialization in graphviz format of a suitable color to use
86   // 'fontcolor' in the same boxes. It should black or white, whichever offers
87   // the better contrast from AsHexString().
TextColorString() const88   string TextColorString() const {
89     // https://en.wikipedia.org/wiki/Relative_luminance
90     const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
91     const uint8 l = luminance > 128.f ? 0 : 255;
92     return StringF("#%.2X%.2X%.2X", l, l, l);
93   }
94 
95  private:
96   uint8 r_ = 0, g_ = 0, b_ = 0;
97 };
98 
HashStringToColor(string s)99 Color HashStringToColor(string s) {
100   // Return a unique color for a name.
101   //
102   // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
103   // the string to a uint_32, then twiddles some bits to get a light and subtle
104   // color. This seems to be a good heuristic for keeping enough of the name to
105   // hash to a unique color while still revealing structure through naming
106   // similarities.
107   //
108   // The regular expression "_\d+" matches any underscore followed by numbers,
109   // which we strip out. Examples:
110   //
111   //     "Conv"      -> "Conv"
112   //     "Conv_2"    -> "Conv"
113   //     "Conv_72"   -> "Conv"
114   //     "Pad_1_bias -> "Pad_bias"
115   //     "Conv_abc"  -> "Conv_abc"
116 
117   RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
118   uint32 color_word = std::hash<std::string>{}(s);
119   color_word |= 0x00E0E0E0;
120   return Color(color_word);
121 }
122 
GetArrayColorAndShape(const Model & model,const string & array_name,Color * color,string * shape)123 void GetArrayColorAndShape(const Model& model, const string& array_name,
124                            Color* color, string* shape) {
125   // All colors in this file are from:
126   // https://material.io/guidelines/style/color.html
127   // Arrays involved in RNN back-edges have a different color
128   for (const auto& rnn_state : model.flags.rnn_states()) {
129     // RNN state, fed by a back-edge. Bold color.
130     if (array_name == rnn_state.state_array()) {
131       *color = Color(0x0F, 0x9D, 0x58);
132       *shape = "invhouse";
133       return;
134     }
135     // RNN back-edge source, feeding a RNN state.
136     // Light tone of the same color as RNN states.
137     if (array_name == rnn_state.back_edge_source_array()) {
138       *color = Color(0xB7, 0xE1, 0xCD);
139       *shape = "house";
140       return;
141     }
142   }
143   // Constant parameter arrays have their own bold color
144   if (model.GetArray(array_name).buffer) {
145     *color = Color(0x42, 0x85, 0xF4);
146     *shape = "cylinder";
147     return;
148   }
149   // Remaining arrays are activations.
150   // We use gray colors for them because they are the majority
151   // of arrays so we want to highlight other arrays instead of them.
152   // First, we use a bolder gray for input/output arrays:
153   if (IsInputArray(model, array_name)) {
154     *color = Color(0x9E, 0x9E, 0x9E);
155     *shape = "invhouse";
156     return;
157   }
158   if (IsOutputArray(model, array_name)) {
159     *color = Color(0x9E, 0x9E, 0x9E);
160     *shape = "house";
161     return;
162   }
163   // Remaining arrays are intermediate activation arrays.
164   // Lighter tone of the same grey as for input/output arrays:
165   // We want these to be very discrete.
166   *color = Color(0xF5, 0xF5, 0xF5);
167   *shape = "box";
168 }
169 
GetArrayCompassPt(const Model & model,const string & array_name)170 string GetArrayCompassPt(const Model& model, const string& array_name) {
171   // The "compass point" is the point on the node where edge connections are
172   // made. For most arrays we don't care, but input's and outputs look better
173   // connected at the tip of the "house" and "invhouse" shapes used. So we
174   // append ":n" and ":s" respectively for those.
175   for (const auto& rnn_state : model.flags.rnn_states()) {
176     // RNN state is essentially an input
177     if (array_name == rnn_state.state_array()) {
178       return ":s";
179     }
180     // RNN back-edge source is essentially an output
181     if (array_name == rnn_state.back_edge_source_array()) {
182       return ":n";
183     }
184   }
185   if (IsInputArray(model, array_name)) {
186     return ":s";
187   }
188   if (IsOutputArray(model, array_name)) {
189     return ":n";
190   }
191   return "";
192 }
193 
AppendArrayVal(string * string,Array const & array,int index)194 void AppendArrayVal(string* string, Array const& array, int index) {
195   if (array.buffer->type == ArrayDataType::kFloat) {
196     const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
197     if (index >= data.size()) {
198       return;
199     }
200     AppendF(string, "%.3f", data[index]);
201   } else if (array.buffer->type == ArrayDataType::kUint8) {
202     const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
203     if (index >= data.size()) {
204       return;
205     }
206     AppendF(string, "%d", data[index]);
207   } else if (array.buffer->type == ArrayDataType::kInt16) {
208     const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
209     if (index >= data.size()) {
210       return;
211     }
212     AppendF(string, "%d", data[index]);
213   } else if (array.buffer->type == ArrayDataType::kInt32) {
214     const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
215     if (index >= data.size()) {
216       return;
217     }
218     AppendF(string, "%d", data[index]);
219   } else if (array.buffer->type == ArrayDataType::kInt64) {
220     const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
221     if (index >= data.size()) {
222       return;
223     }
224     AppendF(string, "%d", data[index]);
225   } else if (array.buffer->type == ArrayDataType::kBool) {
226     const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
227     if (index >= data.size()) {
228       return;
229     }
230     AppendF(string, "%d", data[index]);
231   }
232 }
233 
234 typedef std::map<string, string> Attributes;
235 
AttributesToHtml(Attributes attributes)236 string AttributesToHtml(Attributes attributes) {
237   string html;
238   for (const auto& attr : attributes) {
239     html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
240     html += attr.first;
241     html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
242     html += attr.second;
243     html += "</TD></TR>";
244   }
245   return html;
246 }
247 
GetArrayLabel(const Model & model,const string & array_id)248 string GetArrayLabel(const Model& model, const string& array_id) {
249   string html;
250 
251   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
252   html += "<";
253 
254   // Begin Table
255   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
256   html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
257 
258   auto& array = model.GetArray(array_id);
259   if (array.buffer) {
260     // "cylinder" shapes require some extra head room.
261     html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
262   }
263 
264   // "Primary" name of array (last non-slash delimited group of characters).
265   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
266   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
267   AppendF(&html, R"CODE(%s)CODE",
268           std::vector<string>(absl::StrSplit(array_id, '/')).back());
269   html += R"CODE(</I></FONT>)CODE";
270   html += "</TD></TR>";
271 
272   // Array data type and dimensions
273   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
274   html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
275   // Type
276   html += ArrayDataTypeName(array.data_type);
277   // Shape
278   if (array.has_shape()) {
279     auto& array_shape = array.shape();
280     html += "[";
281     for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
282       AppendF(&html, "%d", array_shape.dims(dim));
283       if (dim + 1 < array_shape.dimensions_count()) {
284         html += kUnicodeMult;
285       }
286     }
287     html += "]";
288   }
289 
290   // Small buffer sample
291   int buffer_size = 0;
292   if (array.buffer) {
293     buffer_size = RequiredBufferSizeForShape(array.shape());
294   }
295   if ((buffer_size > 0) && (buffer_size <= 4)) {
296     html += " = ";
297     if (array.shape().dimensions_count() > 0) {
298       html += "{";
299     }
300     for (int i = 0; i < buffer_size; i++) {
301       AppendArrayVal(&html, array, i);
302       if (i + 1 < buffer_size) {
303         html += ", ";
304       }
305     }
306     if (array.shape().dimensions_count() > 0) {
307       html += "}";
308     }
309   }
310   html += R"CODE(</B></FONT>)CODE";
311   html += "</TD></TR>";
312 
313   // Large buffer samples get their own line
314   if (buffer_size > 4) {
315     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
316     AppendArrayVal(&html, array, 0);
317     html += ", ";
318     AppendArrayVal(&html, array, 1);
319     html += kUnicodeEllipsis;
320     AppendArrayVal(&html, array, buffer_size - 2);
321     html += ", ";
322     AppendArrayVal(&html, array, buffer_size - 1);
323     html += "}</TD></TR>";
324   }
325 
326   // Other array properties
327   Attributes attrs;
328   if (array.minmax) {
329     attrs["minmax"] =
330         StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
331   }
332   if (array.quantization_params) {
333     attrs["quant"] = StringF("%7g\u00B7(x-%d)",  // Unicode "cdot"
334                              array.quantization_params->scale,
335                              array.quantization_params->zero_point);
336   }
337   if (array.alloc) {
338     attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
339   }
340   html += AttributesToHtml(attrs);
341 
342   // output array_id in ultra-small font so it can be searched and copied.
343   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
344   html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
345   AppendF(&html, R"CODE("%s")CODE", array_id);
346   html += R"CODE(</FONT>)CODE";
347   html += "</TD></TR>";
348 
349   // End Table and HTML-like label
350   html += R"CODE(</TABLE></FONT>)CODE";
351   html += ">";
352   return html;
353 }
354 
GetOpAttributes(const Model & model,const Operator & op)355 Attributes GetOpAttributes(const Model& model, const Operator& op) {
356   Attributes attrs;
357   switch (op.fused_activation_function) {
358     case FusedActivationFunctionType::kRelu:
359       attrs["func"] = "ReLU";
360       break;
361     case FusedActivationFunctionType::kRelu6:
362       attrs["func"] = "ReLU6";
363       break;
364     case FusedActivationFunctionType::kRelu1:
365       attrs["func"] = "ReLU1";
366       break;
367     default:
368       break;
369   }
370   // Output state of member vars on derived operators.
371   switch (op.type) {
372     case OperatorType::kConv: {
373       const auto& conv_op = static_cast<const ConvOperator&>(op);
374       string stride;
375       AppendF(&stride, "%d", conv_op.stride_width);
376       stride += kUnicodeMult;
377       AppendF(&stride, "%d", conv_op.stride_height);
378       attrs["stride"] = stride;
379       attrs["padding"] =
380           (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
381       break;
382     }
383     case OperatorType::kDepthwiseConv: {
384       const auto& depthconv_op = static_cast<const ConvOperator&>(op);
385       string stride;
386       AppendF(&stride, "%d", depthconv_op.stride_width);
387       stride += kUnicodeMult;
388       AppendF(&stride, "%d", depthconv_op.stride_height);
389       attrs["stride"] = stride;
390       attrs["padding"] =
391           (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
392       break;
393     }
394     case OperatorType::kFakeQuant: {
395       const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
396       attrs["bits"] = StringF("%d", fakequant_op.num_bits);
397       if (fakequant_op.minmax) {
398         attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
399                                  fakequant_op.minmax->max);
400       } else {
401         attrs["range"] = "[?,?]";
402       }
403       break;
404     }
405     default:
406       break;
407   }
408   int64 math_ops_count;
409   if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
410       (math_ops_count != 0)) {
411     attrs["math"] = FormattedNumber(math_ops_count) + "ops";
412   }
413 
414   return attrs;
415 }
416 
GetOpColor(const Operator & op)417 Color GetOpColor(const Operator& op) {
418   if ((op.type == OperatorType::kDepthwiseConv) ||
419       (op.type == OperatorType::kConv) ||
420       (op.type == OperatorType::kFullyConnected) ||
421       (op.type == OperatorType::kFakeQuant)) {
422     // Give some ops a bolder red
423     return Color(0xC5, 0x39, 0x29);
424   } else {
425     return Color(0xDB, 0x44, 0x37);
426   }
427 }
428 
GetOpLabel(const Model & model,const Operator & op)429 string GetOpLabel(const Model& model, const Operator& op) {
430   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
431   string html;
432   html += "<";
433 
434   // Begin Table
435   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
436   html +=
437       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
438 
439   // Input Ports
440   if (!op.inputs.empty()) {
441     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
442     // Distribute evenly using a sub-table
443     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
444     html += R"CODE(<TR>)CODE";
445     for (int i = 0; i < op.inputs.size(); i++) {
446       html += R"CODE(<TD PORT=")CODE";
447       AppendF(&html, "i%d", i);
448       html += R"CODE(">)CODE";
449       if (op.inputs.size() > 1) {
450         // Only number inputs when op has two or more inputs
451         AppendF(&html, "%d", i);
452       }
453       html += "</TD>";
454     }
455     html += "</TR>";
456     html += R"CODE(</TABLE></TD></TR>)CODE";
457   }
458 
459   // Name
460   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
461   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
462   if (op.type == OperatorType::kUnsupported) {
463     html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
464   } else {
465     html += string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
466   }
467   html += R"CODE(</B></FONT>)CODE";
468   html += "</TD></TR>";
469 
470   // Attributes
471   Attributes attrs = GetOpAttributes(model, op);
472   html += AttributesToHtml(attrs);
473 
474   // Output Ports
475   if (!op.outputs.empty()) {
476     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
477     // Distribute evenly using a sub-table
478     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
479     html += R"CODE(<TR>)CODE";
480     for (int i = 0; i < op.outputs.size(); i++) {
481       html += R"CODE(<TD PORT=")CODE";
482       AppendF(&html, "o%d", i);
483       html += R"CODE(">)CODE";
484       if (op.outputs.size() > 1) {
485         // Only number outputs when op has two or more outputs
486         AppendF(&html, "%d", i);
487       }
488       html += "</TD>";
489     }
490     html += "</TR>";
491     html += R"CODE(</TABLE></TD></TR>)CODE";
492   }
493 
494   // End Table and HTML-like label
495   html += R"CODE(</TABLE></FONT>)CODE";
496   html += ">";
497 
498   return html;
499 }
500 
GetLog2BufferSize(const Model & model,const string & array_id)501 float GetLog2BufferSize(const Model& model, const string& array_id) {
502   auto& array = model.GetArray(array_id);
503   if (array.has_shape()) {
504     int buffer_size = 0;
505     if (IsNonEmpty(array.shape())) {
506       buffer_size = RequiredBufferSizeForShape(array.shape());
507       return std::log2(static_cast<float>(buffer_size));
508     }
509   }
510   return 0.0f;
511 }
512 
GetOpId(int op_index)513 string GetOpId(int op_index) { return StringF("op%05d", op_index); }
514 
DumpOperator(const Model & model,string * output_file,int op_index)515 void DumpOperator(const Model& model, string* output_file, int op_index) {
516   // Dump node for operator.
517   const Operator& op = *model.operators[op_index];
518   Color color = GetOpColor(op);
519   string label = GetOpLabel(model, op);
520   string op_id = GetOpId(op_index);
521   AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
522           color.TextColorString());
523 }
524 
DumpOperatorEdges(const Model & model,string * output_file,int op_index)525 void DumpOperatorEdges(const Model& model, string* output_file, int op_index) {
526   // Inputs
527   const Operator& op = *model.operators[op_index];
528   string op_id = GetOpId(op_index);
529   for (int i = 0; i < op.inputs.size(); i++) {
530     const auto& input = op.inputs[i];
531     if (!model.HasArray(input)) {
532       // Connected arrays should _always_ exist. Except, perhaps, during
533       // development.
534       continue;
535     }
536     float log2_buffer_size = GetLog2BufferSize(model, input);
537     // Draw lines that transport more data thicker (Otherwise, where would the
538     // data fit? right?).
539     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
540     // Keep edges that transport more data shorter than those with less.
541     float weight = std::max(1.0f, log2_buffer_size);
542     if (!IsInputArray(model, input) &&
543         GetOpWithOutput(model, input) == nullptr) {
544       // Give the main line of data flow a straighter path by penalizing edges
545       // to standalone buffers. Weights are generally very large buffers that
546       // would otherwise skew the layout.
547       weight = 1.0f;
548     }
549     string compass_pt = GetArrayCompassPt(model, input);
550     AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
551             weight);
552   }
553   // Outputs
554   for (int i = 0; i < op.outputs.size(); i++) {
555     const auto& output = op.outputs[i];
556     if (!model.HasArray(output)) {
557       continue;
558     }
559     float log2_buffer_size = GetLog2BufferSize(model, output);
560     // See comments above regarding weight and line_width calculations.
561     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
562     float weight = std::max(1.0f, log2_buffer_size);
563     if (!IsArrayConsumed(model, output)) {
564       weight = 1.0f;
565     }
566     string compass_pt = GetArrayCompassPt(model, output);
567     AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
568             line_width, weight);
569   }
570 }
571 
572 struct Node {
Nodetoco::__anon405688330111::Node573   Node() : math_ops(0) {}
574   // Name used as a key in the model's array map
575   string array_id;
576 
577   // Estimated number of math ops incurred by this node (the sum of the op
578   // with this array as 1st output, plus all children nodes).
579   int64 math_ops;
580 
581   // A map of child nodes keyed by name.
582   std::map<const string, std::unique_ptr<Node>> children;
583 };
584 
GetSubgraphLabel(Node const & node,const string & subgraph)585 string GetSubgraphLabel(Node const& node, const string& subgraph) {
586   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
587   string html;
588   html += "<";
589 
590   // Begin Table
591   html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
592   html +=
593       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
594 
595   // Name
596   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
597   html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
598   html += subgraph;
599   html += R"CODE(</I></FONT>)CODE";
600   html += "</TD></TR>";
601 
602   // Attributes
603   Attributes attrs;
604   if (node.math_ops > 0) {
605     attrs["math"] = FormattedNumber(node.math_ops) + "ops";
606   }
607   html += AttributesToHtml(attrs);
608 
609   // End Table and HTML-like label
610   html += R"CODE(</TABLE></FONT>)CODE";
611   html += ">";
612 
613   return html;
614 }
615 
DumpSubgraphHeader(string * output_file,Node const & node,const string & node_name)616 void DumpSubgraphHeader(string* output_file, Node const& node,
617                         const string& node_name) {
618   Color color = HashStringToColor(node_name);
619   string label = GetSubgraphLabel(node, node_name);
620   AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
621 }
622 
DumpArray(const Model & model,string * output_file,const string & array_id)623 void DumpArray(const Model& model, string* output_file,
624                const string& array_id) {
625   Color color;
626   string shape;
627   GetArrayColorAndShape(model, array_id, &color, &shape);
628   string label = GetArrayLabel(model, array_id);
629   AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
630           color.AsHexString(), color.TextColorString());
631 
632   // Ops are placed in the same subgraph as their first output.
633   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
634     const Operator& op = *model.operators[op_index];
635     if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
636       DumpOperator(model, output_file, op_index);
637     }
638   }
639 }
640 
DumpNode(const Model & model,string * output_file,const string & node_name,Node const & node)641 void DumpNode(const Model& model, string* output_file, const string& node_name,
642               Node const& node) {
643   bool not_root = !node_name.empty();
644   if (not_root) {
645     DumpSubgraphHeader(output_file, node, node_name);
646   }
647 
648   for (const auto& child : node.children) {
649     if (!child.second->array_id.empty()) {
650       // Dump array if this node posesses one.
651       DumpArray(model, output_file, child.second->array_id);
652     }
653     // Note that it is always possible to have children. Unlike a filesystem,
654     // the existence of array "foo/bar" does _not_ prevent other arrays, such as
655     // and "foo/bar/baz", from being nested beneath it.
656     DumpNode(model, output_file, child.first, *child.second);
657   }
658 
659   if (not_root) {
660     // End subgraph
661     AppendF(output_file, "    }\n");
662   }
663 }
664 
GetArithmeticOpsCount(const Model & model,const string & array_id)665 int64 GetArithmeticOpsCount(const Model& model, const string& array_id) {
666   for (const auto& op : model.operators) {
667     if (!op->outputs.empty() && op->outputs[0] == array_id) {
668       int64 count;
669       if (EstimateArithmeticOpsCount(model, *op, &count)) {
670         return count;
671       } else {
672         return 0;
673       }
674     }
675   }
676   return 0;
677 }
678 
InsertNode(const Model & model,const string & array_id,Node * node,std::vector<string> prefixes,int64 * math_ops)679 void InsertNode(const Model& model, const string& array_id, Node* node,
680                 std::vector<string> prefixes, int64* math_ops) {
681   if (prefixes.empty()) {
682     // Base case: store array in this node.
683     node->array_id = array_id;
684     *math_ops = GetArithmeticOpsCount(model, array_id);
685   } else {
686     // Insert into the sub-tree for that prefix.
687     string prefix = prefixes.back();
688     prefixes.pop_back();
689     if (node->children.count(prefix) == 0) {
690       // Create a new node if this prefix is unseen.
691       node->children[prefix] = absl::make_unique<Node>();
692     }
693     InsertNode(model, array_id, node->children[prefix].get(), prefixes,
694                math_ops);
695   }
696   // Sum estimated math ops into all nodes.
697   node->math_ops += *math_ops;
698 }
699 
BuildArrayTree(const Model & model,Node * tree)700 void BuildArrayTree(const Model& model, Node* tree) {
701   // Delimit array names by path "/", then place into a tree based on this path.
702   for (const auto& array_id : model.GetArrayMap()) {
703     std::vector<string> prefixes = absl::StrSplit(array_id.first, '/');
704     std::reverse(prefixes.begin(), prefixes.end());
705     int64 math_ops;  // Temporary storage for math ops used during recursion.
706     InsertNode(model, array_id.first, tree, prefixes, &math_ops);
707   }
708 }
709 
GetGraphLabel(const Model & model,const string & graph_name)710 string GetGraphLabel(const Model& model, const string& graph_name) {
711   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
712   string html;
713   html += "<";
714 
715   // Begin Table
716   html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
717   html +=
718       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
719 
720   // Name
721   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
722   html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
723   html += graph_name;
724   html += R"CODE(</I></B></FONT>)CODE";
725   html += "</TD></TR>";
726 
727   // Attributes
728   Attributes attrs;
729   attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
730   if (!model.optional_arrays.empty()) {
731     attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
732   }
733   attrs["operators"] = StringF("%d", model.operators.size());
734   int64 ops_count;
735   if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
736     attrs["math"] = FormattedNumber(ops_count) + "ops";
737   }
738   if (model.transient_data_size > 0) {
739     attrs["transient data size"] =
740         StringF("%d KiB", model.transient_data_size / 1024);
741   }
742   if (model.transient_data_alignment > 0) {
743     attrs["transient data alignment"] =
744         StringF("%d bytes", model.transient_data_alignment);
745   }
746   html += AttributesToHtml(attrs);
747 
748   // End Table and HTML-like label
749   html += R"CODE(</TABLE></FONT>)CODE";
750   html += ">";
751 
752   return html;
753 }
754 }  // namespace
755 
DumpGraphviz(const Model & model,string * output_file,const string & graph_name)756 void DumpGraphviz(const Model& model, string* output_file,
757                   const string& graph_name) {
758   // Start graphviz format
759   AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
760 
761   // Organize arrays into a tree for subgraphing
762   Node tree;
763   BuildArrayTree(model, &tree);
764   DumpNode(model, output_file, "", tree);
765 
766   // Dump edges outside all subgraphs (otherwise the referred-to nodes are
767   // implicitly included in that subgraph).
768   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
769     DumpOperatorEdges(model, output_file, op_index);
770   }
771 
772   // Dump RNN Backedges
773   for (const auto& rnn_state : model.flags.rnn_states()) {
774     AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
775             rnn_state.state_array());
776   }
777   // End graphviz format
778   AppendF(output_file, "}\n");
779 }
780 }  // namespace toco
781