• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/draw/graphviz_graph.h"
18 #include <set>
19 #include <algorithm>
20 #include <sstream>
21 #include <vector>
22 
23 namespace mindspore::lite {
From() const24 std::string Edge::From() const { return from_->name(); }
name() const25 std::string Edge::name() const { return this->name_; }
26 
AppendOutput(const GVNode * to,size_t port)27 void Edge::AppendOutput(const GVNode *to, size_t port) {
28   tos_.emplace_back(to);
29   to_ports_.emplace_back(port);
30 }
31 
Code() const32 std::string Edge::Code() const {
33   std::ostringstream oss;
34   if (from_->type() == kNodeTypeCNode) {
35     oss << from_->prefix() << from_->name() << ":O" << from_port_ << " -> ";
36   } else {
37     oss << from_->prefix() << from_->name() << " -> ";
38   }
39   auto from_str = oss.str();
40   oss.str("");
41   for (size_t i = 0; i < tos_.size(); i++) {
42     auto to = tos_[i];
43     if (to->type() == kNodeTypeCNode) {
44       oss << from_str << to->prefix() << to->name() << ":I" << to_ports_[i] << " [label=\"" << info_ << "\"];";
45     } else {
46       oss << from_str << to->prefix() << to->name() << " [label=\"" << info_ << "\"];";
47     }
48   }
49   return oss.str();
50 }
51 
CreateCNode(const std::string & id,const std::string & label,size_t input_size,const std::vector<std::string> & output_names,const std::vector<std::string> & output_infos,bool highlight)52 GVNode *GVNode::CreateCNode(const std::string &id, const std::string &label, size_t input_size,
53                             const std::vector<std::string> &output_names, const std::vector<std::string> &output_infos,
54                             bool highlight) {
55   auto node = new GVNode(id, label, kNodeTypeCNode, input_size, output_names.size(), highlight);
56   node->prefix_ = "Node_";
57   node->shape_ = "plaintext";
58   node->color_ = "cornsilk";
59   node->Init(output_names, output_infos);
60   return node;
61 }
62 
CreateInput(const std::string & id,const std::vector<std::string> & output_names,const std::vector<std::string> & output_infos,bool highlight)63 GVNode *GVNode::CreateInput(const std::string &id, const std::vector<std::string> &output_names,
64                             const std::vector<std::string> &output_infos, bool highlight) {
65   auto node = new GVNode(id, id, kNodeTypeInput, 0, output_names.size(), highlight);
66   node->prefix_ = "Input_";
67   node->shape_ = "egg";
68   node->Init(output_names, output_infos);
69   return node;
70 }
71 
CreateOutput(const std::string & id,size_t input_size,bool highlight)72 GVNode *GVNode::CreateOutput(const std::string &id, size_t input_size, bool highlight) {
73   auto node = new GVNode(id, id, kNodeTypeOutput, input_size, 0, highlight);
74   node->prefix_ = "Output_";
75   node->shape_ = "egg";
76   node->Init({}, {});
77   return node;
78 }
79 
CreateWeight(const std::string & id,const std::string & label,const std::vector<std::string> & output_names,const std::vector<std::string> & output_infos,bool highlight)80 GVNode *GVNode::CreateWeight(const std::string &id, const std::string &label,
81                              const std::vector<std::string> &output_names, const std::vector<std::string> &output_infos,
82                              bool highlight) {
83   auto node = new GVNode(id, label, kNodeTypeWeight, 0, output_names.size(), highlight);
84   node->prefix_ = "Weight_";
85   node->shape_ = "octagon";
86   node->color_ = "paleturquoise";
87   node->Init(output_names, output_infos);
88   return node;
89 }
90 
~GVNode()91 GVNode::~GVNode() {
92   for (auto output : outputs_) {
93     delete output;
94   }
95   outputs_.clear();
96 }
97 
Init(const std::vector<std::string> & output_names,const std::vector<std::string> & output_infos)98 void GVNode::Init(const std::vector<std::string> &output_names, const std::vector<std::string> &output_infos) {
99   inputs_.reserve(input_size_);
100   outputs_.reserve(output_size_);
101   MS_ASSERT(output_names.size() == output_size_);
102   for (size_t i = 0; i < output_size_; i++) {
103     auto edge = new Edge(output_names[i], this, i, output_infos[i]);
104     this->outputs_.emplace_back(edge);
105   }
106 }
107 
FindCols() const108 size_t GVNode::FindCols() const {
109   auto max = std::max(input_size_, output_size_);
110   auto min = std::min(input_size_, output_size_);
111   if (min == 0 || max == 0) {
112     return 1;
113   }
114   size_t ret = max;
115   while (ret <= input_size_ * output_size_) {
116     if (ret % min == 0) {
117       break;
118     }
119     ret++;
120   }
121   while (ret <= input_size_ * output_size_) {
122     if (ret % max == 0) {
123       break;
124     }
125     ret += min;
126   }
127   return ret;
128 }
129 
Code() const130 std::string GVNode::Code() const {
131   std::ostringstream oss;
132   if (type_ == kNodeTypeCNode) {
133     auto bgcolor = highlight_ ? "red" : color_;
134     oss << "\t"
135         << "\t"
136         << "\t"
137         << "\t";
138     auto indent = oss.str();
139     oss.str("");
140     auto cols = FindCols();
141     oss << "<<table port='core'>" << std::endl;
142     oss << indent << "<tr>";
143     auto input_cols = input_size_ == 0 ? 0 : cols / input_size_;
144     for (size_t i = 0; i < input_size_; i++) {
145       oss << "<td align='center' colspan='" << input_cols << "' port='I" << i << "'>I" << i << "</td>";
146     }
147     oss << "</tr>" << std::endl;
148     oss << indent << "<tr><td align='center' colspan='" << cols << "' bgcolor='" << bgcolor << "'>" << label_
149         << "</td></tr>" << std::endl;
150     oss << indent << "<tr>";
151     auto output_cols = output_size_ == 0 ? 0 : cols / output_size_;
152     for (size_t i = 0; i < output_size_; i++) {
153       oss << "<td align='center' colspan='" << output_cols << "' port='O" << i << "'>O" << i << "</td>";
154     }
155     oss << "</tr>" << std::endl;
156     oss << indent << "</table>>";
157   } else {
158     oss << "\"" << label_ << "\"";
159   }
160   auto label = oss.str();
161   oss.str("");
162   oss << prefix_ << id_ << " [shape=" << shape_;
163   oss << ", label=" << label;
164   if (type_ != kNodeTypeCNode) {
165     oss << ", style=filled, fillcolor=" << color_;
166   }
167   oss << "];";
168   return oss.str();
169 }
170 
~GVGraph()171 GVGraph::~GVGraph() {
172   for (auto *node : nodes_) {
173     delete node;
174   }
175   nodes_.clear();
176 }
177 
AppendNode(GVNode * node)178 void GVGraph::AppendNode(GVNode *node) {
179   if (node == nullptr) {
180     return;
181   }
182   nodes_.emplace_back(node);
183   node_map_[node->name()] = node;
184 }
185 
Link(const std::string & from_name,size_t from_port,const std::string & to_name,size_t to_port)186 int GVGraph::Link(const std::string &from_name, size_t from_port, const std::string &to_name, size_t to_port) {
187   auto from = node_map_.find(from_name);
188   if (from == node_map_.end()) {
189     MS_LOG(ERROR) << "Node " << from_name << " is not belong to this graph.";
190     return RET_ERROR;
191   }
192   MS_ASSERT(from->second != nullptr);
193   if (from_port >= from->second->output_size()) {
194     MS_LOG(ERROR) << "`from_port`(" << from_port << ") out of range of node(" << from_name
195                   << ")'s output ports number: " << from->second->output_size();
196     return RET_ERROR;
197   }
198   auto to = node_map_.find(to_name);
199   if (to == node_map_.end()) {
200     MS_LOG(ERROR) << "Node " << to_name << " is not belong to this graph.";
201     return RET_ERROR;
202   }
203   MS_ASSERT(to->second != nullptr);
204   if (to_port >= to->second->input_size()) {
205     MS_LOG(ERROR) << "`to_port`(" << to_port << ") out of range of node(" << to_name
206                   << ")'s input ports number: " << to->second->input_size();
207     return RET_ERROR;
208   }
209   if (to_port < to->second->size()) {
210     MS_LOG(ERROR) << "node(" << to_name << ")'s " << to_port << "th input port already link to "
211                   << to->second->inputs()[to_port]->From();
212     return RET_ERROR;
213   }
214   auto edge = from->second->outputs()[from_port];
215   edge->AppendOutput(to->second, to_port);
216   to->second->AppendInput(edge);
217   return RET_OK;
218 }
219 
Code() const220 std::string GVGraph::Code() const {
221   std::ostringstream oss;
222   oss << "digraph " << name_ << " {" << std::endl;
223   for (auto node : nodes_) {
224     oss << node->Code() << std::endl;
225   }
226   for (auto node : nodes_) {
227     for (auto output : node->outputs()) {
228       oss << output->Code() << std::endl;
229     }
230   }
231   oss << "}";
232   return oss.str();
233 }
234 }  // namespace mindspore::lite
235