1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/dump_graphviz.h"
16
17 #include <cmath>
18 #include <functional>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "re2/re2.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
30 #include "tensorflow/lite/toco/toco_port.h"
31 #include "tensorflow/lite/toco/toco_types.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33
34 using toco::port::AppendF;
35 using toco::port::StringF;
36
37 namespace toco {
38 namespace {
39
40 // 'nslimit' is a graphviz (dot) paramater that limits the iterations during
41 // the layout phase. Omitting it allows infinite iterations, causing some
42 // complex graphs to never finish. A value of 125 produces good graphs
43 // while allowing complex graphs to finish.
44 constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
45 nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
46 )CODE";
47 // Note: tooltip's are only supported on SVGs in Chrome.
48 constexpr char kSubgraphFmt[] =
49 R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
50 )CODE";
51 constexpr char kArrayNodeFmt[] =
52 R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
53 )CODE";
54 constexpr char kOpNodeFmt[] =
55 R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
56 )CODE";
57 constexpr char kInputEdgeFmt[] =
58 R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
59 )CODE";
60 constexpr char kOutputEdgeFmt[] =
61 R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
62 )CODE";
63 constexpr char kRNNBackEdgeFmt[] =
64 R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false];
65 )CODE";
66 constexpr char kUnicodeMult[] = "\u00D7";
67 constexpr char kUnicodeEllipsis[] = " \u2026 ";
68
69 class Color {
70 public:
Color()71 Color() {}
Color(uint8 r,uint8 g,uint8 b)72 Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
Color(uint32 word)73 explicit Color(uint32 word)
74 : r_((word & 0x00FF0000) >> 16),
75 g_((word & 0x0000FF00) >> 8),
76 b_((word & 0x000000FF) >> 0) {}
77
78 // Returns the string serialization of this color in graphviz format,
79 // for use as 'fillcolor' in boxes.
AsHexString() const80 string AsHexString() const { return StringF("#%.2X%.2X%.2X", r_, g_, b_); }
81 // The color to use for this node; will be used as 'fillcolor'
82 // for its box. See Color::AsHexString. A suitable, different
83 // color will be chosen for the 'fontcolor' for the inside text
84 // label, see Color::TextColorString.
85 // Returns the serialization in graphviz format of a suitable color to use
86 // 'fontcolor' in the same boxes. It should black or white, whichever offers
87 // the better contrast from AsHexString().
TextColorString() const88 string TextColorString() const {
89 // https://en.wikipedia.org/wiki/Relative_luminance
90 const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
91 const uint8 l = luminance > 128.f ? 0 : 255;
92 return StringF("#%.2X%.2X%.2X", l, l, l);
93 }
94
95 private:
96 uint8 r_ = 0, g_ = 0, b_ = 0;
97 };
98
HashStringToColor(string s)99 Color HashStringToColor(string s) {
100 // Return a unique color for a name.
101 //
102 // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
103 // the string to a uint_32, then twiddles some bits to get a light and subtle
104 // color. This seems to be a good heuristic for keeping enough of the name to
105 // hash to a unique color while still revealing structure through naming
106 // similarities.
107 //
108 // The regular expression "_\d+" matches any underscore followed by numbers,
109 // which we strip out. Examples:
110 //
111 // "Conv" -> "Conv"
112 // "Conv_2" -> "Conv"
113 // "Conv_72" -> "Conv"
114 // "Pad_1_bias -> "Pad_bias"
115 // "Conv_abc" -> "Conv_abc"
116
117 RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
118 uint32 color_word = std::hash<std::string>{}(s);
119 color_word |= 0x00E0E0E0;
120 return Color(color_word);
121 }
122
GetArrayColorAndShape(const Model & model,const string & array_name,Color * color,string * shape)123 void GetArrayColorAndShape(const Model& model, const string& array_name,
124 Color* color, string* shape) {
125 // All colors in this file are from:
126 // https://material.io/guidelines/style/color.html
127 // Arrays involved in RNN back-edges have a different color
128 for (const auto& rnn_state : model.flags.rnn_states()) {
129 // RNN state, fed by a back-edge. Bold color.
130 if (array_name == rnn_state.state_array()) {
131 *color = Color(0x0F, 0x9D, 0x58);
132 *shape = "invhouse";
133 return;
134 }
135 // RNN back-edge source, feeding a RNN state.
136 // Light tone of the same color as RNN states.
137 if (array_name == rnn_state.back_edge_source_array()) {
138 *color = Color(0xB7, 0xE1, 0xCD);
139 *shape = "house";
140 return;
141 }
142 }
143 // Constant parameter arrays have their own bold color
144 if (model.GetArray(array_name).buffer) {
145 *color = Color(0x42, 0x85, 0xF4);
146 *shape = "cylinder";
147 return;
148 }
149 // Remaining arrays are activations.
150 // We use gray colors for them because they are the majority
151 // of arrays so we want to highlight other arrays instead of them.
152 // First, we use a bolder gray for input/output arrays:
153 if (IsInputArray(model, array_name)) {
154 *color = Color(0x9E, 0x9E, 0x9E);
155 *shape = "invhouse";
156 return;
157 }
158 if (IsOutputArray(model, array_name)) {
159 *color = Color(0x9E, 0x9E, 0x9E);
160 *shape = "house";
161 return;
162 }
163 // Remaining arrays are intermediate activation arrays.
164 // Lighter tone of the same grey as for input/output arrays:
165 // We want these to be very discrete.
166 *color = Color(0xF5, 0xF5, 0xF5);
167 *shape = "box";
168 }
169
GetArrayCompassPt(const Model & model,const string & array_name)170 string GetArrayCompassPt(const Model& model, const string& array_name) {
171 // The "compass point" is the point on the node where edge connections are
172 // made. For most arrays we don't care, but input's and outputs look better
173 // connected at the tip of the "house" and "invhouse" shapes used. So we
174 // append ":n" and ":s" respectively for those.
175 for (const auto& rnn_state : model.flags.rnn_states()) {
176 // RNN state is essentially an input
177 if (array_name == rnn_state.state_array()) {
178 return ":s";
179 }
180 // RNN back-edge source is essentially an output
181 if (array_name == rnn_state.back_edge_source_array()) {
182 return ":n";
183 }
184 }
185 if (IsInputArray(model, array_name)) {
186 return ":s";
187 }
188 if (IsOutputArray(model, array_name)) {
189 return ":n";
190 }
191 return "";
192 }
193
AppendArrayVal(string * string,Array const & array,int index)194 void AppendArrayVal(string* string, Array const& array, int index) {
195 if (array.buffer->type == ArrayDataType::kFloat) {
196 const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
197 if (index >= data.size()) {
198 return;
199 }
200 AppendF(string, "%.3f", data[index]);
201 } else if (array.buffer->type == ArrayDataType::kUint8) {
202 const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
203 if (index >= data.size()) {
204 return;
205 }
206 AppendF(string, "%d", data[index]);
207 } else if (array.buffer->type == ArrayDataType::kInt16) {
208 const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
209 if (index >= data.size()) {
210 return;
211 }
212 AppendF(string, "%d", data[index]);
213 } else if (array.buffer->type == ArrayDataType::kInt32) {
214 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
215 if (index >= data.size()) {
216 return;
217 }
218 AppendF(string, "%d", data[index]);
219 } else if (array.buffer->type == ArrayDataType::kInt64) {
220 const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
221 if (index >= data.size()) {
222 return;
223 }
224 AppendF(string, "%d", data[index]);
225 } else if (array.buffer->type == ArrayDataType::kBool) {
226 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
227 if (index >= data.size()) {
228 return;
229 }
230 AppendF(string, "%d", data[index]);
231 }
232 }
233
234 typedef std::map<string, string> Attributes;
235
AttributesToHtml(Attributes attributes)236 string AttributesToHtml(Attributes attributes) {
237 string html;
238 for (const auto& attr : attributes) {
239 html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
240 html += attr.first;
241 html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
242 html += attr.second;
243 html += "</TD></TR>";
244 }
245 return html;
246 }
247
GetArrayLabel(const Model & model,const string & array_id)248 string GetArrayLabel(const Model& model, const string& array_id) {
249 string html;
250
251 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
252 html += "<";
253
254 // Begin Table
255 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
256 html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
257
258 auto& array = model.GetArray(array_id);
259 if (array.buffer) {
260 // "cylinder" shapes require some extra head room.
261 html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
262 }
263
264 // "Primary" name of array (last non-slash delimited group of characters).
265 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
266 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
267 AppendF(&html, R"CODE(%s)CODE",
268 std::vector<string>(absl::StrSplit(array_id, '/')).back());
269 html += R"CODE(</I></FONT>)CODE";
270 html += "</TD></TR>";
271
272 // Array data type and dimensions
273 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
274 html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
275 // Type
276 html += ArrayDataTypeName(array.data_type);
277 // Shape
278 if (array.has_shape()) {
279 auto& array_shape = array.shape();
280 html += "[";
281 for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
282 AppendF(&html, "%d", array_shape.dims(dim));
283 if (dim + 1 < array_shape.dimensions_count()) {
284 html += kUnicodeMult;
285 }
286 }
287 html += "]";
288 }
289
290 // Small buffer sample
291 int buffer_size = 0;
292 if (array.buffer) {
293 buffer_size = RequiredBufferSizeForShape(array.shape());
294 }
295 if ((buffer_size > 0) && (buffer_size <= 4)) {
296 html += " = ";
297 if (array.shape().dimensions_count() > 0) {
298 html += "{";
299 }
300 for (int i = 0; i < buffer_size; i++) {
301 AppendArrayVal(&html, array, i);
302 if (i + 1 < buffer_size) {
303 html += ", ";
304 }
305 }
306 if (array.shape().dimensions_count() > 0) {
307 html += "}";
308 }
309 }
310 html += R"CODE(</B></FONT>)CODE";
311 html += "</TD></TR>";
312
313 // Large buffer samples get their own line
314 if (buffer_size > 4) {
315 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
316 AppendArrayVal(&html, array, 0);
317 html += ", ";
318 AppendArrayVal(&html, array, 1);
319 html += kUnicodeEllipsis;
320 AppendArrayVal(&html, array, buffer_size - 2);
321 html += ", ";
322 AppendArrayVal(&html, array, buffer_size - 1);
323 html += "}</TD></TR>";
324 }
325
326 // Other array properties
327 Attributes attrs;
328 if (array.minmax) {
329 attrs["minmax"] =
330 StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
331 }
332 if (array.quantization_params) {
333 attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot"
334 array.quantization_params->scale,
335 array.quantization_params->zero_point);
336 }
337 if (array.alloc) {
338 attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
339 }
340 html += AttributesToHtml(attrs);
341
342 // output array_id in ultra-small font so it can be searched and copied.
343 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
344 html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
345 AppendF(&html, R"CODE("%s")CODE", array_id);
346 html += R"CODE(</FONT>)CODE";
347 html += "</TD></TR>";
348
349 // End Table and HTML-like label
350 html += R"CODE(</TABLE></FONT>)CODE";
351 html += ">";
352 return html;
353 }
354
GetOpAttributes(const Model & model,const Operator & op)355 Attributes GetOpAttributes(const Model& model, const Operator& op) {
356 Attributes attrs;
357 switch (op.fused_activation_function) {
358 case FusedActivationFunctionType::kRelu:
359 attrs["func"] = "ReLU";
360 break;
361 case FusedActivationFunctionType::kRelu6:
362 attrs["func"] = "ReLU6";
363 break;
364 case FusedActivationFunctionType::kRelu1:
365 attrs["func"] = "ReLU1";
366 break;
367 default:
368 break;
369 }
370 // Output state of member vars on derived operators.
371 switch (op.type) {
372 case OperatorType::kConv: {
373 const auto& conv_op = static_cast<const ConvOperator&>(op);
374 string stride;
375 AppendF(&stride, "%d", conv_op.stride_width);
376 stride += kUnicodeMult;
377 AppendF(&stride, "%d", conv_op.stride_height);
378 attrs["stride"] = stride;
379 attrs["padding"] =
380 (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
381 break;
382 }
383 case OperatorType::kDepthwiseConv: {
384 const auto& depthconv_op = static_cast<const ConvOperator&>(op);
385 string stride;
386 AppendF(&stride, "%d", depthconv_op.stride_width);
387 stride += kUnicodeMult;
388 AppendF(&stride, "%d", depthconv_op.stride_height);
389 attrs["stride"] = stride;
390 attrs["padding"] =
391 (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
392 break;
393 }
394 case OperatorType::kFakeQuant: {
395 const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
396 attrs["bits"] = StringF("%d", fakequant_op.num_bits);
397 if (fakequant_op.minmax) {
398 attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
399 fakequant_op.minmax->max);
400 } else {
401 attrs["range"] = "[?,?]";
402 }
403 break;
404 }
405 default:
406 break;
407 }
408 int64 math_ops_count;
409 if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
410 (math_ops_count != 0)) {
411 attrs["math"] = FormattedNumber(math_ops_count) + "ops";
412 }
413
414 return attrs;
415 }
416
GetOpColor(const Operator & op)417 Color GetOpColor(const Operator& op) {
418 if ((op.type == OperatorType::kDepthwiseConv) ||
419 (op.type == OperatorType::kConv) ||
420 (op.type == OperatorType::kFullyConnected) ||
421 (op.type == OperatorType::kFakeQuant)) {
422 // Give some ops a bolder red
423 return Color(0xC5, 0x39, 0x29);
424 } else {
425 return Color(0xDB, 0x44, 0x37);
426 }
427 }
428
GetOpLabel(const Model & model,const Operator & op)429 string GetOpLabel(const Model& model, const Operator& op) {
430 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
431 string html;
432 html += "<";
433
434 // Begin Table
435 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
436 html +=
437 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
438
439 // Input Ports
440 if (!op.inputs.empty()) {
441 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
442 // Distribute evenly using a sub-table
443 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
444 html += R"CODE(<TR>)CODE";
445 for (int i = 0; i < op.inputs.size(); i++) {
446 html += R"CODE(<TD PORT=")CODE";
447 AppendF(&html, "i%d", i);
448 html += R"CODE(">)CODE";
449 if (op.inputs.size() > 1) {
450 // Only number inputs when op has two or more inputs
451 AppendF(&html, "%d", i);
452 }
453 html += "</TD>";
454 }
455 html += "</TR>";
456 html += R"CODE(</TABLE></TD></TR>)CODE";
457 }
458
459 // Name
460 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
461 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
462 if (op.type == OperatorType::kUnsupported) {
463 html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
464 } else {
465 html += string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
466 }
467 html += R"CODE(</B></FONT>)CODE";
468 html += "</TD></TR>";
469
470 // Attributes
471 Attributes attrs = GetOpAttributes(model, op);
472 html += AttributesToHtml(attrs);
473
474 // Output Ports
475 if (!op.outputs.empty()) {
476 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
477 // Distribute evenly using a sub-table
478 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
479 html += R"CODE(<TR>)CODE";
480 for (int i = 0; i < op.outputs.size(); i++) {
481 html += R"CODE(<TD PORT=")CODE";
482 AppendF(&html, "o%d", i);
483 html += R"CODE(">)CODE";
484 if (op.outputs.size() > 1) {
485 // Only number outputs when op has two or more outputs
486 AppendF(&html, "%d", i);
487 }
488 html += "</TD>";
489 }
490 html += "</TR>";
491 html += R"CODE(</TABLE></TD></TR>)CODE";
492 }
493
494 // End Table and HTML-like label
495 html += R"CODE(</TABLE></FONT>)CODE";
496 html += ">";
497
498 return html;
499 }
500
GetLog2BufferSize(const Model & model,const string & array_id)501 float GetLog2BufferSize(const Model& model, const string& array_id) {
502 auto& array = model.GetArray(array_id);
503 if (array.has_shape()) {
504 int buffer_size = 0;
505 if (IsNonEmpty(array.shape())) {
506 buffer_size = RequiredBufferSizeForShape(array.shape());
507 return std::log2(static_cast<float>(buffer_size));
508 }
509 }
510 return 0.0f;
511 }
512
GetOpId(int op_index)513 string GetOpId(int op_index) { return StringF("op%05d", op_index); }
514
DumpOperator(const Model & model,string * output_file,int op_index)515 void DumpOperator(const Model& model, string* output_file, int op_index) {
516 // Dump node for operator.
517 const Operator& op = *model.operators[op_index];
518 Color color = GetOpColor(op);
519 string label = GetOpLabel(model, op);
520 string op_id = GetOpId(op_index);
521 AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
522 color.TextColorString());
523 }
524
DumpOperatorEdges(const Model & model,string * output_file,int op_index)525 void DumpOperatorEdges(const Model& model, string* output_file, int op_index) {
526 // Inputs
527 const Operator& op = *model.operators[op_index];
528 string op_id = GetOpId(op_index);
529 for (int i = 0; i < op.inputs.size(); i++) {
530 const auto& input = op.inputs[i];
531 if (!model.HasArray(input)) {
532 // Connected arrays should _always_ exist. Except, perhaps, during
533 // development.
534 continue;
535 }
536 float log2_buffer_size = GetLog2BufferSize(model, input);
537 // Draw lines that transport more data thicker (Otherwise, where would the
538 // data fit? right?).
539 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
540 // Keep edges that transport more data shorter than those with less.
541 float weight = std::max(1.0f, log2_buffer_size);
542 if (!IsInputArray(model, input) &&
543 GetOpWithOutput(model, input) == nullptr) {
544 // Give the main line of data flow a straighter path by penalizing edges
545 // to standalone buffers. Weights are generally very large buffers that
546 // would otherwise skew the layout.
547 weight = 1.0f;
548 }
549 string compass_pt = GetArrayCompassPt(model, input);
550 AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
551 weight);
552 }
553 // Outputs
554 for (int i = 0; i < op.outputs.size(); i++) {
555 const auto& output = op.outputs[i];
556 if (!model.HasArray(output)) {
557 continue;
558 }
559 float log2_buffer_size = GetLog2BufferSize(model, output);
560 // See comments above regarding weight and line_width calculations.
561 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
562 float weight = std::max(1.0f, log2_buffer_size);
563 if (!IsArrayConsumed(model, output)) {
564 weight = 1.0f;
565 }
566 string compass_pt = GetArrayCompassPt(model, output);
567 AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
568 line_width, weight);
569 }
570 }
571
572 struct Node {
Nodetoco::__anon405688330111::Node573 Node() : math_ops(0) {}
574 // Name used as a key in the model's array map
575 string array_id;
576
577 // Estimated number of math ops incurred by this node (the sum of the op
578 // with this array as 1st output, plus all children nodes).
579 int64 math_ops;
580
581 // A map of child nodes keyed by name.
582 std::map<const string, std::unique_ptr<Node>> children;
583 };
584
GetSubgraphLabel(Node const & node,const string & subgraph)585 string GetSubgraphLabel(Node const& node, const string& subgraph) {
586 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
587 string html;
588 html += "<";
589
590 // Begin Table
591 html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
592 html +=
593 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
594
595 // Name
596 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
597 html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
598 html += subgraph;
599 html += R"CODE(</I></FONT>)CODE";
600 html += "</TD></TR>";
601
602 // Attributes
603 Attributes attrs;
604 if (node.math_ops > 0) {
605 attrs["math"] = FormattedNumber(node.math_ops) + "ops";
606 }
607 html += AttributesToHtml(attrs);
608
609 // End Table and HTML-like label
610 html += R"CODE(</TABLE></FONT>)CODE";
611 html += ">";
612
613 return html;
614 }
615
DumpSubgraphHeader(string * output_file,Node const & node,const string & node_name)616 void DumpSubgraphHeader(string* output_file, Node const& node,
617 const string& node_name) {
618 Color color = HashStringToColor(node_name);
619 string label = GetSubgraphLabel(node, node_name);
620 AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
621 }
622
DumpArray(const Model & model,string * output_file,const string & array_id)623 void DumpArray(const Model& model, string* output_file,
624 const string& array_id) {
625 Color color;
626 string shape;
627 GetArrayColorAndShape(model, array_id, &color, &shape);
628 string label = GetArrayLabel(model, array_id);
629 AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
630 color.AsHexString(), color.TextColorString());
631
632 // Ops are placed in the same subgraph as their first output.
633 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
634 const Operator& op = *model.operators[op_index];
635 if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
636 DumpOperator(model, output_file, op_index);
637 }
638 }
639 }
640
DumpNode(const Model & model,string * output_file,const string & node_name,Node const & node)641 void DumpNode(const Model& model, string* output_file, const string& node_name,
642 Node const& node) {
643 bool not_root = !node_name.empty();
644 if (not_root) {
645 DumpSubgraphHeader(output_file, node, node_name);
646 }
647
648 for (const auto& child : node.children) {
649 if (!child.second->array_id.empty()) {
650 // Dump array if this node posesses one.
651 DumpArray(model, output_file, child.second->array_id);
652 }
653 // Note that it is always possible to have children. Unlike a filesystem,
654 // the existence of array "foo/bar" does _not_ prevent other arrays, such as
655 // and "foo/bar/baz", from being nested beneath it.
656 DumpNode(model, output_file, child.first, *child.second);
657 }
658
659 if (not_root) {
660 // End subgraph
661 AppendF(output_file, " }\n");
662 }
663 }
664
GetArithmeticOpsCount(const Model & model,const string & array_id)665 int64 GetArithmeticOpsCount(const Model& model, const string& array_id) {
666 for (const auto& op : model.operators) {
667 if (!op->outputs.empty() && op->outputs[0] == array_id) {
668 int64 count;
669 if (EstimateArithmeticOpsCount(model, *op, &count)) {
670 return count;
671 } else {
672 return 0;
673 }
674 }
675 }
676 return 0;
677 }
678
InsertNode(const Model & model,const string & array_id,Node * node,std::vector<string> prefixes,int64 * math_ops)679 void InsertNode(const Model& model, const string& array_id, Node* node,
680 std::vector<string> prefixes, int64* math_ops) {
681 if (prefixes.empty()) {
682 // Base case: store array in this node.
683 node->array_id = array_id;
684 *math_ops = GetArithmeticOpsCount(model, array_id);
685 } else {
686 // Insert into the sub-tree for that prefix.
687 string prefix = prefixes.back();
688 prefixes.pop_back();
689 if (node->children.count(prefix) == 0) {
690 // Create a new node if this prefix is unseen.
691 node->children[prefix] = absl::make_unique<Node>();
692 }
693 InsertNode(model, array_id, node->children[prefix].get(), prefixes,
694 math_ops);
695 }
696 // Sum estimated math ops into all nodes.
697 node->math_ops += *math_ops;
698 }
699
BuildArrayTree(const Model & model,Node * tree)700 void BuildArrayTree(const Model& model, Node* tree) {
701 // Delimit array names by path "/", then place into a tree based on this path.
702 for (const auto& array_id : model.GetArrayMap()) {
703 std::vector<string> prefixes = absl::StrSplit(array_id.first, '/');
704 std::reverse(prefixes.begin(), prefixes.end());
705 int64 math_ops; // Temporary storage for math ops used during recursion.
706 InsertNode(model, array_id.first, tree, prefixes, &math_ops);
707 }
708 }
709
GetGraphLabel(const Model & model,const string & graph_name)710 string GetGraphLabel(const Model& model, const string& graph_name) {
711 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
712 string html;
713 html += "<";
714
715 // Begin Table
716 html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
717 html +=
718 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
719
720 // Name
721 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
722 html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
723 html += graph_name;
724 html += R"CODE(</I></B></FONT>)CODE";
725 html += "</TD></TR>";
726
727 // Attributes
728 Attributes attrs;
729 attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
730 if (!model.optional_arrays.empty()) {
731 attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
732 }
733 attrs["operators"] = StringF("%d", model.operators.size());
734 int64 ops_count;
735 if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
736 attrs["math"] = FormattedNumber(ops_count) + "ops";
737 }
738 if (model.transient_data_size > 0) {
739 attrs["transient data size"] =
740 StringF("%d KiB", model.transient_data_size / 1024);
741 }
742 if (model.transient_data_alignment > 0) {
743 attrs["transient data alignment"] =
744 StringF("%d bytes", model.transient_data_alignment);
745 }
746 html += AttributesToHtml(attrs);
747
748 // End Table and HTML-like label
749 html += R"CODE(</TABLE></FONT>)CODE";
750 html += ">";
751
752 return html;
753 }
754 } // namespace
755
DumpGraphviz(const Model & model,string * output_file,const string & graph_name)756 void DumpGraphviz(const Model& model, string* output_file,
757 const string& graph_name) {
758 // Start graphviz format
759 AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
760
761 // Organize arrays into a tree for subgraphing
762 Node tree;
763 BuildArrayTree(model, &tree);
764 DumpNode(model, output_file, "", tree);
765
766 // Dump edges outside all subgraphs (otherwise the referred-to nodes are
767 // implicitly included in that subgraph).
768 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
769 DumpOperatorEdges(model, output_file, op_index);
770 }
771
772 // Dump RNN Backedges
773 for (const auto& rnn_state : model.flags.rnn_states()) {
774 AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
775 rnn_state.state_array());
776 }
777 // End graphviz format
778 AppendF(output_file, "}\n");
779 }
780 } // namespace toco
781