• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/graph_compiler/compile_result_builder.h"
18 #include <algorithm>
19 #include "mindspore/core/ops/structure_ops.h"
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "src/extendrt/graph_compiler/anfnode_tensor_adapter.h"
23 #include "ir/anf.h"
24 #include "ir/func_graph.h"
25 #include "ir/primitive.h"
26 #include "ops/op_name.h"
27 #include "ops/primitive_c.h"
28 #include "src/extendrt/utils/func_graph_utils.h"
29 
30 using AbstractBasePtr = mindspore::abstract::AbstractBasePtr;
31 using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr;
32 using AbstractSequencePtr = mindspore::abstract::AbstractSequencePtr;
33 
34 namespace mindspore {
35 namespace lite {
BuildInputs(const AnfNodePtrList & inputs)36 StatusCode CompileResultBuilder::BuildInputs(const AnfNodePtrList &inputs) {
37   MS_ASSERT(graph_ != nullptr);
38   if (graph_->InputSize() > 0) {
39     MS_LOG(ERROR) << "Please don't call BuildInputs twice.";
40     return kLiteError;
41   }
42   for (auto &input : inputs) {
43     auto results = TensorAdapter::CreateTensorsFromAbstract(input->abstract(), compile_option_->graph_input_format);
44     if (results.empty()) {
45       MS_LOG(ERROR) << "Create tensors from abstract of segments input failed, input : "
46                     << input->fullname_with_scope();
47       return kLiteError;
48     }
49     auto arg_node = std::make_shared<CompileNode>(input->fullname_with_scope(), kernel::PrimitiveType());
50     auto ret = graph_->AppendArgNode(arg_node);
51     if (ret != kSuccess) {
52       MS_LOG(ERROR) << "Append input lite-node to graph failed, input : " << input->fullname_with_scope();
53       return ret;
54     }
55     for (auto &result : results) {
56       auto tensor = result.release();
57       arg_node->AppendOutputTensor(tensor);
58       ret = graph_->AppendInputTensor(tensor);
59       if (ret != kSuccess) {
60         MS_LOG(ERROR) << "Append output tensor to argument node failed, node: " << input->fullname_with_scope();
61         delete (tensor);
62         return ret;
63       }
64     }
65   }
66   return kSuccess;
67 }
68 
BuildNodes(const std::vector<AnfNodePtr> & nodes)69 StatusCode CompileResultBuilder::BuildNodes(const std::vector<AnfNodePtr> &nodes) {
70   MS_ASSERT(graph_ != nullptr);
71   if (graph_->NodeSize() > 0) {
72     MS_LOG(ERROR) << "Please don't call BuildNodes twice.";
73     return kLiteError;
74   }
75 
76   for (auto &node : nodes) {
77     if (!utils::isa<CNodePtr>(node)) {
78       continue;
79     }
80     auto ret = CreateAndAppendNode(utils::cast<CNodePtr>(node));
81     if (ret != kSuccess) {
82       MS_LOG(ERROR) << "Create compile node from cnode failed : " << node;
83       return ret;
84     }
85   }
86   return kSuccess;
87 }
88 
BuildNodes(const GraphSegmentPtr & graph_segment)89 StatusCode CompileResultBuilder::BuildNodes(const GraphSegmentPtr &graph_segment) {
90   return BuildNodes(graph_segment->nodes_);
91 }
92 
BuildOutputs(const AnfNodePtrList & outputs)93 StatusCode CompileResultBuilder::BuildOutputs(const AnfNodePtrList &outputs) {
94   MS_ASSERT(graph_ != nullptr);
95   if (graph_->OutputSize() > 0) {
96     MS_LOG(ERROR) << "Please don't call BuildOutputs twice.";
97     return kLiteError;
98   }
99   for (auto &output : outputs) {
100     auto out_cnode = utils::cast<CNodePtr>(output);
101     if (out_cnode == nullptr) {
102       MS_LOG(ERROR) << "Outputs should be a CNode vector, but got " << output->Type() << " type element.";
103       return kLiteError;
104     }
105     auto compile_node = graph_->GetNode(out_cnode->fullname_with_scope());
106     if (compile_node == nullptr) {
107       continue;
108     }
109     for (auto &tensor : compile_node->GetOutputs()) {
110       auto ret = graph_->AppendOutputTensor(tensor, true);
111       if (ret != kSuccess) {
112         MS_LOG(ERROR) << "Append output tensor to graph failed, output: " << out_cnode->fullname_with_scope();
113         return ret;
114       }
115     }
116   }
117   return kSuccess;
118 }
119 
120 // Replace `dst_tensor` with `src_tensor`.
ReplaceTensor(InferTensor * dst_tensor,const InferTensor * src_tensor)121 void CompileResultBuilder::ReplaceTensor(InferTensor *dst_tensor, const InferTensor *src_tensor) {
122   // used as inputs of other node
123   auto &nodes = graph_->GetMutableNodes();
124   for (auto &compile_node : nodes) {
125     if (compile_node == nullptr) {
126       continue;
127     }
128     compile_node->ReplaceInputTensor(dst_tensor, src_tensor);
129   }
130   // used as outputs of graph
131   auto &outputs = graph_->GetMutableOutputs();
132   std::replace_if(
133     outputs.begin(), outputs.end(), [&src_tensor](InferTensor *ele) { return ele == src_tensor; }, dst_tensor);
134 }
135 
RemoveMakeSeqNode()136 StatusCode CompileResultBuilder::RemoveMakeSeqNode() {
137   auto &nodes = graph_->GetMutableNodes();
138   for (auto iter = nodes.begin(); iter != nodes.end();) {
139     auto &node = *iter;
140     if (node->GetType() != kMakeTupleOpName && node->GetType() != kMakeListOpName) {
141       iter++;
142       continue;
143     }
144     MS_LOG(INFO) << "Handling make sequence node: " << node->GetName();
145     auto tensor_number = node->InputSize();
146     if (tensor_number != node->OutputSize()) {
147       MS_LOG(ERROR) << "MakeSequence node should has same number of inputs and outputs, but got " << tensor_number
148                     << " inputs and " << node->OutputSize() << " outputs.";
149       return kLiteError;
150     }
151     for (size_t i = 0; i < tensor_number; i++) {
152       ReplaceTensor(node->GetInput(i), node->GetOutput(i));
153     }
154     iter = nodes.erase(iter);
155   }
156   return kSuccess;
157 }
158 
RemoveDependNode()159 StatusCode CompileResultBuilder::RemoveDependNode() {
160   auto &nodes = graph_->GetMutableNodes();
161   for (auto iter = nodes.begin(); iter != nodes.end();) {
162     auto &node = *iter;
163     if (node->GetType() != kDependOpName) {
164       iter++;
165       continue;
166     }
167     MS_LOG(INFO) << "Handling Depend node: " << node->GetName();
168     constexpr int kSize2 = 2;
169     if (node->InputSize() != kSize2) {
170       MS_LOG(ERROR) << "Depend node should has 2 inputs, but got " << node->InputSize();
171       return kLiteError;
172     }
173     if (node->OutputSize() != 1) {
174       MS_LOG(ERROR) << "Depend node should has 1 outputs, but got " << node->OutputSize();
175       return kLiteError;
176     }
177     ReplaceTensor(node->GetInput(0), node->GetOutput(0));
178     iter = nodes.erase(iter);
179   }
180   return kSuccess;
181 }
182 
RemoveSeqGetItemNode()183 StatusCode CompileResultBuilder::RemoveSeqGetItemNode() {
184   auto &nodes = graph_->GetMutableNodes();
185   for (auto iter = nodes.begin(); iter != nodes.end();) {
186     auto &node = *iter;
187     if (node->GetType() != kTupleGetItemOpName && node->GetType() != kListGetItemOpName &&
188         node->GetType() != "array_getitem" && node->GetType() != kSliceGetItemOpName) {
189       iter++;
190       continue;
191     }
192     MS_LOG(DEBUG) << "Handling GetItem node: " << node->GetName();
193     if (node->OutputSize() != 1) {
194       MS_LOG(ERROR) << "GetItem node should has 1 outputs, but got " << node->OutputSize();
195       return kLiteError;
196     }
197     auto index_tensor = node->GetInput(node->GetInputs().size() - 1);
198     if (index_tensor->data() == nullptr) {
199       MS_LOG(ERROR) << "`index_tensor` of GetItem should be a const tensor, but has no data.";
200       return kLiteError;
201     }
202     if (index_tensor->data_type() == kNumberTypeInt32) {
203       auto idx = reinterpret_cast<int32_t *>(index_tensor->data())[0];
204       ReplaceTensor(node->GetInput(idx), node->GetOutput(0));
205     } else if (index_tensor->data_type() == kNumberTypeInt64) {
206       auto idx = reinterpret_cast<int64_t *>(index_tensor->data())[0];
207       ReplaceTensor(node->GetInput(idx), node->GetOutput(0));
208     } else {
209       MS_LOG(ERROR) << "`index_tensor` of GetItem should be a const tensor with int data type, but got "
210                     << index_tensor->data_type();
211       return kLiteError;
212     }
213     iter = nodes.erase(iter);
214   }
215   return kSuccess;
216 }
217 
OptimizeGraph()218 StatusCode CompileResultBuilder::OptimizeGraph() {
219   MS_ASSERT(graph_ != nullptr);
220   auto ret = RemoveDependNode();
221   if (ret != kSuccess) {
222     MS_LOG(ERROR) << "Handle Depend node failed";
223     return ret;
224   }
225   ret = RemoveMakeSeqNode();
226   if (ret != kSuccess) {
227     MS_LOG(ERROR) << "Handle Make Sequence node failed";
228     return ret;
229   }
230   ret = RemoveSeqGetItemNode();
231   if (ret != kSuccess) {
232     MS_LOG(ERROR) << "Handle Sequence-Getitem node failed";
233     return ret;
234   }
235   return kSuccess;
236 }
237 
Build(const GraphSegmentPtr & graph_segment,const AnfNodePtrList & inputs,const AnfNodePtrList & outputs)238 CompileResultPtr CompileResultBuilder::Build(const GraphSegmentPtr &graph_segment, const AnfNodePtrList &inputs,
239                                              const AnfNodePtrList &outputs) {
240   graph_ = std::make_shared<CompileResult>();
241   if (BuildInputs(inputs) != kSuccess) {
242     MS_LOG(ERROR) << "Build graph inputs failed";
243     return nullptr;
244   }
245   if (BuildNodes(graph_segment) != kSuccess) {
246     MS_LOG(ERROR) << "Build graph nodes failed";
247     return nullptr;
248   }
249   if (BuildOutputs(outputs) != kSuccess) {
250     MS_LOG(ERROR) << "Build graph outputs failed";
251     return nullptr;
252   }
253   if (OptimizeGraph() != kSuccess) {
254     MS_LOG(ERROR) << "Optimize graph failed";
255     return nullptr;
256   }
257   graph_->Assemble();
258   return graph_;
259 }
260 
AppendInputCNodeToInputs(const CNodePtr & cnode,const CompileNodePtr & compile_node)261 StatusCode CompileResultBuilder::AppendInputCNodeToInputs(const CNodePtr &cnode, const CompileNodePtr &compile_node) {
262   if (cnode == nullptr) {
263     MS_LOG(ERROR) << "Input cnode is nullptr.";
264     return kLiteInputParamInvalid;
265   }
266   if (compile_node == nullptr) {
267     MS_LOG(ERROR) << "Input compile_node is nullptr.";
268     return kLiteInputParamInvalid;
269   }
270   auto input_node = graph_->GetNode(cnode->fullname_with_scope());
271   if (input_node == nullptr) {
272     input_node = graph_->GetArgNode(cnode->fullname_with_scope());
273   }
274   if (input_node == nullptr) {
275     MS_LOG(ERROR) << "Can not find input lite-node in graph, node: " << cnode->fullname_with_scope();
276     return kLiteError;
277   }
278   for (auto &input_node_output : input_node->GetOutputs()) {
279     auto ret = graph_->AppendNodeInputTensor(compile_node, input_node_output, true);
280     if (ret != kSuccess) {
281       MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName();
282       return ret;
283     }
284   }
285   return kSuccess;
286 }
287 
AppendInputParameterToInputs(const ParameterPtr & param_node,const CompileNodePtr & compile_node)288 StatusCode CompileResultBuilder::AppendInputParameterToInputs(const ParameterPtr &param_node,
289                                                               const CompileNodePtr &compile_node) {
290   if (param_node == nullptr) {
291     MS_LOG(ERROR) << "Input param_node is nullptr.";
292     return kLiteInputParamInvalid;
293   }
294   if (compile_node == nullptr) {
295     MS_LOG(ERROR) << "Input compile_node is nullptr.";
296     return kLiteInputParamInvalid;
297   }
298   auto arg_node = graph_->GetArgNode(param_node->fullname_with_scope());
299   if (arg_node != nullptr) {
300     for (auto &output : arg_node->GetOutputs()) {
301       auto ret = graph_->AppendNodeInputTensor(compile_node, output, true);
302       if (ret != kSuccess) {
303         MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName();
304         return ret;
305       }
306     }
307     return kSuccess;
308   }
309   auto tensor_from_param = TensorAdapter::Convert2Tensor(param_node);
310   if (tensor_from_param == nullptr) {
311     MS_LOG(ERROR) << "Create tensor from Parameter failed.";
312     return kLiteError;
313   }
314   auto format_value = compile_node->GetBaseOperator()->GetAttr(mindspore::ops::kFormat);
315   if (format_value != nullptr) {
316     tensor_from_param->set_format(static_cast<Format>(GetValue<int64_t>(format_value)));
317   } else {
318     tensor_from_param->set_format(compile_option_->graph_format);
319   }
320   auto ret = graph_->AppendNodeInputTensor(compile_node, tensor_from_param);
321   if (ret != kSuccess) {
322     MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName();
323     delete tensor_from_param;
324     return ret;
325   }
326   return kSuccess;
327 }
328 
AppendInputValueNodeToInputs(const ValueNodePtr & value_node,const CompileNodePtr & compile_node)329 StatusCode CompileResultBuilder::AppendInputValueNodeToInputs(const ValueNodePtr &value_node,
330                                                               const CompileNodePtr &compile_node) {
331   if (value_node == nullptr) {
332     MS_LOG(ERROR) << "Input value_node is nullptr.";
333     return kLiteInputParamInvalid;
334   }
335   if (compile_node == nullptr) {
336     MS_LOG(ERROR) << "Input compile_node is nullptr.";
337     return kLiteInputParamInvalid;
338   }
339   if (value_node->value() != nullptr && value_node->value()->isa<Monad>()) {
340     MS_LOG(WARNING) << "Skip Monad value node: " << value_node->fullname_with_scope();
341     return kSuccess;
342   }
343   auto tensor_from_value = TensorAdapter::Convert2Tensor(value_node);
344   if (tensor_from_value == nullptr) {
345     MS_LOG(ERROR) << "Create tensor from ValueNode failed.";
346     return kLiteError;
347   }
348   auto format_value = compile_node->GetBaseOperator()->GetAttr(mindspore::ops::kFormat);
349   if (format_value != nullptr) {
350     tensor_from_value->set_format(static_cast<Format>(GetValue<int64_t>(format_value)));
351   } else {
352     tensor_from_value->set_format(compile_option_->graph_format);
353   }
354   auto ret = graph_->AppendNodeInputTensor(compile_node, tensor_from_value);
355   if (ret != kSuccess) {
356     MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName();
357     delete tensor_from_value;
358     return ret;
359   }
360   return kSuccess;
361 }
362 
CreateAndAppendNode(const CNodePtr & cnode)363 StatusCode CompileResultBuilder::CreateAndAppendNode(const CNodePtr &cnode) {
364   auto compile_node = CompileNode::Create(cnode);
365   if (compile_node == nullptr) {
366     MS_LOG(ERROR) << "Create compile node failed, cnode: " << cnode->fullname_with_scope();
367     return kLiteError;
368   }
369   auto ret = graph_->AppendNode(compile_node);
370   if (ret != kSuccess) {
371     MS_LOG(ERROR) << "Append compile_node to graph failed, node: " << compile_node->GetName();
372     return ret;
373   }
374   // inputs
375   for (size_t i = 1; i < cnode->size(); i++) {
376     auto &input = cnode->input(i);
377     if (utils::isa<CNodePtr>(input)) {
378       ret = this->AppendInputCNodeToInputs(utils::cast<CNodePtr>(input), compile_node);
379     } else if (utils::isa<Parameter>(input)) {
380       ret = this->AppendInputParameterToInputs(utils::cast<ParameterPtr>(input), compile_node);
381     } else if (utils::isa<ValueNode>(input)) {
382       ret = this->AppendInputValueNodeToInputs(utils::cast<ValueNodePtr>(input), compile_node);
383     } else {
384       MS_LOG(ERROR) << "Unsupported input node of cnode: " << input
385                     << ", current cnode: " << cnode->fullname_with_scope();
386       ret = kLiteNotSupport;
387     }
388     if (ret != kSuccess) {
389       MS_LOG(ERROR) << "Create input tensor for cnode failed, cnode: " << cnode->fullname_with_scope();
390       return ret;
391     }
392   }
393   // outputs
394   ret = BuildNodeOutputTensor(cnode, compile_node);
395   if (ret != kSuccess) {
396     MS_LOG(ERROR) << "Create output tensors of cnode failed, cnode: " << cnode;
397     return ret;
398   }
399   return kSuccess;
400 }
401 
BuildNodeOutputTensor(const CNodePtr & cnode,const CompileNodePtr & compile_node)402 StatusCode CompileResultBuilder::BuildNodeOutputTensor(const CNodePtr &cnode, const CompileNodePtr &compile_node) {
403   if (compile_node == nullptr) {
404     MS_LOG(ERROR) << "Input compile_node is nullptr.";
405     return kLiteInputParamInvalid;
406   }
407   if (compile_node->OutputSize() > 0) {
408     MS_LOG(ERROR) << "Build node output twice, node : " << compile_node->GetName();
409     return kLiteError;
410   }
411   auto results = TensorAdapter::Convert2Tensor(cnode);
412   if (results.empty()) {
413     MS_LOG(ERROR) << "Create tensors from cnode failed, cnode : " << cnode->fullname_with_scope();
414     return kLiteError;
415   }
416   size_t index = 0;
417   auto ret = kSuccess;
418   for (; index < results.size(); index++) {
419     auto tensor = results[index];
420     ret = graph_->AppendNodeOutputTensor(compile_node, tensor);
421     if (ret != kSuccess) {
422       MS_LOG(ERROR) << "Append output tensor to node failed, node: " << compile_node->GetName();
423       break;
424     }
425   }
426   // release results if failed
427   for (; index < results.size(); index++) {
428     delete results[index];
429   }
430   return ret;
431 }
432 
BuildNodes(const FuncGraphPtr & func_graph)433 StatusCode CompileResultBuilder::BuildNodes(const FuncGraphPtr &func_graph) {
434   MS_ASSERT(func_graph != nullptr);
435   auto nodes = func_graph->TopoSort(func_graph->get_return());
436   if (nodes.empty()) {
437     MS_LOG(ERROR) << "There are no nodes in the graph";
438     return kLiteError;
439   }
440 
441   return BuildNodes(nodes);
442 }
443 
Build(const FuncGraphPtr & func_graph)444 CompileResultPtr CompileResultBuilder::Build(const FuncGraphPtr &func_graph) {
445   graph_ = std::make_shared<CompileResult>();
446 
447   if (BuildInputs(func_graph->get_inputs()) != kSuccess) {
448     MS_LOG(ERROR) << "Build graph inputs failed";
449     return nullptr;
450   }
451   if (BuildNodes(func_graph) != kSuccess) {
452     MS_LOG(ERROR) << "Build graph nodes failed";
453     return nullptr;
454   }
455 
456   std::vector<AnfWithOutIndex> outputs_with_index;
457   FuncGraphUtils::GetFuncGraphOutputs(func_graph, &outputs_with_index);
458   AnfNodePtrList outputs;
459   outputs.resize(outputs_with_index.size());
460   for (auto &output : outputs_with_index) {
461     if (output.second >= outputs.size()) {
462       MS_LOG(ERROR) << "Build graph nodes failed";
463       return nullptr;
464     }
465     outputs[output.second] = output.first;
466   }
467   if (BuildOutputs(outputs) != kSuccess) {
468     MS_LOG(ERROR) << "Build graph outputs failed";
469     return nullptr;
470   }
471   if (OptimizeGraph() != kSuccess) {
472     MS_LOG(ERROR) << "Optimize graph failed";
473     return nullptr;
474   }
475   graph_->Assemble();
476   return graph_;
477 }
478 }  // namespace lite
479 }  // namespace mindspore
480