• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/graph_compiler/compile_result.h"
18 #include <string>
19 #include <memory>
20 #include <sstream>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 #include "ops/base_operator.h"
25 #include "utils/hash_map.h"
26 #include "include/api/status.h"
27 #include "ir/primitive.h"
28 #include "ops/op_name.h"
29 #include "ops/primitive_c.h"
30 #include "src/common/file_utils.h"
31 
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 constexpr char tab[] = "  ";
36 
GenIndent(int indent)37 inline std::string GenIndent(int indent) {
38   std::ostringstream oss;
39   for (int i = 0; i < indent; i++) {
40     oss << tab;
41   }
42   return oss.str();
43 }
44 
DumpIntShape(const std::vector<int> & shape)45 inline std::string DumpIntShape(const std::vector<int> &shape) {
46   std::ostringstream oss;
47   oss << "[";
48   for (size_t i = 0; i < shape.size(); i++) {
49     oss << shape[i];
50     if (i < shape.size() - 1) {
51       oss << ", ";
52     }
53   }
54   oss << "]";
55   return oss.str();
56 }
57 
DumpTensor(const InferTensor * tensor,int indent=0)58 inline std::string DumpTensor(const InferTensor *tensor, int indent = 0) {
59   std::ostringstream oss;
60   oss << GenIndent(indent) << "Tensor <name:" << tensor->tensor_name() << ", dtype:" << tensor->data_type()
61       << ", format:" << tensor->format() << ", shape:" << DumpIntShape(tensor->shape()) << ">";
62   return oss.str();
63 }
64 }  // namespace
65 
GetKernelAttr() const66 kernel::KernelAttr CompileNode::GetKernelAttr() const {
67   kernel::KernelAttr attr;
68   for (auto &input : inputs_) {
69     (void)attr.AddInputAttr(input->data_type(), FormatEnumToString(input->format()));
70   }
71   for (auto &output : outputs_) {
72     (void)attr.AddOutputAttr(output->data_type(), FormatEnumToString(output->format()));
73   }
74   return attr;
75 }
76 
Create(CNodePtr cnode)77 CompileNodePtr CompileNode::Create(CNodePtr cnode) {
78   if (cnode == nullptr) {
79     return nullptr;
80   }
81   auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
82   if (primitive == nullptr) {
83     MS_LOG(ERROR) << "Node has no primitive, first input of cnode(" << cnode->fullname_with_scope()
84                   << ") is : " << cnode->input(0);
85     return nullptr;
86   }
87   auto node = std::make_shared<CompileNode>(cnode->fullname_with_scope(), kernel::PrimitiveType(primitive->name()));
88   ops::PrimitiveCPtr primc{nullptr};
89   if (utils::isa<ops::PrimitiveCPtr>(primitive)) {
90     primc = utils::cast<ops::PrimitiveCPtr>(primitive);
91   } else {
92     static auto ops_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
93     auto primc_creator_iter = ops_primc_fns.find(node->type_.TypeName());
94     if (primc_creator_iter == ops_primc_fns.end()) {
95       MS_LOG(ERROR) << "Can not find primitive_c create function for: " << node->type_;
96       return nullptr;
97     }
98     primc = primc_creator_iter->second();
99     if (primc == nullptr) {
100       MS_LOG(ERROR) << "Create primitive_c failed, type: " << node->type_;
101       return nullptr;
102     }
103     (void)primc->SetAttrs(primitive->attrs());
104   }
105   static auto baseops_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
106   auto baseops_creator_iter = baseops_fns.find(node->type_.TypeName());
107   if (baseops_creator_iter == baseops_fns.end()) {
108     MS_LOG(ERROR) << "Can not find base-operator create function for: " << node->type_;
109     return nullptr;
110   }
111   auto baseops_creator = baseops_creator_iter->second;
112   node->base_operator_ = baseops_creator(primc);
113   if (node->base_operator_ == nullptr) {
114     MS_LOG(ERROR) << "Create base-operator failed, type: " << node->type_;
115     return nullptr;
116   }
117   node->cnode_ = std::move(cnode);
118   return node;
119 }
120 
AppendInputTensor(InferTensor * tensor)121 void CompileNode::AppendInputTensor(InferTensor *tensor) { this->inputs_.emplace_back(tensor); }
122 
AppendOutputTensor(InferTensor * tensor)123 void CompileNode::AppendOutputTensor(InferTensor *tensor) { this->outputs_.emplace_back(tensor); }
124 
Dump(int indent) const125 std::string CompileNode::Dump(int indent) const {
126   constexpr int kNumberTwo = 2;
127   std::ostringstream oss;
128   oss << GenIndent(indent) << "CompileNode <name:" << name_ << ", type:" << type_ << "> {" << std::endl;
129   oss << GenIndent(indent + 1) << "inputs: [" << std::endl;
130   for (auto &input : inputs_) {
131     oss << DumpTensor(input, indent + kNumberTwo) << std::endl;
132   }
133   oss << GenIndent(indent + 1) << "]" << std::endl;
134   oss << GenIndent(indent + 1) << "outputs: [" << std::endl;
135   for (auto &output : outputs_) {
136     oss << DumpTensor(output, indent + kNumberTwo) << std::endl;
137   }
138   oss << GenIndent(indent + 1) << "]" << std::endl;
139   oss << GenIndent(indent) << "}";
140   return oss.str();
141 }
142 
ReplaceInputTensor(InferTensor * dst,const InferTensor * src)143 void CompileNode::ReplaceInputTensor(InferTensor *dst, const InferTensor *src) {
144   std::replace_if(
145     inputs_.begin(), inputs_.end(), [&src](InferTensor *ele) { return ele == src; }, dst);
146 }
147 
GetNode(const std::string & name)148 CompileNodePtr CompileResult::GetNode(const std::string &name) {
149   auto iter = node_map_.find(name);
150   if (iter == node_map_.end()) {
151     return nullptr;
152   } else {
153     return iter->second;
154   }
155 }
156 
GetArgNode(const std::string & name)157 CompileNodePtr CompileResult::GetArgNode(const std::string &name) {
158   auto iter = arg_node_map_.find(name);
159   if (iter == arg_node_map_.end()) {
160     return nullptr;
161   } else {
162     return iter->second;
163   }
164 }
165 
GetMutableNodes()166 std::vector<CompileNodePtr> &CompileResult::GetMutableNodes() {
167   if (assembled_) {
168     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
169   }
170   return nodes_;
171 }
GetMutableInputs()172 std::vector<InferTensor *> &CompileResult::GetMutableInputs() {
173   if (assembled_) {
174     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
175   }
176   return inputs_;
177 }
178 
GetMutableOutputs()179 std::vector<InferTensor *> &CompileResult::GetMutableOutputs() {
180   if (assembled_) {
181     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
182   }
183   return outputs_;
184 }
185 
AppendNode(CompileNodePtr node)186 StatusCode CompileResult::AppendNode(CompileNodePtr node) {
187   if (assembled_) {
188     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
189   }
190   if (node == nullptr) {
191     MS_LOG(ERROR) << "Input node is nullptr";
192     return kLiteInputParamInvalid;
193   }
194   const std::string &node_name = node->GetName();
195   auto iter = node_map_.find(node_name);
196   if (iter != node_map_.end()) {
197     MS_LOG(ERROR) << "Duplicated node name : " << node_name;
198     return kLiteError;
199   }
200   node_map_[node_name] = node;
201   nodes_.emplace_back(node);
202   return kSuccess;
203 }
204 
AppendArgNode(CompileNodePtr node)205 StatusCode CompileResult::AppendArgNode(CompileNodePtr node) {
206   if (assembled_) {
207     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
208   }
209   if (node == nullptr) {
210     MS_LOG(ERROR) << "Input node is nullptr";
211     return kLiteInputParamInvalid;
212   }
213   const std::string &node_name = node->GetName();
214   auto iter = arg_node_map_.find(node_name);
215   if (iter != arg_node_map_.end()) {
216     MS_LOG(ERROR) << "Duplicated node name : " << node_name;
217     return kLiteError;
218   }
219   arg_node_map_[node_name] = node;
220   arg_nodes_.emplace_back(node);
221   return kSuccess;
222 }
223 
AppendTensor(InferTensor * tensor)224 StatusCode CompileResult::AppendTensor(InferTensor *tensor) {
225   if (assembled_) {
226     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
227   }
228   if (tensor == nullptr) {
229     MS_LOG(ERROR) << "Input tensor is nullptr";
230     return kLiteInputParamInvalid;
231   }
232   tensors_.emplace_back(tensor);
233   return kSuccess;
234 }
235 
AppendInputTensor(InferTensor * tensor,bool is_borrow)236 StatusCode CompileResult::AppendInputTensor(InferTensor *tensor, bool is_borrow) {
237   if (assembled_) {
238     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
239   }
240   if (tensor == nullptr) {
241     MS_LOG(ERROR) << "Input tensor is nullptr";
242     return kLiteInputParamInvalid;
243   }
244   inputs_.emplace_back(tensor);
245   if (!is_borrow) {
246     return AppendTensor(tensor);
247   }
248   return kSuccess;
249 }
250 
AppendOutputTensor(InferTensor * tensor,bool is_borrow)251 StatusCode CompileResult::AppendOutputTensor(InferTensor *tensor, bool is_borrow) {
252   if (assembled_) {
253     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
254   }
255   if (tensor == nullptr) {
256     MS_LOG(ERROR) << "Input tensor is nullptr";
257     return kLiteInputParamInvalid;
258   }
259   if (tensor->tensor_name().empty()) {
260     tensor->set_tensor_name("graph_out_" + std::to_string(this->outputs_.size()));
261   }
262   if (!is_borrow) {
263     return AppendTensor(tensor);
264   }
265   outputs_.emplace_back(tensor);
266   return kSuccess;
267 }
268 
AppendNodeInputTensor(const CompileNodePtr & compile_node,InferTensor * tensor,bool is_borrow)269 StatusCode CompileResult::AppendNodeInputTensor(const CompileNodePtr &compile_node, InferTensor *tensor,
270                                                 bool is_borrow) {
271   if (compile_node == nullptr) {
272     MS_LOG(ERROR) << "Input compile_node is nullptr";
273     return kLiteInputParamInvalid;
274   }
275   return AppendNodeInputTensor(compile_node->GetName(), tensor, is_borrow);
276 }
277 
AppendNodeInputTensor(const std::string & node_name,InferTensor * input_tensor,bool is_borrow)278 StatusCode CompileResult::AppendNodeInputTensor(const std::string &node_name, InferTensor *input_tensor,
279                                                 bool is_borrow) {
280   if (assembled_) {
281     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
282   }
283   if (input_tensor == nullptr) {
284     MS_LOG(ERROR) << "`input_tensor` is nullptr";
285     return kLiteInputParamInvalid;
286   }
287 
288   auto iter = node_map_.find(node_name);
289   if (iter == node_map_.end()) {
290     MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name;
291     return kLiteError;
292   }
293   iter->second->AppendInputTensor(input_tensor);
294   if (!is_borrow) {
295     return AppendTensor(input_tensor);
296   }
297   return kSuccess;
298 }
299 
AppendNodeOutputTensor(const CompileNodePtr & compile_node,InferTensor * tensor,bool is_borrow)300 StatusCode CompileResult::AppendNodeOutputTensor(const CompileNodePtr &compile_node, InferTensor *tensor,
301                                                  bool is_borrow) {
302   if (compile_node == nullptr) {
303     MS_LOG(ERROR) << "Input compile_node is nullptr";
304     return kLiteInputParamInvalid;
305   }
306   return AppendNodeOutputTensor(compile_node->GetName(), tensor, is_borrow);
307 }
308 
AppendNodeOutputTensor(const std::string & node_name,InferTensor * output_tensor,bool is_borrow)309 StatusCode CompileResult::AppendNodeOutputTensor(const std::string &node_name, InferTensor *output_tensor,
310                                                  bool is_borrow) {
311   if (assembled_) {
312     MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
313   }
314   if (output_tensor == nullptr) {
315     MS_LOG(ERROR) << "`output_tensor` is nullptr";
316     return kLiteInputParamInvalid;
317   }
318 
319   auto iter = node_map_.find(node_name);
320   if (iter == node_map_.end()) {
321     MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name;
322     return kLiteError;
323   }
324   iter->second->AppendOutputTensor(output_tensor);
325   if (!is_borrow) {
326     return AppendTensor(output_tensor);
327   }
328   return kSuccess;
329 }
330 
Dump(int indent) const331 std::string CompileResult::Dump(int indent) const {
332   constexpr int kNumTwo = 2;
333   std::ostringstream oss;
334   oss << GenIndent(indent) << "CompileResult {" << std::endl;
335   oss << GenIndent(indent + 1) << "nodes: [" << std::endl;
336   for (auto &node : nodes_) {
337     oss << node->Dump(indent + kNumTwo) << std::endl;
338   }
339   oss << GenIndent(indent + 1) << "]" << std::endl;
340   oss << GenIndent(indent + 1) << "inputs: [" << std::endl;
341   for (auto &input : inputs_) {
342     oss << DumpTensor(input, indent + kNumTwo) << std::endl;
343   }
344   oss << GenIndent(indent + 1) << "]" << std::endl;
345   oss << GenIndent(indent + 1) << "outputs: [" << std::endl;
346   for (auto &output : outputs_) {
347     oss << DumpTensor(output, indent + kNumTwo) << std::endl;
348   }
349   oss << GenIndent(indent + 1) << "]" << std::endl;
350   oss << GenIndent(indent + 1) << "tensors: [" << std::endl;
351   for (auto &tensor : tensors_) {
352     oss << DumpTensor(tensor, indent + kNumTwo) << std::endl;
353   }
354   oss << GenIndent(indent + 1) << "]" << std::endl;
355   oss << GenIndent(indent) << "}" << std::endl;
356   return oss.str();
357 }
358 }  // namespace lite
359 }  // namespace mindspore
360