1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "debug/draw.h"
18
19 #include <iostream>
20 #include <iterator>
21 #include <vector>
22 #include <string>
23 #include "ir/meta_func_graph.h"
24 #include "ir/param_info.h"
25 #include "ir/primitive.h"
26 #include "ir/graph_utils.h"
27 #include "utils/utils.h"
28 #include "frontend/operator/composite/composite.h"
29 #include "pipeline/jit/parse/resolve.h"
30 #include "ir/tensor.h"
31
32 namespace mindspore {
33 // namespace to support debug utils
34 namespace draw {
35 namespace {
36 // Only for ValueNode
ValueType(const ValueNodePtr & node)37 std::string ValueType(const ValueNodePtr &node) {
38 if (node == nullptr) {
39 return "";
40 }
41 auto v = node->value();
42 MS_EXCEPTION_IF_NULL(v);
43 return v->type_name();
44 }
45
ReplaceSpecialChar(const std::string & str)46 std::string ReplaceSpecialChar(const std::string &str) {
47 std::ostringstream oss;
48 for (size_t i = 0; i < str.size(); i++) {
49 if (str[i] == '<') {
50 oss << "「";
51 } else if (str[i] == '>') {
52 oss << "」";
53 } else {
54 oss << str[i];
55 }
56 }
57 return oss.str();
58 }
59 } // namespace
60
61 // API of debug utils
DrawNodes(const std::vector<AnfNodePtr> & nodes,OrderedMap<FuncGraphPtr,std::shared_ptr<BaseDigraph>> * sub_graphs,bool is_user)62 void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
63 bool is_user) {
64 if (sub_graphs == nullptr) {
65 return;
66 }
67 for (auto &nd : nodes) {
68 MS_EXCEPTION_IF_NULL(nd);
69 auto sub_graph = nd->func_graph();
70 if (sub_graph != nullptr) {
71 auto gsub = (*sub_graphs)[sub_graph];
72 if (gsub == nullptr) {
73 if (is_user) {
74 gsub = std::make_shared<ModelDigraph>(sub_graph->ToString());
75 } else {
76 gsub = std::make_shared<Digraph>(sub_graph->ToString());
77 }
78 (*sub_graphs)[sub_graph] = gsub;
79 }
80 if (!nd->isa<Parameter>()) {
81 gsub->Node(nd);
82 }
83 }
84 }
85 }
86
DrawValueNodes(const std::vector<AnfNodePtr> & nodes,OrderedMap<FuncGraphPtr,std::shared_ptr<BaseDigraph>> * sub_graphs)87 void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
88 OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
89 if (sub_graphs == nullptr) {
90 return;
91 }
92
93 int dup_idx = 0;
94
95 for (auto &nd : nodes) {
96 for (auto &t : GetInputs(nd)) {
97 MS_EXCEPTION_IF_NULL(t);
98 MS_EXCEPTION_IF_NULL(nd);
99 if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
100 (*sub_graphs)[nd->func_graph()]->Node(t, dup_idx);
101 dup_idx++;
102 } else if (t->isa<Parameter>() && (*sub_graphs).find(t->func_graph()) != (*sub_graphs).end()) {
103 (*sub_graphs)[t->func_graph()]->Node(t, dup_idx);
104 dup_idx++;
105 }
106 }
107 }
108 }
109
DrawEdges(const std::vector<AnfNodePtr> & nodes,const std::shared_ptr<BaseDigraph> & digraph,bool is_user)110 void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
111 if (digraph == nullptr) {
112 return;
113 }
114
115 int dup_idx = 0;
116
117 int offset = 0;
118 if (is_user) {
119 offset = 1;
120 }
121
122 // Draw edge
123 for (auto &nd : nodes) {
124 auto &succs = GetInputs(nd);
125 auto num = succs.size();
126 for (size_t i = 0; i < num; i++) {
127 auto &t = succs.at(i);
128 MS_EXCEPTION_IF_NULL(t);
129 if (t->isa<ValueNode>() || t->isa<Parameter>()) {
130 if ((!is_user) || (i != 0)) {
131 // `SizeToInt(i) - offset` is just for printing as text
132 digraph->Edge(t, nd, SizeToInt(i) - offset, dup_idx);
133 }
134 if (IsValueNode<FuncGraph>(t)) {
135 auto const_graph = GetValueNode<FuncGraphPtr>(t);
136 digraph->Edge(t, const_graph, dup_idx);
137 }
138 dup_idx++;
139 } else {
140 digraph->Edge(t, nd, SizeToInt(i) - offset);
141 }
142 }
143 }
144 }
145
DrawByOpt(const std::string & filename,const FuncGraphPtr & func_graph,bool is_user)146 void DrawByOpt(const std::string &filename, const FuncGraphPtr &func_graph, bool is_user) {
147 if (func_graph == nullptr) {
148 return;
149 }
150 auto ret = func_graph->get_return();
151 auto nodes = DeepScopedGraphSearch(ret);
152
153 std::shared_ptr<BaseDigraph> digraph;
154 OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> sub_graphs;
155 ChangeFileMode(filename, S_IWUSR);
156 if (is_user) {
157 digraph = std::make_shared<ModelDigraph>("mindspore", filename);
158 } else {
159 digraph = std::make_shared<Digraph>("mindspore", filename);
160 }
161
162 MS_EXCEPTION_IF_NULL(digraph);
163 digraph->Start();
164
165 // Draw nodes
166 DrawNodes(nodes, &sub_graphs, is_user);
167
168 // Draw ValueNodes on CNodes
169 DrawValueNodes(nodes, &sub_graphs);
170
171 // Draw subgraph
172 for (const auto &gsub : sub_graphs) {
173 digraph->SubGraph(gsub.first, gsub.second);
174 }
175
176 // Draw edge
177 DrawEdges(nodes, digraph, is_user);
178
179 digraph->End();
180 // set file mode to read only by user
181 ChangeFileMode(filename, S_IRUSR);
182 }
183
184 #ifdef ENABLE_DUMP_IR
Draw(const std::string & filename,const FuncGraphPtr & func_graph)185 void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
186 const std::string dot_suffix = ".dot";
187 const std::string filename_with_suffix =
188 (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
189 const std::string filepath = GetSaveGraphsPathName(Common::AddId(filename_with_suffix, dot_suffix));
190 auto real_filepath = Common::CreatePrefixPath(filepath);
191 if (!real_filepath.has_value()) {
192 MS_LOG(ERROR) << "The export ir path: " << filepath << " is illegal.";
193 return;
194 }
195 DrawByOpt(real_filepath.value(), func_graph, false);
196 }
197
DrawUserFuncGraph(const std::string & filename,const FuncGraphPtr & func_graph)198 void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
199 const std::string dot_suffix = ".dot";
200 const std::string filepath = GetSaveGraphsPathName(Common::AddId(filename, dot_suffix));
201 auto real_filepath = Common::CreatePrefixPath(filepath);
202 if (!real_filepath.has_value()) {
203 MS_LOG(ERROR) << "The export ir path: " << filepath << " is illegal.";
204 return;
205 }
206 DrawByOpt(real_filepath.value(), func_graph, true);
207 }
208 #else
Draw(const std::string &,const FuncGraphPtr &)209 void Draw(const std::string &, const FuncGraphPtr &) {
210 static bool already_printed = false;
211 if (already_printed) {
212 return;
213 }
214 already_printed = true;
215 MS_LOG(WARNING) << "The functionality of dumping function graph IR in graphviz dot format is disabled, "
216 << "please recompile source to enable it. See help of building script.";
217 }
218
DrawUserFuncGraph(const std::string &,const FuncGraphPtr &)219 void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
220 static bool already_printed = false;
221 if (already_printed) {
222 return;
223 }
224 already_printed = true;
225 MS_LOG(WARNING) << "The functionality of dumping function graph IR in graphviz dot format is disabled, "
226 << "please recompile source to enable it. See help of building script.";
227 }
228 #endif
229
Shape(const AnfNodePtr & node)230 std::string Graphviz::Shape(const AnfNodePtr &node) {
231 if (node == nullptr) {
232 return "";
233 }
234
235 if (node->isa<CNode>()) {
236 return "plaintext";
237 }
238
239 if (node->isa<Parameter>()) {
240 return "octagon";
241 }
242
243 if (IsValueNode<FuncGraph>(node)) {
244 return "oval";
245 }
246
247 return "plaintext";
248 }
249
Color(const AnfNodePtr & node)250 std::string Graphviz::Color(const AnfNodePtr &node) {
251 if (node == nullptr) {
252 return "";
253 }
254
255 if (node->isa<CNode>()) {
256 return "cornsilk";
257 }
258
259 if (node->isa<Parameter>()) {
260 return "paleturquoise";
261 }
262
263 if (IsValueNode<FuncGraph>(node)) {
264 return "palegreen";
265 }
266
267 return "lavender";
268 }
269
Start()270 void BaseDigraph::Start() {
271 buffer_ << "digraph " << name_ << " {" << std::endl;
272 buffer_ << "compound=true" << std::endl;
273 }
274
Head(const AnfNodePtr & node,int id)275 void BaseDigraph::Head(const AnfNodePtr &node, int id) {
276 if (node == nullptr) {
277 return;
278 }
279
280 buffer_ << "node" << node << "_" << id;
281 if (node->isa<CNode>() || (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node))) {
282 buffer_ << ":core";
283 }
284 }
285
Tail(const AnfNodePtr & node,int idx,int id)286 void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
287 if (node == nullptr) {
288 return;
289 }
290
291 buffer_ << "node" << node << "_" << id;
292 buffer_ << ":" << idx;
293 }
294
Tail(const FuncGraphPtr & func_graph)295 void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
296 if (func_graph == nullptr) {
297 return;
298 }
299 buffer_ << "node" << func_graph->get_return() << "_" << 0;
300 }
301
Edge(const AnfNodePtr & start,const FuncGraphPtr & end,int id_start)302 void BaseDigraph::Edge(const AnfNodePtr &start, const FuncGraphPtr &end, int id_start) {
303 Head(start, id_start);
304 buffer_ << "->";
305 Tail(end);
306
307 buffer_ << "[lhead=cluster_" << end;
308 buffer_ << ",dir=both,arrowhead=dot,style=filled,color=blue]";
309 buffer_ << std::endl;
310 }
311
End()312 void BaseDigraph::End() {
313 buffer_ << "}" << std::endl;
314
315 if (fout_.is_open()) {
316 fout_ << buffer_.str();
317 }
318 }
319
FuncGraphParameters(const FuncGraphPtr & key)320 void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
321 buffer_ << "parameters_" << key << "[shape=plaintext ";
322 buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
323 buffer_ << "<tr><td>parameters</td></tr>";
324 int count = 0;
325 for (auto ¶meter : key->parameters()) {
326 MS_EXCEPTION_IF_NULL(parameter);
327 buffer_ << "<tr><td>";
328 buffer_ << parameter->ToString();
329 auto param = parameter->cast<ParameterPtr>();
330 if (param && param->has_default()) {
331 auto tensor_v = param->default_param();
332 if (tensor_v && tensor_v->isa<tensor::Tensor>()) {
333 auto tensor = tensor_v->cast<tensor::TensorPtr>();
334 auto &shape = tensor->shape();
335 std::ostringstream shape_str;
336 std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ","));
337 buffer_ << "[" << shape_str.str() << "]";
338 }
339 }
340 buffer_ << "</td></tr>";
341 count++;
342 // Wrap the text.
343 if (count % 10 == 0) {
344 buffer_ << "\n";
345 }
346 }
347 buffer_ << "</table>>,];";
348 }
349
SubGraph(const FuncGraphPtr & key,const std::shared_ptr<BaseDigraph> & gsub)350 void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
351 if (key == nullptr || gsub == nullptr) {
352 return;
353 }
354
355 std::string label = key->debug_info()->get_full_name();
356 if (label.empty()) {
357 label = gsub->name();
358 }
359
360 std::string label_managed = "[managed]";
361 if (key->manager() == nullptr) {
362 label_managed = "[not managed]";
363 }
364 label += label_managed;
365
366 gsub->FuncGraphParameters(key);
367 buffer_ << "subgraph cluster_" << key << "{" << std::endl;
368 buffer_ << "id=cluster_" << key << std::endl;
369 buffer_ << "label=\"" << label << "\"" << std::endl;
370 buffer_ << "fontname=\"Courier New\"" << std::endl;
371 buffer_ << gsub->buffer().str();
372 buffer_ << "}" << std::endl;
373 }
374
~Digraph()375 Digraph::~Digraph() {
376 try {
377 if (fout_.is_open()) {
378 fout_.close();
379 }
380 } catch (const std::exception &e) {
381 MS_LOG(ERROR) << "Exception when closing file " << filename_;
382 }
383 }
384
ReplaceAll(std::string str,const std::string & from,const std::string & to)385 static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
386 size_t start_pos = 0;
387 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
388 (void)str.replace(start_pos, from.length(), to);
389 // Handles case where 'to' is a substring of 'from'
390 start_pos += to.length();
391 }
392 return str;
393 }
394
DrawValueNode(Graphviz * const graph_obj,const ValueNodePtr & node)395 static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
396 MS_EXCEPTION_IF_NULL(graph_obj);
397 graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
398 << "'>";
399 graph_obj->buffer() << "<tr><td bgcolor='white'>" << ValueType(node) << "</td></tr>"
400 << "<tr><td>";
401 if (IsValueNode<MetaFuncGraph>(node)) {
402 graph_obj->buffer() << node->value()->cast<MetaFuncGraphPtr>()->name();
403 } else if (IsValueNode<parse::NameSpace>(node)) {
404 graph_obj->buffer() << node->value()->cast<parse::NameSpacePtr>()->name();
405 } else if (IsValueNode<parse::Symbol>(node)) {
406 graph_obj->buffer() << ReplaceSpecialChar(node->value()->cast<parse::SymbolPtr>()->name());
407 } else {
408 std::ostringstream ss;
409 ss << node->value()->ToString();
410 std::string s = ReplaceAll(ss.str(), ", ", "<br/>");
411 graph_obj->buffer() << s;
412 ValuePtr value = node->value();
413 if (value->isa<Primitive>()) {
414 PrimitivePtr primitive = value->cast<PrimitivePtr>();
415 graph_obj->buffer() << "</td></tr>"
416 << "<tr><td align='left'>";
417 if (!primitive->instance_name().empty()) {
418 graph_obj->buffer() << "instance name:"
419 << " " << primitive->instance_name() << "<br/>";
420 }
421 auto attrs = primitive->attrs();
422 if (attrs.size() > 0) {
423 graph_obj->buffer() << "</td></tr>"
424 << "<tr><td align='left'>";
425 int i = 0;
426 for (const auto &attr : attrs) {
427 if (i != 0) {
428 graph_obj->buffer() << "<br/>";
429 }
430 graph_obj->buffer() << attr.first << " ";
431 if (attr.second == nullptr) {
432 graph_obj->buffer() << " ";
433 } else {
434 graph_obj->buffer() << attr.second->ToString();
435 }
436 i++;
437 }
438 }
439 }
440 }
441 graph_obj->buffer() << "</td></tr>"
442 << "</table>>,";
443 }
444
DrawParallelInfo(Graphviz * const graph_obj,const CNodePtr & node)445 static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
446 if (graph_obj == nullptr || node == nullptr) {
447 return;
448 }
449 auto distributed_operation_info = node->user_data<parallel::OperatorInfo>();
450 if (distributed_operation_info != nullptr) {
451 auto strategyPtr = distributed_operation_info->strategy();
452 if (strategyPtr != nullptr) {
453 auto num = node->inputs().size();
454 graph_obj->buffer() << "<tr><td colspan='" << num << "' ";
455 graph_obj->buffer() << "bgcolor='" << graph_obj->Color(node) << "'>";
456 std::vector<ValuePtr> temp = {MakeValue(strategyPtr->GetInputStage()), MakeValue(strategyPtr->GetInputDim())};
457 ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(temp);
458 graph_obj->buffer() << "Strategy " << strategy_tuple->ToString();
459 graph_obj->buffer() << "</td></tr>" << std::endl;
460 }
461 }
462 }
463
DrawCNode(Graphviz * const graph_obj,const CNodePtr & node)464 static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
465 if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
466 return;
467 }
468 auto num = node->size();
469 bool is_modelgraph = false;
470 if (typeid(*graph_obj) == typeid(ModelDigraph)) {
471 is_modelgraph = true;
472 num -= 1;
473 }
474
475 graph_obj->buffer() << "label=<<table port='core'>" << std::endl;
476 // Draw ports for CNode
477 if (num > 0) {
478 graph_obj->buffer() << "<tr>";
479 for (size_t i = 0; i < num; i++) {
480 graph_obj->buffer() << "<td port='" << i << "'>" << i << "</td>";
481 }
482 graph_obj->buffer() << "</tr>" << std::endl;
483 }
484
485 // Draw op name
486 graph_obj->buffer() << "<tr><td";
487 if (num > 0) {
488 graph_obj->buffer() << " colspan='" << num << "'";
489 }
490 graph_obj->buffer() << " bgcolor='" << graph_obj->Color(node) << "'>";
491
492 if (IsValueNode<Primitive>(node->input(0)) && is_modelgraph) {
493 auto primitive = GetValueNode<PrimitivePtr>(node->input(0));
494 graph_obj->buffer() << ReplaceAll(primitive->ToString(), ", ", "<br/>");
495 auto attrs = primitive->attrs();
496 if (attrs.size() > 0) {
497 graph_obj->buffer() << "</td></tr>" << std::endl << "<tr><td";
498 // Draw attrs
499 if (num > 0) {
500 graph_obj->buffer() << " colspan='" << num << "'";
501 }
502 graph_obj->buffer() << ">";
503 int i = 0;
504 for (auto &attr : attrs) {
505 if (i != 0) {
506 graph_obj->buffer() << "<br/>";
507 }
508 graph_obj->buffer() << attr.first << " " << attr.second->ToString();
509 i++;
510 }
511 }
512 graph_obj->buffer() << "CNode";
513 } else {
514 graph_obj->buffer() << "CNode(" << node->ToString() << ")";
515 }
516
517 graph_obj->buffer() << "</td></tr>" << std::endl;
518 DrawParallelInfo(graph_obj, node);
519 graph_obj->buffer() << "</table>>,";
520 }
521
Node(const AnfNodePtr & node,int id)522 void Digraph::Node(const AnfNodePtr &node, int id) {
523 if (node == nullptr) {
524 return;
525 }
526
527 buffer_ << "node" << node << "_" << id;
528 buffer_ << "[";
529
530 // Set fontname
531 buffer_ << "fontname=\"Courier New\",";
532 // Set label and shape
533 buffer_ << "shape=" << Shape(node) << ",";
534 if (node->isa<CNode>()) {
535 DrawCNode(this, node->cast<CNodePtr>());
536 } else if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
537 DrawValueNode(this, node->cast<ValueNodePtr>());
538 } else {
539 buffer_ << "label=\"" << node->ToString();
540 if (IsValueNode<FuncGraph>(node)) {
541 FuncGraphPtr nextNet = GetValueNode<FuncGraphPtr>(node);
542 std::string nextNetName = nextNet->debug_info()->get_full_name();
543 if (!nextNetName.empty()) {
544 buffer_ << "[" << nextNet->debug_info()->get_full_name().c_str() << "]";
545 }
546 }
547 buffer_ << "\","
548 << "style=filled,fillcolor=" << Color(node) << ",";
549 }
550
551 // Set URL for func graph
552 if (IsValueNode<FuncGraph>(node)) {
553 buffer_ << "URL=\"#cluster_" << GetValueNode(node) << "\",";
554 }
555
556 buffer_ << "]" << std::endl;
557 }
558
Edge(const AnfNodePtr & start,const AnfNodePtr & end,int idx,int id_start)559 void Digraph::Edge(const AnfNodePtr &start, const AnfNodePtr &end, int idx, int id_start) {
560 if (start == nullptr || end == nullptr) {
561 return;
562 }
563
564 Head(start, id_start);
565 buffer_ << "->";
566 Tail(end, idx);
567
568 buffer_ << "[arrowhead=vee,";
569
570 // Check how many inputs for end
571 if (end->isa<CNode>()) {
572 auto cnode = end->cast<CNodePtr>();
573 MS_EXCEPTION_IF_NULL(cnode);
574 auto num = cnode->inputs().size();
575 if (idx == 0 && num > 1) {
576 buffer_ << "style=dashed";
577 }
578 }
579 buffer_ << "]" << std::endl;
580 }
581
~ModelDigraph()582 ModelDigraph::~ModelDigraph() {
583 try {
584 if (fout_.is_open()) {
585 fout_.close();
586 }
587 } catch (const std::exception &e) {
588 MS_LOG(ERROR) << "exception when closing file " << filename_;
589 }
590 }
591
Shape(const AnfNodePtr & node)592 std::string ModelDigraph::Shape(const AnfNodePtr &node) {
593 if (node == nullptr) {
594 return "";
595 }
596
597 if (node->isa<CNode>()) {
598 return "plaintext";
599 }
600
601 if (node->isa<Parameter>()) {
602 return "ellipse";
603 }
604
605 if (IsValueNode<FuncGraph>(node)) {
606 return "oval";
607 }
608
609 return "plaintext";
610 }
611
Node(const AnfNodePtr & node,int id)612 void ModelDigraph::Node(const AnfNodePtr &node, int id) {
613 if (node == nullptr) {
614 return;
615 }
616
617 if (IsValueNode<Primitive>(node)) {
618 return;
619 }
620
621 buffer_ << "node" << node << "_" << id;
622 buffer_ << "[";
623
624 // Set fontname
625 buffer_ << "fontname=\"Courier New\",";
626 // Set label and shape
627 buffer_ << "shape=" << Shape(node) << ",";
628 if (node->isa<CNode>()) {
629 DrawCNode(this, node->cast<CNodePtr>());
630 } else if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
631 DrawValueNode(this, node->cast<ValueNodePtr>());
632 } else {
633 buffer_ << "label=\"" << node->ToString() << "\",";
634 buffer_ << "style=filled,fillcolor=" << Color(node) << ",";
635 }
636
637 // Set URL for func graph
638 if (IsValueNode<FuncGraph>(node)) {
639 buffer_ << "URL=\"#cluster_" << GetValueNode(node) << "\",";
640 }
641
642 buffer_ << "]" << std::endl;
643 }
644
Edge(const AnfNodePtr & start,const AnfNodePtr & end,int idx,int id_start)645 void ModelDigraph::Edge(const AnfNodePtr &start, const AnfNodePtr &end, int idx, int id_start) {
646 if (start == nullptr || end == nullptr) {
647 return;
648 }
649
650 Head(start, id_start);
651 buffer_ << "->";
652 Tail(end, idx);
653
654 buffer_ << "[arrowhead=vee,";
655 buffer_ << "]" << std::endl;
656 }
657
658 struct DrawerRegister {
DrawerRegistermindspore::draw::DrawerRegister659 DrawerRegister() {
660 FuncGraph::set_drawer(
661 [](const std::string &filename, const FuncGraphPtr &func_graph) { Draw(filename, func_graph); });
662 }
663 ~DrawerRegister() = default;
664 } drawer_regsiter;
665 } // namespace draw
666 } // namespace mindspore
667