• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
18 #include <vector>
19 #include <deque>
20 #include <set>
21 #include "src/common/common.h"
22 #include "src/common/log_adapter.h"
23 #include "include/errorcode.h"
24 #include "src/tensor.h"
25 #include "src/tensorlist.h"
26 #include "src/common/prim_util.h"
27 #include "src/ops/populate/populate_register.h"
28 #include "src/runtime/infer_manager.h"
29 #include "tools/common/node_util.h"
30 #include "tools/converter/converter_flags.h"
31 #include "src/common/string_util.h"
32 #include "src/common/log_util.h"
33 #include "nnacl/op_base.h"
34 
35 using mindspore::converter::kFmkTypeTf;
36 namespace mindspore {
37 namespace lite {
38 namespace {
39 constexpr int DEFAULT_DIM_VALUE = -1;
40 constexpr size_t kInitialSize = 1024;
41 constexpr int kMainGraphIndex = 0;
42 constexpr int kCallInputMinSize = 1;
43 constexpr int kSwitchInputMinSize = 3;
44 constexpr int kTypeIndex = 0;
45 constexpr int kElementShapeIndex = 1;
46 constexpr int kFirstElementShapeIndex = 2;
47 constexpr int kTensorListDatasize = 3;
48 
FreeTensors(std::vector<Tensor * > * input_tensors,std::vector<Tensor * > * output_tensors)49 void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
50   if (input_tensors == nullptr) {
51     return;
52   }
53   for (auto &tensor : *input_tensors) {
54     if (tensor == nullptr) {
55       continue;
56     }
57     if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
58       tensor->set_data(nullptr);
59     }
60     delete tensor;
61     tensor = nullptr;
62   }
63   if (output_tensors == nullptr) {
64     return;
65   }
66   for (auto &tensor : *output_tensors) {
67     if (tensor == nullptr) {
68       continue;
69     }
70     if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
71       tensor->set_data(nullptr);
72     }
73     delete tensor;
74     tensor = nullptr;
75   }
76   input_tensors->resize(0);
77   output_tensors->resize(0);
78 }
79 
ConvertTensorList(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)80 void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
81   std::unique_ptr<Tensor> lite_tensor = nullptr;
82   auto &tensorT = graph->allTensors.at(index);
83   std::vector<int32_t> tensor_shape{};
84   TypeId type = kTypeUnknown;
85   std::vector<int> element_shape;
86   if (!tensorT->data.empty()) {
87     auto data_len = tensorT->data.size();
88     int *data = reinterpret_cast<int *>(tensorT->data.data());
89     type = TypeId(data[kTypeIndex]);
90     if (data_len < kTensorDataSize ||
91         (data[kElementShapeIndex] != 0 && static_cast<int>((data[kElementShapeIndex] + kTensorListDatasize) *
92                                                            sizeof(int)) != static_cast<int>(tensorT->data.size()))) {
93       MS_LOG(ERROR) << "tensorlist data length illegal, tensorT name: " << tensorT->name;
94       MS_LOG(ERROR) << "(data[1] + 3) * sizeof(int): "
95                     << ((data[kElementShapeIndex] + kTensorListDatasize) * sizeof(int));
96       MS_LOG(ERROR) << "static_cast<int>(tensorT->data.size()): " << static_cast<int>(tensorT->data.size());
97       *convert_succ = false;
98       return;
99     }
100     for (int j = 0; j < data[kElementShapeIndex]; ++j) {
101       element_shape.push_back(data[j + kFirstElementShapeIndex]);
102     }
103     if (INT_ADD_OVERFLOW(data[kElementShapeIndex], kFirstElementShapeIndex)) {
104       MS_LOG(ERROR) << "int add overflow";
105       *convert_succ = false;
106       return;
107     }
108     tensor_shape = {data[data[kElementShapeIndex] + kFirstElementShapeIndex]};
109   }
110   lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
111   if (lite_tensor == nullptr) {
112     MS_LOG(ERROR) << "lite tensorlist is nullptr";
113     *convert_succ = false;
114     return;
115   }
116 
117   auto lite_tensor_list = reinterpret_cast<TensorList *>(lite_tensor.get());
118   std::vector<Tensor *> tensors{};
119   if (!tensor_shape.empty() && tensor_shape.front() == -1) {
120     MS_LOG(INFO) << "tensor_shape is -1, tensor name: " << lite_tensor->tensor_name();
121   }
122   if (!tensor_shape.empty() && tensor_shape.front() != -1) {
123     for (int32_t i = 0; i < tensor_shape.front(); ++i) {
124       auto tensor = new (std::nothrow) Tensor(type, element_shape);
125       tensors.emplace_back(tensor);
126     }
127   }
128 
129   lite_tensor_list->set_tensors_data_type(type);
130   lite_tensor_list->set_element_shape(element_shape);
131   lite_tensor_list->set_tensors(tensors);
132   lite_tensors->emplace_back(lite_tensor.release());
133 }
134 
ConvertString(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)135 void ConvertString(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
136   std::unique_ptr<Tensor> lite_tensor = nullptr;
137   auto &tensorT = graph->allTensors.at(index);
138   auto tensor_shape = tensorT->dims;
139   lite_tensor = std::make_unique<Tensor>(
140     TypeId(tensorT->dataType), tensor_shape, static_cast<mindspore::Format>(tensorT->format),
141     TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
142   if (lite_tensor == nullptr) {
143     MS_LOG(ERROR) << "lite tensor is nullptr";
144     *convert_succ = false;
145     return;
146   }
147   auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
148   // when tensorT as param input
149   if (lite_tensor_size == 0) {
150     lite_tensors->emplace_back(lite_tensor.release());
151     return;
152   }
153   auto string_buffer = ParseStringBuffer(tensorT->data.data());
154   auto ret = WriteStringsToTensor(lite_tensor.get(), string_buffer);
155   if (ret != RET_OK) {
156     MS_LOG(ERROR) << "WriteStringsToTensor failed";
157     *convert_succ = false;
158     return;
159   }
160   lite_tensors->emplace_back(lite_tensor.release());
161 }
162 
ConvertOtherTensor(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)163 void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
164   std::unique_ptr<Tensor> lite_tensor = nullptr;
165   auto &tensorT = graph->allTensors.at(index);
166   auto tensor_shape = tensorT->dims;
167   lite_tensor = std::make_unique<Tensor>(
168     TypeId(tensorT->dataType), tensor_shape, static_cast<mindspore::Format>(tensorT->format),
169     TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
170   if (lite_tensor == nullptr) {
171     MS_LOG(ERROR) << "lite tensor is nullptr";
172     *convert_succ = false;
173     return;
174   }
175   auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
176   // when tensorT as param input
177   if (lite_tensor_size == 0) {
178     lite_tensors->emplace_back(lite_tensor.release());
179     return;
180   }
181   lite_tensor->set_data(tensorT->data.data());
182   lite_tensors->emplace_back(lite_tensor.release());
183 }
184 
ConvertTensorToLiteTensor(MetaGraphT * graph,const std::vector<uint32_t> & tensor_indexs)185 std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs) {
186   MS_ASSERT(graph != nullptr);
187   std::vector<Tensor *> lite_tensors;
188   bool convert_succ = true;
189   for (size_t i = 0; i < tensor_indexs.size(); i++) {
190     auto &tensorT = graph->allTensors.at(tensor_indexs[i]);
191     switch (tensorT->dataType) {
192       case kObjectTypeTensorType:
193         ConvertTensorList(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
194         break;
195       case kObjectTypeString:
196         ConvertString(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
197         break;
198       default:
199         ConvertOtherTensor(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
200         break;
201     }
202   }
203   if (!convert_succ) {
204     FreeTensors(&lite_tensors, {});
205     return {};
206   }
207   return lite_tensors;
208 }
209 
NodeInferShape(const std::unique_ptr<schema::CNodeT> & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)210 STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
211                       std::vector<Tensor *> *outputs) {
212   flatbuffers::FlatBufferBuilder fbb(kInitialSize);
213   auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
214   if (prim == nullptr) {
215     MS_LOG(ERROR) << "get primitive failed.";
216     fbb.Clear();
217     return RET_ERROR;
218   }
219 
220   auto ret = KernelInferShape(inputs, *outputs, prim, {}, SCHEMA_CUR);
221   if (ret == lite::RET_NOT_SUPPORT) {
222     auto parameter_gen =
223       lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim->value_type()), SCHEMA_CUR);
224     if (parameter_gen == nullptr) {
225       fbb.Clear();
226       MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
227       return RET_ERROR;
228     }
229     auto parameter = parameter_gen(prim);
230     if (parameter == nullptr) {
231       fbb.Clear();
232       MS_LOG(ERROR) << "parameter is nullptr.";
233       return RET_ERROR;
234     }
235     parameter->quant_type_ = static_cast<int>(node->quantType);
236     ret = KernelInferShape(inputs, *outputs, parameter);
237     if (parameter->destroy_func_ != nullptr) {
238       parameter->destroy_func_(parameter);
239     }
240     free(parameter);
241     parameter = nullptr;
242   }
243 
244   fbb.Clear();
245   return ret;
246 }
247 
248 #ifdef Debug
PrintTensorShape(const std::vector<Tensor * > & input_tensors,const std::vector<Tensor * > & output_tensors)249 void PrintTensorShape(const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors) {
250   int i = 0;
251   for (auto input_tensor : input_tensors) {
252     std::ostringstream oss;
253     for (auto &dim : input_tensor->shape()) {
254       oss << " " << dim;
255     }
256     MS_LOG(DEBUG) << "input shape " << i++ << ":" << oss.str();
257   }
258   i = 0;
259   for (auto output_tensor : output_tensors) {
260     std::ostringstream oss;
261     for (auto &dim : output_tensor->shape()) {
262       oss << " " << dim;
263     }
264     MS_LOG(DEBUG) << "output shape" << i++ << ":" << oss.str();
265   }
266 }
267 #endif
268 
SetDataType(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,std::vector<InferTensor> * tensors,uint32_t i,uint32_t infer_node_index)269 int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors, std::vector<InferTensor> *tensors,
270                 uint32_t i, uint32_t infer_node_index) {
271   auto &node = graph->nodes.at(infer_node_index);
272   auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
273   output_tensor->format = static_cast<schema::Format>(output_tensors[i]->format());
274   output_tensor->dataType = output_tensors[i]->data_type();
275   if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
276     auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
277     int tensor_shape_dims = 0;
278     if (!tensor_list->tensors().empty()) {
279       tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
280     }
281     MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDatasize), static_cast<int>(sizeof(int))),
282                        RET_ERROR, "int mul overflow");
283     auto total_size = (tensor_shape_dims + kTensorListDatasize) * sizeof(int);
284     output_tensor->data.resize(total_size, 0);
285     auto output_tensor_data = reinterpret_cast<int *>(output_tensor->data.data());
286     if (tensor_list->tensors_data_type() == kTypeUnknown) {
287       if (!tensor_list->tensors().empty()) {
288         tensor_list->set_tensors_data_type(tensor_list->tensors().front()->data_type());
289       }
290     }
291     output_tensor_data[kTypeIndex] = tensor_list->tensors_data_type();
292     if (tensor_list->element_shape().empty() && !tensor_list->tensors().empty()) {
293       tensor_list->set_element_shape(tensor_list->tensors().front()->shape());
294     }
295     output_tensor_data[kElementShapeIndex] = static_cast<int>(tensor_list->element_shape().size());
296     for (size_t j = 0; j < tensor_list->element_shape().size(); ++j) {
297       output_tensor_data[j + kFirstElementShapeIndex] = tensor_list->element_shape().at(j);
298     }
299     output_tensor_data[kFirstElementShapeIndex + output_tensor_data[kElementShapeIndex]] =
300       static_cast<int>(tensor_list->tensors().size());
301   } else if (output_tensors[i]->data_type() == kTypeUnknown) {
302     tensors->at(node->outputIndex[i]).is_inferred_ = false;
303     return RET_OK;
304   }
305   tensors->at(node->outputIndex[i]).is_inferred_ = true;
306   return RET_OK;
307 }
308 
PartialGraphIndex(const CNodeT * partial_node)309 int PartialGraphIndex(const CNodeT *partial_node) {
310   return partial_node->primitive->value.AsPartialFusion()->sub_graph_index;
311 }
312 }  // namespace
313 
CopyPartialShapeToSubGraph(const CNodeT * partial_node,MetaGraphT * graph)314 int InferShapePass::CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph) {
315   auto subgraph_index = PartialGraphIndex(partial_node);
316   auto &subgraph = graph->subGraph.at(subgraph_index);
317 
318   if (subgraph->inputIndices.size() != partial_node->inputIndex.size()) {
319     MS_LOG(ERROR) << "partial node " << partial_node->name << " inputs size: " << partial_node->inputIndex.size()
320                   << " vs "
321                   << " subgraph " << subgraph_index << " input size: " << subgraph->inputIndices.size();
322     return RET_PARAM_INVALID;
323   }
324 
325   for (size_t i = 0; i < partial_node->inputIndex.size(); ++i) {
326     auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
327     auto &partial_input = graph->allTensors.at(partial_node->inputIndex[i]);
328     subgraph_input->dataType = partial_input->dataType;
329     subgraph_input->dims = partial_input->dims;
330     subgraph_input->format = partial_input->format;
331     subgraph_input->data.resize(partial_input->data.size(), 0);
332     if (partial_input->data.empty()) {
333       continue;
334     }
335     auto ret = memcpy_s(subgraph_input->data.data(), subgraph_input->data.size(), partial_input->data.data(),
336                         partial_input->data.size());
337     if (ret != EOK) {
338       MS_LOG(ERROR) << "memcpy failed, ret: " << ret;
339       return RET_ERROR;
340     }
341   }
342   return RET_OK;
343 }
344 
RestoreSubGraphInput(const CNodeT * partial_node,MetaGraphT * graph)345 int InferShapePass::RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph) {
346   auto subgraph_index = PartialGraphIndex(partial_node);
347   auto &subgraph = graph->subGraph.at(subgraph_index);
348   for (size_t i = 0; i < subgraph->inputIndices.size(); ++i) {
349     auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
350     if (subgraph_input->dataType != kObjectTypeTensorType) {
351       subgraph_input->data = {};
352     }
353   }
354   return RET_OK;
355 }
356 
InferPartialNode(const CNodeT * partial_node,MetaGraphT * graph)357 int InferShapePass::InferPartialNode(const CNodeT *partial_node, MetaGraphT *graph) {
358   int subgraph_index = PartialGraphIndex(partial_node);
359   int ret = CopyPartialShapeToSubGraph(partial_node, graph);
360   if (ret != RET_OK) {
361     MS_LOG(ERROR) << "CopyPartialShapeToSubGraph failed, ret: " << ret;
362     return ret;
363   }
364 
365   ret = InferSubgraph(subgraph_index, graph);
366   if (ret != RET_OK) {
367     // not return ret here to infer the following part of graph
368     MS_LOG(WARNING) << "InferSubgraph index: " << subgraph_index << " failed, ret: " << ret;
369   }
370 
371   ret = RestoreSubGraphInput(partial_node, graph);
372   if (ret != RET_OK) {
373     MS_LOG(ERROR) << "RestoreSubGraphInput failed, ret: " << ret;
374   }
375   return ret;
376 }
377 
InitInferTensor(MetaGraphT * graph)378 void InferShapePass::InitInferTensor(MetaGraphT *graph) {
379   tensors_.resize(graph->allTensors.size());
380   for (size_t i = 0; i < graph->nodes.size(); i++) {
381     auto &node = graph->nodes.at(i);
382     auto node_input_indexes = node->inputIndex;
383     //  init in_nodes index
384     for (unsigned int node_input_indexe : node_input_indexes) {
385       tensors_[node_input_indexe].next_nodes_.push_back(i);
386     }
387     auto node_output_indexes = node->outputIndex;
388     for (unsigned int node_output_indexe : node_output_indexes) {
389       tensors_[node_output_indexe].prev_nodes_.push_back(i);
390     }
391   }
392 
393   for (auto input_idx : graph->inputIndex) {
394     auto input_tensor = graph->allTensors[input_idx].get();
395     for (auto &dim : input_tensor->dims) {
396       if (dim == 0) {
397         MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
398         dim = DEFAULT_DIM_VALUE;
399       }
400     }
401   }
402 }
403 
InferSwitchNode(const std::unique_ptr<CNodeT> & switch_node,MetaGraphT * graph)404 int InferShapePass::InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph) {
405   if (switch_node->inputIndex.size() < kSwitchInputMinSize) {
406     MS_LOG(ERROR) << "switch node input size: " << switch_node->inputIndex.size() << " is less than three.";
407     return RET_PARAM_INVALID;
408   }
409 
410   static std::set<CNodeT *> partial_cnode_inferred{};
411   std::deque<CNodeT *> to_process{};
412   auto true_branch_output_index = switch_node->inputIndex.at(kSwitchTrueIndex);
413   auto false_branch_output_index = switch_node->inputIndex.at(kSwitchFalseIndex);
414   for (auto &node : graph->nodes) {
415     if (node->primitive->value.type != PrimitiveType_PartialFusion) {
416       continue;
417     }
418     if (IsContain(node->outputIndex, true_branch_output_index) &&
419         partial_cnode_inferred.find(node.get()) == partial_cnode_inferred.end()) {
420       to_process.push_back(node.get());
421       partial_cnode_inferred.insert(node.get());
422       break;
423     }
424   }
425   for (auto &node : graph->nodes) {
426     if (node->primitive->value.type != PrimitiveType_PartialFusion) {
427       continue;
428     }
429     if (IsContain(node->outputIndex, false_branch_output_index) &&
430         partial_cnode_inferred.find(node.get()) == partial_cnode_inferred.end()) {
431       to_process.push_back(node.get());
432       partial_cnode_inferred.insert(node.get());
433       break;
434     }
435   }
436 
437   while (!to_process.empty()) {
438     auto node = to_process.front();
439     to_process.pop_front();
440     int ret = InferPartialNode(node, graph);
441     if (ret != RET_OK) {
442       MS_LOG(WARNING) << "not support partial infer.";
443       return ret;
444     }
445   }
446 
447   return RET_OK;
448 }
449 
InferCallNode(const std::unique_ptr<CNodeT> & call_node,MetaGraphT * graph)450 int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph) {
451   if (call_node->inputIndex.size() < kCallInputMinSize) {
452     MS_LOG(ERROR) << "call node input size: " << call_node->inputIndex.size() << " is less than one.";
453     return RET_PARAM_INVALID;
454   }
455   auto call_first_input_index = call_node->inputIndex.front();
456   bool find_partial = false;
457   bool find_switch = false;
458   for (auto &node : graph->nodes) {
459     if (IsContain(node->outputIndex, call_first_input_index) &&
460         node->primitive->value.type == PrimitiveType_PartialFusion) {
461       find_partial = true;
462       int ret = InferPartialNode(node.get(), graph);
463       if (ret != RET_OK) {
464         MS_LOG(WARNING) << "not support partial infer.";
465         return ret;
466       }
467       break;
468     }
469     if (IsContain(node->outputIndex, call_first_input_index) && node->primitive->value.type == PrimitiveType_Switch) {
470       find_switch = true;
471       int ret = InferSwitchNode(node, graph);
472       if (ret != RET_OK) {
473         MS_LOG(WARNING) << "not support partial infer.";
474         return ret;
475       }
476       break;
477     }
478   }
479   if (!find_partial && !find_switch) {
480     MS_LOG(ERROR) << "not able to call partial or call switch.";
481     return RET_ERROR;
482   }
483   return RET_OK;
484 }
485 
InferSubgraph(const int & subgraph_index,MetaGraphT * graph)486 int InferShapePass::InferSubgraph(const int &subgraph_index, MetaGraphT *graph) {
487   std::vector<uint32_t> infer_node_indexes{};
488   int ret = InitSearchTensor(subgraph_index, graph, &infer_node_indexes);
489   if (ret != RET_OK) {
490     MS_LOG(ERROR) << "InitSearchTensor failed.";
491     return ret;
492   }
493   if (infer_node_indexes.empty()) {
494     MS_LOG(DEBUG) << "no need to infer.";
495     return RET_OK;
496   }
497 
498   while (!infer_node_indexes.empty()) {
499     auto infer_node_index = infer_node_indexes.front();
500     auto &node = graph->nodes.at(infer_node_index);
501     auto node_type = node->primitive->value.type;
502     if (node_type == PrimitiveType_Call) {
503       ret = InferCallNode(node, graph);
504       if (ret != RET_OK) {
505         MS_LOG(ERROR) << "infer call node failed.";
506         return ret;
507       }
508     }
509 
510     infer_node_indexes.erase(infer_node_indexes.begin());
511     auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex);
512     auto output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex);
513     if (output_tensors.empty() || output_tensors.size() != node->outputIndex.size() || input_tensors.empty() ||
514         input_tensors.size() != node->inputIndex.size()) {
515       MS_LOG(ERROR) << "convert lite tensor error";
516       FreeTensors(&input_tensors, &output_tensors);
517       return RET_INFER_ERR;
518     }
519     auto status = NodeInferShape(node, input_tensors, &output_tensors);
520     MS_LOG(DEBUG) << "cur node:" << node->name;
521     if (status == RET_OK || status == RET_INFER_INVALID) {
522 #ifdef Debug
523       PrintTensorShape(input_tensors, output_tensors);
524 #endif
525       // copy output shape to tensorT
526       for (size_t i = 0; i < output_tensors.size(); i++) {
527         auto output_dims = output_tensors[i]->shape();
528         auto &output_tensorT = graph->allTensors.at(node->outputIndex[i]);
529         output_tensorT->dims.swap(output_dims);
530         SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
531       }
532     } else {
533       MS_LOG(WARNING) << "InferShape failed, name: " << node->name
534                       << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
535       FreeTensors(&input_tensors, &output_tensors);
536       return RET_INFER_ERR;
537     }
538     FreeTensors(&input_tensors, &output_tensors);
539     AddOutputNodes(graph, &infer_node_indexes, infer_node_index);
540   }
541   return RET_OK;
542 }
543 
Run(MetaGraphT * graph)544 STATUS InferShapePass::Run(MetaGraphT *graph) {
545   CHECK_NULL_RETURN(graph);
546   InitInferTensor(graph);
547 
548   int ret = InferSubgraph(kMainGraphIndex, graph);
549   if (ret != RET_OK) {
550     MS_LOG(ERROR) << "InferSubgraph index: " << kMainGraphIndex << " failed, ret: " << ret;
551     return ret;
552   }
553 
554   ResetIncorrectTensorShape(graph);
555   return RET_OK;
556 }
557 
InitSearchTensor(const int & subgraph_index,MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes)558 int InferShapePass::InitSearchTensor(const int &subgraph_index, MetaGraphT *graph,
559                                      std::vector<uint32_t> *infer_node_indexes) {
560   if (static_cast<size_t>(subgraph_index) >= graph->subGraph.size()) {
561     MS_LOG(ERROR) << "subgraph_index: " << subgraph_index
562                   << " is larger than graph->subGraph.size(): " << graph->subGraph.size();
563     return RET_ERROR;
564   }
565   auto &subgraph = graph->subGraph.at(subgraph_index);
566   for (uint32_t i = 0; i < tensors_.size(); i++) {
567     if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty()) {
568       tensors_[i].is_inferred_ = true;
569     }
570   }
571   for (size_t i = 0; i < subgraph->nodeIndices.size(); i++) {
572     auto &node = graph->nodes.at(subgraph->nodeIndices.at(i));
573     if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
574                     [&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
575       infer_node_indexes->push_back(subgraph->nodeIndices.at(i));
576     }
577   }
578   return RET_OK;
579 }
580 
AddOutputNodes(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,uint32_t infer_node_index)581 void InferShapePass::AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
582                                     uint32_t infer_node_index) {
583   auto &node = graph->nodes.at(infer_node_index);
584   for (size_t i = 0; i < node->outputIndex.size(); i++) {
585     auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
586     for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
587       auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
588       if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
589                       [&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
590         AddNextInferShapeNode(graph, infer_node_indexes, next_nodes_indexes, j);
591       }
592     }
593   }
594 }
595 
AddNextInferShapeNode(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,std::vector<uint32_t> next_nodes_indexes,size_t index)596 void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
597                                            std::vector<uint32_t> next_nodes_indexes, size_t index) {
598   auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
599   if (find(infer_node_indexes->begin(), infer_node_indexes->end(), next_nodes_indexes[index]) ==
600       infer_node_indexes->end()) {
601     if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
602                     [&](uint32_t i) { return tensors_[i].is_inferred_; })) {
603       infer_node_indexes->push_back(next_nodes_indexes[index]);
604     }
605   }
606 }
607 
ResetIncorrectTensorShape(MetaGraphT * graph)608 void InferShapePass::ResetIncorrectTensorShape(MetaGraphT *graph) {
609   MS_ASSERT(graph != nullptr);
610   for (auto &node : graph->nodes) {
611     auto out_tensors_index = node->outputIndex;
612     for (auto index : out_tensors_index) {
613       auto &tensor = graph->allTensors.at(index);
614       auto shape = tensor->dims;
615       if (shape == std::vector{-1}) {
616         tensor->dims = {};
617       }
618     }
619   }
620 }
621 }  // namespace lite
622 }  // namespace mindspore
623