• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/draw/graphviz_graph_builder.h"
18 #include <set>
19 #include <vector>
20 #include "src/common/draw/adapter_graph.h"
21 #include "ir/dtype.h"
22 
23 namespace mindspore::lite {
24 namespace {
StrReplace(std::string * str)25 inline void StrReplace(std::string *str) {
26   replace(str->begin(), str->end(), '/', '_');
27   replace(str->begin(), str->end(), '-', '_');
28 }
29 
ShortName(std::string * str)30 inline void ShortName(std::string *str) {
31   auto pos = str->rfind('/');
32   if (pos == std::string::npos) {
33     return;
34   }
35   *str = str->substr(pos + 1);
36 }
37 
GetNodeId(const AdapterNode & node)38 inline std::string GetNodeId(const AdapterNode &node) {
39   auto name = node.GetName();
40   StrReplace(&name);
41   return name;
42 }
43 
GetNodeLabel(const AdapterNode & node)44 inline std::string GetNodeLabel(const AdapterNode &node) {
45   auto name = node.GetName();
46   ShortName(&name);
47   StrReplace(&name);
48   return name;
49 }
50 
GetTensorId(const lite::Tensor & tensor)51 inline std::string GetTensorId(const lite::Tensor &tensor) {
52   auto name = tensor.tensor_name();
53   StrReplace(&name);
54   return name;
55 }
56 
GetTensorInfo(const lite::Tensor & tensor)57 inline std::string GetTensorInfo(const lite::Tensor &tensor) {
58   auto tensor_info = FormatEnumToString(tensor.format());
59   tensor_info += ", ";
60   tensor_info += TypeIdToString(tensor.data_type());
61   tensor_info += ", ";
62   tensor_info += lite::ShapeVectorToStr(tensor.shape());
63   return tensor_info;
64 }
65 }  // namespace
66 
Build(const std::shared_ptr<AdapterGraph> & graph)67 std::shared_ptr<GVGraph> GVGraphBuilder::Build(const std::shared_ptr<AdapterGraph> &graph) {
68   gv_graph_ = std::make_shared<GVGraph>(graph->GetName());
69   // graph inputs
70   for (auto in_tensor : graph->GetInputs()) {
71     this->AppendGraphInputNode(*in_tensor);
72   }
73   // nodes
74   for (const auto *node : graph->GetNodes()) {
75     auto node_id = GetNodeId(*node);
76     auto node_label = GetNodeLabel(*node);
77     for (size_t i = 0; i < node->InputSize(); i++) {
78       auto in_tensor = node->GetInput(i);
79       if (GetBelongingGVNode(in_tensor).first == nullptr) {
80         if (!in_tensor->IsConst()) {
81           MS_LOG(WARNING) << "The " << i << "th input of " << node->GetName()
82                           << " is neither a const tensor nor an output of other node. Treat it as a weight node.";
83         }
84         auto tensor_id = node_id + "_in_" + std::to_string(i);
85         auto tensor_label = node_label + "_in_" + std::to_string(i);
86         AppendWeightNode(*in_tensor, tensor_id, tensor_label);
87       }
88     }
89     auto ret = this->AppendComputeNode(*node);
90     if (ret != RET_OK) {
91       MS_LOG(ERROR) << "Create and append gv_node for " << node->GetName() << " failed.";
92       return nullptr;
93     }
94   }
95   // graph outputs
96   auto ret = this->AppendGraphOutputNode(graph->GetOutputs());
97   if (ret != RET_OK) {
98     MS_LOG(ERROR) << "Create and append graph return node failed";
99     return nullptr;
100   }
101   return this->gv_graph_;
102 }
103 
AppendGraphInputNode(const lite::Tensor & tensor)104 void GVGraphBuilder::AppendGraphInputNode(const lite::Tensor &tensor) {
105   auto tensor_id = GetTensorId(tensor);
106   auto gv_node = lite::GVNode::CreateInput(tensor_id, {tensor_id}, {GetTensorInfo(tensor)});
107   MS_ASSERT(gv_node != nullptr);
108   gv_graph_->AppendNode(gv_node);
109   gv_node_out_tensor_map_[&tensor] = std::make_pair(gv_node, 0);
110 }
111 
112 namespace {
113 template <typename T>
BufferToString(const T * buffer,size_t size)114 std::string BufferToString(const T *buffer, size_t size) {
115   MS_ASSERT(buffer != nullptr);
116   constexpr size_t print_pre_number = 3;
117   constexpr size_t print_post_number = 3;
118   constexpr size_t print_period_number = 2;
119   if (size <= print_pre_number + print_post_number + print_period_number) {
120     std::ostringstream oss;
121     for (size_t i = 0; i < size; i++) {
122       if (i == 0) {
123         oss << buffer[i];
124       } else {
125         oss << ", " << buffer[i];
126       }
127     }
128     return oss.str();
129   }
130 
131   size_t index = 0;
132   std::ostringstream oss;
133   for (size_t i = 0; i < print_pre_number; i++, index++) {
134     if (index == 0) {
135       oss << buffer[index];
136     } else {
137       oss << ", " << buffer[index];
138     }
139   }
140   oss << "...";
141   for (size_t i = 0; i < print_post_number; i++, index++) {
142     oss << ", " << buffer[index];
143   }
144   return oss.str();
145 }
146 
TensorDataString(const lite::Tensor & tensor)147 std::string TensorDataString(const lite::Tensor &tensor) {
148   if (tensor.shape().size() != 1 || tensor.shape()[0] <= 0 || tensor.data() == nullptr) {
149     return "";
150   }
151   auto data_size = static_cast<size_t>(tensor.shape()[0]);
152 
153   std::ostringstream oss;
154   oss << "\n[";
155   if (tensor.data_type() == kNumberTypeInt || tensor.data_type() == kNumberTypeInt32) {
156     auto data = reinterpret_cast<int *>(tensor.data());
157     oss << BufferToString(data, data_size);
158   } else if (tensor.data_type() == kNumberTypeInt64) {
159     auto data = reinterpret_cast<int64_t *>(tensor.data());
160     oss << BufferToString(data, data_size);
161   } else {
162     return "";
163   }
164   oss << "]";
165   return oss.str();
166 }
167 }  // namespace
168 
AppendWeightNode(const lite::Tensor & tensor,const std::string & id,const std::string & label)169 void GVGraphBuilder::AppendWeightNode(const lite::Tensor &tensor, const std::string &id, const std::string &label) {
170   auto gv_node = lite::GVNode::CreateWeight(id, label + TensorDataString(tensor), {id}, {GetTensorInfo(tensor)});
171   MS_ASSERT(gv_node != nullptr);
172   gv_graph_->AppendNode(gv_node);
173   AppendOutTensorMap(&tensor, gv_node, 0);
174 }
175 
AppendComputeNode(const AdapterNode & node)176 int GVGraphBuilder::AppendComputeNode(const AdapterNode &node) {
177   auto gv_node = CreateComputeNode(node);
178   if (gv_node == nullptr) {
179     MS_LOG(ERROR) << "Create gv_node for " << node.GetName() << " failed.";
180     return RET_ERROR;
181   }
182   gv_graph_->AppendNode(gv_node);
183   for (size_t i = 0; i < node.OutputSize(); i++) {
184     AppendOutTensorMap(node.GetOutput(i), gv_node, i);
185   }
186   auto ret = LinkNodes(node, *gv_node);
187   if (ret != RET_OK) {
188     MS_LOG(ERROR) << "Link inputs for " << node.GetName() << " failed.";
189     return RET_ERROR;
190   }
191   return RET_OK;
192 }
193 
AppendGraphOutputNode(const std::vector<lite::Tensor * > & out_tensors)194 int GVGraphBuilder::AppendGraphOutputNode(const std::vector<lite::Tensor *> &out_tensors) {
195   auto out_tensor_size = out_tensors.size();
196   auto gv_node = lite::GVNode::CreateOutput("return", out_tensor_size);
197   MS_ASSERT(gv_node != nullptr);
198   gv_graph_->AppendNode(gv_node);
199   for (size_t i = 0; i < out_tensors.size(); i++) {
200     auto out_tensor = out_tensors[i];
201     auto pair = this->GetBelongingGVNode(out_tensor);
202     if (pair.first == nullptr) {
203       MS_LOG(ERROR) << "Can not find graph output tensor source: " << out_tensor->tensor_name();
204       return RET_ERROR;
205     }
206     auto link_ret = gv_graph_->Link(pair.first->name(), pair.second, gv_node->name(), i);
207     if (link_ret != RET_OK) {
208       MS_LOG(ERROR) << "Link " << i << "th input tensor of return failed.";
209       return RET_ERROR;
210     }
211   }
212   return RET_OK;
213 }
214 
CreateComputeNode(const AdapterNode & node)215 GVNode *GVGraphBuilder::CreateComputeNode(const AdapterNode &node) {
216   auto node_id = GetNodeId(node);
217   auto node_label = GetNodeLabel(node);
218   std::vector<std::string> output_names;
219   std::vector<std::string> output_infos;
220   for (auto out_tensor : node.GetOutputs()) {
221     output_names.emplace_back(GetTensorId(*out_tensor));
222     output_infos.emplace_back(GetTensorInfo(*out_tensor));
223   }
224   auto *gv_node =
225     lite::GVNode::CreateCNode(node_id, node_label, node.InputSize(), output_names, output_infos, node.IsHighlight());
226   MS_ASSERT(gv_node != nullptr);
227   return gv_node;
228 }
229 
AppendOutTensorMap(const lite::Tensor * tensor,lite::GVNode * node,size_t out_index)230 void GVGraphBuilder::AppendOutTensorMap(const lite::Tensor *tensor, lite::GVNode *node, size_t out_index) {
231   gv_node_out_tensor_map_[tensor] = std::make_pair(node, out_index);
232 }
233 
GetBelongingGVNode(const lite::Tensor * tensor) const234 std::pair<lite::GVNode *, size_t> GVGraphBuilder::GetBelongingGVNode(const lite::Tensor *tensor) const {
235   auto iter = gv_node_out_tensor_map_.find(tensor);
236   if (iter == gv_node_out_tensor_map_.end()) {
237     return {};
238   } else {
239     return iter->second;
240   }
241 }
LinkNodes(const AdapterNode & node,const GVNode & gv_node)242 int GVGraphBuilder::LinkNodes(const AdapterNode &node, const GVNode &gv_node) {
243   for (size_t i = 0; i < node.InputSize(); i++) {
244     auto in_tensor = node.GetInput(i);
245     auto pair = this->GetBelongingGVNode(in_tensor);
246     if (pair.first == nullptr) {
247       MS_LOG(ERROR) << "Can not find input tensor source: " << in_tensor->tensor_name();
248       return RET_ERROR;
249     }
250     auto link_ret = gv_graph_->Link(pair.first->name(), pair.second, gv_node.name(), i);
251     if (link_ret != RET_OK) {
252       MS_LOG(ERROR) << "Link " << i << "th input tensor of " << node.GetName() << " failed.";
253       return RET_ERROR;
254     }
255   }
256   return RET_OK;
257 }
258 }  // namespace mindspore::lite
259