• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
19 #include <vector>
20 #include <deque>
21 #include "src/common/common.h"
22 #include "src/common/log_adapter.h"
23 #include "include/errorcode.h"
24 #include "src/tensor.h"
25 #include "src/tensorlist.h"
26 #include "src/common/prim_util.h"
27 #include "src/common/ops/populate/populate_register.h"
28 #include "src/litert/infer_manager.h"
29 #include "src/common/primitive_t_utils.h"
30 #include "tools/common/node_util.h"
31 #include "src/common/string_utils.h"
32 #include "src/common/log_util.h"
33 #include "nnacl/op_base.h"
34 
35 using mindspore::converter::kFmkTypeTf;
36 namespace {
37 constexpr int DEFAULT_DIM_VALUE = -1;
38 constexpr size_t kInitialSize = 1024;
39 constexpr int kMainGraphIndex = 0;
40 constexpr int kCallInputMinSize = 1;
41 constexpr int kSwitchInputMinSize = 3;
42 constexpr int kTypeIndex = 0;
43 constexpr int kElementShapeIndex = 1;
44 constexpr int kFirstElementShapeIndex = 2;
45 constexpr int kTensorListDataSize = 3;
46 }  // namespace
47 namespace mindspore {
48 namespace lite {
49 namespace {
FreeTensors(std::vector<Tensor * > * input_tensors,std::vector<Tensor * > * output_tensors)50 void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
51   if (input_tensors == nullptr) {
52     return;
53   }
54   for (auto &tensor : *input_tensors) {
55     if (tensor == nullptr) {
56       continue;
57     }
58     if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
59       tensor->set_data(nullptr);
60     }
61     delete tensor;
62     tensor = nullptr;
63   }
64   if (output_tensors == nullptr) {
65     return;
66   }
67   for (auto &tensor : *output_tensors) {
68     if (tensor == nullptr) {
69       continue;
70     }
71     if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
72       tensor->set_data(nullptr);
73     }
74     delete tensor;
75     tensor = nullptr;
76   }
77   input_tensors->resize(0);
78   output_tensors->resize(0);
79 }
80 
81 namespace {
82 constexpr int kBytesPerInt = 4;
83 }  // namespace
84 
ConvertTensorList(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)85 void ConvertTensorList(const MetaGraphT *graph, uint32_t index, bool *convert_succ,
86                        std::vector<Tensor *> *lite_tensors) {
87   if (graph == nullptr) {
88     MS_LOG(ERROR) << "graph is nullptr";
89     return;
90   }
91   std::unique_ptr<Tensor> lite_tensor = nullptr;
92   auto &tensorT = graph->allTensors.at(index);
93   std::vector<int32_t> tensor_shape{};
94   TypeId type = kTypeUnknown;
95   std::vector<int> element_shape;
96   if (tensorT->data.size() >= kBytesPerInt) {
97     int *data = reinterpret_cast<int *>(tensorT->data.data());
98     type = TypeId(data[kTypeIndex]);
99     auto basic_data_size = tensorT->data.size() / sizeof(int);
100     if (basic_data_size < static_cast<size_t>(kTensorListDataSize)) {
101       MS_LOG(ERROR) << "tensorlist data length illegal, which should be at least 3, now is " << basic_data_size;
102       *convert_succ = false;
103       return;
104     }
105     if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDataSize)) {
106       MS_LOG(ERROR) << "int add overflow.";
107       *convert_succ = false;
108       return;
109     }
110     if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDataSize)) > basic_data_size) {
111       MS_LOG(ERROR) << "tensorlist data length illegal. current tensorlist data length should be at least "
112                     << (data[kElementShapeIndex] + kTensorListDataSize) << ", but now is " << basic_data_size;
113       *convert_succ = false;
114       return;
115     }
116     auto element_num = data[data[kElementShapeIndex] + kFirstElementShapeIndex];
117     if (element_num > 0 && INT_ADD_OVERFLOW(element_num, 1)) {
118       MS_LOG(ERROR) << "int add overflow.";
119       *convert_succ = false;
120       return;
121     }
122     auto shape_once = data[kElementShapeIndex] + 1;
123     auto shape_group_num = element_num < 0 ? 1 : element_num + 1;
124     if (INT_MUL_OVERFLOW(shape_once, shape_group_num)) {
125       MS_LOG(ERROR) << "int mul overflow.";
126       *convert_succ = false;
127       return;
128     }
129     tensor_shape = {element_num};
130     auto shape_info_size = shape_once * shape_group_num;
131     if (INT_ADD_OVERFLOW(shape_info_size, kFirstElementShapeIndex)) {
132       MS_LOG(ERROR) << "int add overflow.";
133       *convert_succ = false;
134       return;
135     }
136     int real_data_size = shape_info_size + kFirstElementShapeIndex;
137     if (real_data_size <= 0 || static_cast<uint32_t>(real_data_size) != basic_data_size) {
138       MS_LOG(ERROR) << "current tensorlist data length should be " << real_data_size << ", but now is "
139                     << basic_data_size;
140       *convert_succ = false;
141       return;
142     }
143     for (int j = 0; j < data[kElementShapeIndex]; ++j) {
144       element_shape.push_back(data[j + kFirstElementShapeIndex]);
145     }
146   }
147   lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
148   if (lite_tensor == nullptr) {
149     MS_LOG(ERROR) << "lite tensorlist is nullptr";
150     *convert_succ = false;
151     return;
152   }
153 
154   auto lite_tensor_list = reinterpret_cast<TensorList *>(lite_tensor.get());
155   std::vector<Tensor *> tensors{};
156   if (!tensor_shape.empty() && tensor_shape.front() != -1) {
157     for (int32_t i = 0; i < tensor_shape.front(); ++i) {
158       auto tensor = new (std::nothrow) Tensor(type, element_shape);
159       tensors.emplace_back(tensor);
160     }
161   }
162 
163   lite_tensor_list->set_tensors_data_type(type);
164   lite_tensor_list->set_element_shape(element_shape);
165   lite_tensor_list->set_tensors(tensors);
166   lite_tensors->emplace_back(lite_tensor.release());
167 }
168 
169 namespace {
CreateRuntimeTensor(const std::unique_ptr<TensorT> & src_tensor)170 std::unique_ptr<Tensor> CreateRuntimeTensor(const std::unique_ptr<TensorT> &src_tensor) {
171   if (src_tensor == nullptr) {
172     MS_LOG(ERROR) << "src tensor is nullptr";
173     return nullptr;
174   }
175   std::unique_ptr<Tensor> runtime_tensor = nullptr;
176   auto tensor_shape = src_tensor->dims;
177   runtime_tensor = std::make_unique<Tensor>(TypeId(src_tensor->dataType), tensor_shape,
178                                             static_cast<mindspore::Format>(src_tensor->format),
179                                             TensorCategory(src_tensor->nodeType, src_tensor->dims.size(),
180                                                            TypeId(src_tensor->dataType), src_tensor->data.size()));
181   if (runtime_tensor == nullptr) {
182     MS_LOG(ERROR) << "Create runtime tensor failed";
183     return nullptr;
184   }
185   return runtime_tensor;
186 }
187 }  // namespace
188 
ConvertString(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)189 void ConvertString(const MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
190   auto &tensorT = graph->allTensors.at(index);
191   auto runtime_tensor = CreateRuntimeTensor(tensorT);
192   if (runtime_tensor == nullptr) {
193     *convert_succ = false;
194     return;
195   }
196   // when tensorT as param input
197   if (tensorT->data.empty()) {
198     lite_tensors->emplace_back(runtime_tensor.release());
199     return;
200   }
201   auto string_buffer = ParseStringBuffer(tensorT->data.data());
202   auto ret = WriteStringsToTensor(runtime_tensor.get(), string_buffer);
203   if (ret != RET_OK) {
204     MS_LOG(ERROR) << "WriteStringsToTensor failed";
205     *convert_succ = false;
206     return;
207   }
208   lite_tensors->emplace_back(runtime_tensor.release());
209 }
210 
ConvertOtherTensor(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)211 void ConvertOtherTensor(const MetaGraphT *graph, uint32_t index, bool *convert_succ,
212                         std::vector<Tensor *> *lite_tensors) {
213   CHECK_NULL_RETURN_VOID(graph);
214   auto &tensorT = graph->allTensors.at(index);
215   auto runtime_tensor = CreateRuntimeTensor(tensorT);
216   if (runtime_tensor == nullptr) {
217     *convert_succ = false;
218     return;
219   }
220   // when tensorT as param input
221   if (tensorT->data.empty()) {
222     lite_tensors->emplace_back(runtime_tensor.release());
223     return;
224   }
225   runtime_tensor->set_data(tensorT->data.data());
226   lite_tensors->emplace_back(runtime_tensor.release());
227 }
228 
ConvertTensorToLiteTensor(const MetaGraphT * graph,const std::vector<uint32_t> & tensor_indexs)229 std::vector<Tensor *> ConvertTensorToLiteTensor(const MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs) {
230   MS_ASSERT(graph != nullptr);
231   std::vector<Tensor *> lite_tensors;
232   bool convert_succ = true;
233   for (unsigned int tensor_index : tensor_indexs) {
234     auto &tensorT = graph->allTensors.at(tensor_index);
235     switch (tensorT->dataType) {
236       case kObjectTypeTensorType:
237         ConvertTensorList(graph, tensor_index, &convert_succ, &lite_tensors);
238         break;
239       case kObjectTypeString:
240         MS_CHECK_TRUE_MSG(tensorT->dims.size() <= 1, {}, "String type tensor dims should be less than or equal to 1.");
241         ConvertString(graph, tensor_index, &convert_succ, &lite_tensors);
242         break;
243       default:
244         ConvertOtherTensor(graph, tensor_index, &convert_succ, &lite_tensors);
245         break;
246     }
247   }
248   if (!convert_succ) {
249     FreeTensors(&lite_tensors, {});
250     return {};
251   }
252   return lite_tensors;
253 }
254 
NodeInferShape(const std::unique_ptr<schema::CNodeT> & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)255 STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
256                       std::vector<Tensor *> *outputs) {
257   flatbuffers::FlatBufferBuilder fbb(kInitialSize);
258   auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
259   if (prim == nullptr) {
260     MS_LOG(ERROR) << "get primitive failed.";
261     fbb.Clear();
262     return RET_ERROR;
263   }
264 
265   auto ret = KernelInferShape(inputs, *outputs, prim, {}, static_cast<int>(SCHEMA_CUR));
266   if (ret == lite::RET_NOT_SUPPORT) {
267     auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
268       static_cast<int>(prim->value_type()), static_cast<int>(SCHEMA_CUR));
269     if (parameter_gen == nullptr) {
270       fbb.Clear();
271       MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
272       return RET_ERROR;
273     }
274     auto parameter = parameter_gen(prim);
275     if (parameter == nullptr) {
276       fbb.Clear();
277       MS_LOG(ERROR) << "parameter is nullptr.";
278       return RET_ERROR;
279     }
280     parameter->quant_type_ = static_cast<int>(node->quantType);
281     ret = KernelInferShape(inputs, *outputs, parameter);
282     if (parameter->destroy_func_ != nullptr) {
283       parameter->destroy_func_(parameter);
284     }
285     free(parameter);
286     parameter = nullptr;
287   }
288 
289   fbb.Clear();
290   return ret;
291 }
292 
293 #ifdef Debug
PrintTensorShape(const std::vector<Tensor * > & input_tensors,const std::vector<Tensor * > & output_tensors)294 void PrintTensorShape(const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors) {
295   int i = 0;
296   for (auto input_tensor : input_tensors) {
297     std::ostringstream oss;
298     for (auto &dim : input_tensor->shape()) {
299       oss << " " << dim;
300     }
301     MS_LOG(DEBUG) << "input shape " << i++ << ":" << oss.str();
302   }
303   i = 0;
304   for (auto output_tensor : output_tensors) {
305     std::ostringstream oss;
306     for (auto &dim : output_tensor->shape()) {
307       oss << " " << dim;
308     }
309     MS_LOG(DEBUG) << "output shape" << i++ << ":" << oss.str();
310   }
311 }
312 #endif
313 
SetDataType(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,const std::unique_ptr<mindspore::schema::CNodeT> & node,std::vector<InferTensor> * tensors,size_t i)314 int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
315                 const std::unique_ptr<mindspore::schema::CNodeT> &node, std::vector<InferTensor> *tensors, size_t i) {
316   auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
317   output_tensor->format = static_cast<schema::Format>(output_tensors[i]->format());
318   output_tensor->dataType = output_tensors[i]->data_type();
319   if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
320     auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
321     MSLITE_CHECK_PTR(tensor_list);
322     int tensor_shape_dims = 0;
323     if (!tensor_list->tensors().empty()) {
324       tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
325     }
326     MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDataSize), static_cast<int>(sizeof(int))),
327                        RET_ERROR, "int mul overflow");
328     if (tensor_list->tensors_data_type() == kTypeUnknown) {
329       if (!tensor_list->tensors().empty()) {
330         tensor_list->set_tensors_data_type(tensor_list->tensors().front()->data_type());
331       }
332     }
333     std::vector<int> basic_data;
334     basic_data.push_back(tensor_list->tensors_data_type());
335     if (tensor_list->element_shape().empty() && !tensor_list->tensors().empty()) {
336       tensor_list->set_element_shape(tensor_list->tensors().front()->shape());
337     }
338     basic_data.push_back(tensor_list->element_shape().size());
339     for (size_t j = 0; j < tensor_list->element_shape().size(); ++j) {
340       basic_data.push_back(tensor_list->element_shape().at(j));
341     }
342     basic_data.push_back(tensor_list->tensors().size());
343     for (size_t index = 0; index < tensor_list->tensors().size(); ++index) {
344       auto tensor_shape = tensor_list->GetTensor(static_cast<int>(index))->shape();
345       basic_data.push_back(tensor_shape.size());
346       for (size_t j = 0; j < tensor_shape.size(); ++j) {
347         basic_data.push_back(tensor_shape[j]);
348       }
349     }
350     output_tensor->data.resize(basic_data.size() * sizeof(int));
351     if (memcpy_s(output_tensor->data.data(), output_tensor->data.size(), basic_data.data(),
352                  basic_data.size() * sizeof(int)) != EOK) {
353       MS_LOG(ERROR) << "memcpy data failed.";
354       return RET_ERROR;
355     }
356   } else if (output_tensors[i]->data_type() == kTypeUnknown) {
357     tensors->at(node->outputIndex[i]).is_inferred_ = false;
358     return RET_OK;
359   }
360   tensors->at(node->outputIndex[i]).is_inferred_ = true;
361   return RET_OK;
362 }
363 
CopyOutputInfoToTensorT(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,const std::unique_ptr<mindspore::schema::CNodeT> & node,std::vector<InferTensor> * tensors)364 int CopyOutputInfoToTensorT(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
365                             const std::unique_ptr<mindspore::schema::CNodeT> &node, std::vector<InferTensor> *tensors) {
366   for (uint32_t i = 0; i < output_tensors.size(); i++) {
367     auto output_dims = output_tensors[i]->shape();
368     auto &output_tensorT = graph->allTensors.at(node->outputIndex[i]);
369     MSLITE_CHECK_PTR(output_tensorT);
370     output_tensorT->dims.swap(output_dims);
371     if (SetDataType(graph, output_tensors, node, tensors, i) != RET_OK) {
372       MS_LOG(ERROR) << "SetDataType failed.";
373       return RET_ERROR;
374     }
375   }
376   return RET_OK;
377 }
378 
PartialGraphIndex(const CNodeT * partial_node)379 int64_t PartialGraphIndex(const CNodeT *partial_node) {
380   MSLITE_CHECK_PTR(partial_node);
381   return partial_node->primitive->value.AsPartialFusion()->sub_graph_index;
382 }
383 }  // namespace
384 
CopyPartialShapeToSubGraph(const CNodeT * partial_node,MetaGraphT * graph)385 int InferShapePass::CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph) {
386   auto subgraph_index = PartialGraphIndex(partial_node);
387   auto &subgraph = graph->subGraph.at(subgraph_index);
388 
389   if (subgraph->inputIndices.size() != partial_node->inputIndex.size()) {
390     MS_LOG(ERROR) << "partial node " << partial_node->name << " inputs size: " << partial_node->inputIndex.size()
391                   << " vs "
392                   << " subgraph " << subgraph_index << " input size: " << subgraph->inputIndices.size();
393     return RET_PARAM_INVALID;
394   }
395 
396   for (size_t i = 0; i < partial_node->inputIndex.size(); ++i) {
397     auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
398     MSLITE_CHECK_PTR(subgraph_input);
399     auto &partial_input = graph->allTensors.at(partial_node->inputIndex[i]);
400     MSLITE_CHECK_PTR(partial_input);
401     subgraph_input->dataType = partial_input->dataType;
402     subgraph_input->dims = partial_input->dims;
403     subgraph_input->format = partial_input->format;
404     subgraph_input->data.resize(partial_input->data.size(), 0);
405     if (partial_input->data.empty()) {
406       continue;
407     }
408     auto ret = memcpy_s(subgraph_input->data.data(), subgraph_input->data.size(), partial_input->data.data(),
409                         partial_input->data.size());
410     if (ret != EOK) {
411       MS_LOG(ERROR) << "memcpy failed, ret: " << ret;
412       return RET_ERROR;
413     }
414   }
415   return RET_OK;
416 }
417 
RestoreSubGraphInput(const CNodeT * partial_node,MetaGraphT * graph)418 void InferShapePass::RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph) {
419   auto subgraph_index = PartialGraphIndex(partial_node);
420   auto &subgraph = graph->subGraph.at(subgraph_index);
421   for (size_t i = 0; i < subgraph->inputIndices.size(); ++i) {
422     auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
423     if (subgraph_input->dataType != kObjectTypeTensorType) {
424       subgraph_input->data = {};
425     }
426   }
427 }
428 
SetNonTailCallOutputShape(const std::unique_ptr<CNodeT> & call_node,const CNodeT * partial_node,MetaGraphT * graph)429 int InferShapePass::SetNonTailCallOutputShape(const std::unique_ptr<CNodeT> &call_node, const CNodeT *partial_node,
430                                               MetaGraphT *graph) {
431   auto subgraph_index = PartialGraphIndex(partial_node);
432   auto &subgraph = graph->subGraph.at(subgraph_index);
433   size_t call_node_output_size = call_node->outputIndex.size();
434   size_t subgraph_output_size = subgraph->outputIndices.size();
435   if (subgraph_output_size != call_node_output_size) {
436     MS_LOG(ERROR) << "call node output size: " << call_node_output_size
437                   << " is same as corresponding subgraph output size: " << subgraph_output_size;
438     return RET_ERROR;
439   }
440   for (size_t i = 0; i < subgraph_output_size; ++i) {
441     auto &subgraph_output_tensor = graph->allTensors.at(subgraph->outputIndices[i]);
442     auto &call_output_tensor = graph->allTensors.at(call_node->outputIndex[i]);
443     call_output_tensor->format = subgraph_output_tensor->format;
444     call_output_tensor->dims = subgraph_output_tensor->dims;
445     call_output_tensor->dataType = subgraph_output_tensor->dataType;
446   }
447   return RET_OK;
448 }
449 
InferPartialNode(const bool & is_tail_call,const std::unique_ptr<CNodeT> & call_node,const CNodeT * partial_node,MetaGraphT * graph)450 int InferShapePass::InferPartialNode(const bool &is_tail_call, const std::unique_ptr<CNodeT> &call_node,
451                                      const CNodeT *partial_node, MetaGraphT *graph) {
452   int64_t subgraph_index = PartialGraphIndex(partial_node);
453   int ret = CopyPartialShapeToSubGraph(partial_node, graph);
454   if (ret != RET_OK) {
455     MS_LOG(ERROR) << "CopyPartialShapeToSubGraph failed, ret: " << ret;
456     return ret;
457   }
458 
459   ret = InferSubgraph(subgraph_index, graph);
460   if (ret != RET_OK) {
461     // not return ret here to infer the following part of graph
462     MS_LOG(WARNING) << "InferSubgraph index: " << subgraph_index << " failed, ret: " << ret;
463   }
464 
465   RestoreSubGraphInput(partial_node, graph);
466 
467   if (!is_tail_call) {
468     ret = SetNonTailCallOutputShape(call_node, partial_node, graph);
469     if (ret != RET_OK) {
470       MS_LOG(ERROR) << "SetNonTailCallOutputShape failed.";
471       return ret;
472     }
473   }
474   return ret;
475 }
476 
InitInferTensor(MetaGraphT * graph)477 void InferShapePass::InitInferTensor(MetaGraphT *graph) {
478   tensors_.resize(graph->allTensors.size());
479   for (size_t i = 0; i < graph->nodes.size(); i++) {
480     auto &node = graph->nodes.at(i);
481     auto node_input_indexes = node->inputIndex;
482     //  init in_nodes index
483     for (unsigned int node_input_indexe : node_input_indexes) {
484       tensors_[node_input_indexe].next_nodes_.push_back(i);
485     }
486     auto node_output_indexes = node->outputIndex;
487     for (unsigned int node_output_indexe : node_output_indexes) {
488       tensors_[node_output_indexe].prev_nodes_.push_back(i);
489     }
490   }
491 
492   for (auto input_idx : graph->inputIndex) {
493     auto input_tensor = graph->allTensors[input_idx].get();
494     CHECK_NULL_RETURN_VOID(input_tensor);
495     for (auto &dim : input_tensor->dims) {
496       if (dim == 0) {
497         MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
498         dim = DEFAULT_DIM_VALUE;
499       }
500     }
501   }
502 }
503 
InferSwitchOrSwitchLayerNode(const bool & is_tail_call,const std::unique_ptr<CNodeT> & call_node,const std::unique_ptr<CNodeT> & aim_node,MetaGraphT * graph)504 int InferShapePass::InferSwitchOrSwitchLayerNode(const bool &is_tail_call, const std::unique_ptr<CNodeT> &call_node,
505                                                  const std::unique_ptr<CNodeT> &aim_node, MetaGraphT *graph) {
506   if (aim_node->inputIndex.size() < kSwitchInputMinSize) {
507     MS_LOG(ERROR) << "switch or switch_layer node input size: " << aim_node->inputIndex.size() << " is less than 3.";
508     return RET_PARAM_INVALID;
509   }
510 
511   size_t aim_node_input_size = aim_node->inputIndex.size();
512   std::vector<uint32_t> all_partial_index{};
513   for (size_t i = 1; i < aim_node_input_size; ++i) {
514     all_partial_index.push_back(aim_node->inputIndex.at(i));
515   }
516 
517   std::vector<CNodeT *> all_partial_nodes{};
518   for (auto &partial_index : all_partial_index) {
519     for (auto &node : graph->nodes) {
520       MSLITE_CHECK_PTR(node);
521       if (node->primitive->value.type != PrimitiveType_PartialFusion) {
522         continue;
523       }
524       if (IsContain(node->outputIndex, partial_index)) {
525         all_partial_nodes.push_back(node.get());
526         break;
527       }
528     }
529   }
530 
531   std::deque<CNodeT *> to_process{};
532   for (auto &partial_node : all_partial_nodes) {
533     if (partial_cnode_inferred_.find(partial_node) == partial_cnode_inferred_.end()) {
534       to_process.push_back(partial_node);
535       (void)partial_cnode_inferred_.insert(partial_node);
536     }
537   }
538 
539   while (!to_process.empty()) {
540     auto node = to_process.front();
541     to_process.pop_front();
542     int ret = InferPartialNode(is_tail_call, call_node, node, graph);
543     if (ret != RET_OK) {
544       MS_LOG(WARNING) << "not support partial infer.";
545       return ret;
546     }
547   }
548 
549   return RET_OK;
550 }
551 
InferCallNode(const std::unique_ptr<CNodeT> & call_node,MetaGraphT * graph)552 int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph) {
553   MSLITE_CHECK_PTR(call_node);
554   if (call_node->inputIndex.size() < kCallInputMinSize) {
555     MS_LOG(ERROR) << "call node input size: " << call_node->inputIndex.size() << " is less than one.";
556     return RET_PARAM_INVALID;
557   }
558   auto call_first_input_index = call_node->inputIndex.front();
559   bool is_tail_call = call_node->primitive->value.AsCall()->is_tail_call;
560   for (auto &node : graph->nodes) {
561     if (!IsContain(node->outputIndex, call_first_input_index)) {
562       continue;
563     }
564     switch (node->primitive->value.type) {
565       case PrimitiveType_PartialFusion:
566         return InferPartialNode(is_tail_call, call_node, node.get(), graph);
567       case PrimitiveType_Switch:
568       case PrimitiveType_SwitchLayer:
569         return InferSwitchOrSwitchLayerNode(is_tail_call, call_node, node, graph);
570       default:
571         MS_LOG(ERROR) << "not able to call partial or call switch.";
572         return RET_ERROR;
573     }
574   }
575   return RET_OK;
576 }
577 
InferSubgraph(const int64_t & subgraph_index,MetaGraphT * graph)578 int InferShapePass::InferSubgraph(const int64_t &subgraph_index, MetaGraphT *graph) {
579   std::vector<uint32_t> infer_node_indexes{};
580   int ret = InitSearchTensor(subgraph_index, graph, &infer_node_indexes);
581   if (ret != RET_OK) {
582     MS_LOG(ERROR) << "InitSearchTensor failed.";
583     return ret;
584   }
585   if (infer_node_indexes.empty()) {
586     MS_LOG(DEBUG) << "no need to infer.";
587     return RET_OK;
588   }
589 
590   while (!infer_node_indexes.empty()) {
591     auto infer_node_index = infer_node_indexes.front();
592     auto &node = graph->nodes.at(infer_node_index);
593     MSLITE_CHECK_PTR(node);
594     infer_node_indexes.erase(infer_node_indexes.begin());
595     if (node->primitive == nullptr) {
596       MS_LOG(WARNING) << "node primitive is nullptr!";
597       continue;
598     }
599     auto node_type = node->primitive->value.type;
600     if (node_type == PrimitiveType_Call) {
601       ret = InferCallNode(node, graph);
602       if (ret != RET_OK) {
603         MS_LOG(ERROR) << "infer call node failed.";
604         return ret;
605       }
606     }
607 
608     auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex);
609     auto output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex);
610     if (output_tensors.empty() || output_tensors.size() != node->outputIndex.size() || input_tensors.empty() ||
611         input_tensors.size() != node->inputIndex.size()) {
612       MS_LOG(ERROR) << "convert lite tensor error";
613       FreeTensors(&input_tensors, &output_tensors);
614       return RET_INFER_ERR;
615     }
616     auto status = NodeInferShape(node, input_tensors, &output_tensors);
617     MS_LOG(DEBUG) << "cur node:" << node->name;
618     if (status == RET_OK || status == RET_INFER_INVALID) {
619 #ifdef Debug
620       PrintTensorShape(input_tensors, output_tensors);
621 #endif
622       ret = CopyOutputInfoToTensorT(graph, output_tensors, node, &tensors_);
623       if (ret != RET_OK) {
624         MS_LOG(ERROR) << "SetDataType failed: " << ret;
625         FreeTensors(&input_tensors, &output_tensors);
626         return RET_INFER_ERR;
627       }
628     } else {
629       MS_LOG(WARNING) << "InferShape failed, name: " << node->name
630                       << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
631       FreeTensors(&input_tensors, &output_tensors);
632       return RET_INFER_ERR;
633     }
634     FreeTensors(&input_tensors, &output_tensors);
635     AddOutputNodes(graph, &infer_node_indexes, infer_node_index);
636   }
637   return RET_OK;
638 }
639 
Run(MetaGraphT * graph)640 STATUS InferShapePass::Run(MetaGraphT *graph) {
641   CHECK_NULL_RETURN(graph);
642   InitInferTensor(graph);
643 
644   int ret = InferSubgraph(kMainGraphIndex, graph);
645   if (ret != RET_OK) {
646     MS_LOG(ERROR) << "InferSubgraph index: " << kMainGraphIndex << " failed, ret: " << ret;
647     return ret;
648   }
649 
650   ResetIncorrectTensorShape(graph);
651   return RET_OK;
652 }
653 
InitSearchTensor(const int64_t & subgraph_index,MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes)654 int InferShapePass::InitSearchTensor(const int64_t &subgraph_index, MetaGraphT *graph,
655                                      std::vector<uint32_t> *infer_node_indexes) {
656   if (static_cast<size_t>(subgraph_index) >= graph->subGraph.size()) {
657     MS_LOG(ERROR) << "subgraph_index: " << subgraph_index
658                   << " is larger than graph->subGraph.size(): " << graph->subGraph.size();
659     return RET_ERROR;
660   }
661   auto &subgraph = graph->subGraph.at(subgraph_index);
662   for (uint32_t i = 0; i < tensors_.size(); i++) {
663     if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty() ||
664         (graph->allTensors.at(i)->nodeType == NodeType_ValueNode && graph->allTensors.at(i)->dims.size() == 1 &&
665          graph->allTensors.at(i)->dims[0] == 0)) {
666       tensors_[i].is_inferred_ = true;
667     }
668   }
669   for (size_t i = 0; i < subgraph->nodeIndices.size(); i++) {
670     auto &node = graph->nodes.at(subgraph->nodeIndices.at(i));
671     if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
672                     [&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
673       infer_node_indexes->push_back(subgraph->nodeIndices.at(i));
674     }
675   }
676   return RET_OK;
677 }
678 
AddOutputNodes(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,uint32_t infer_node_index)679 void InferShapePass::AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
680                                     uint32_t infer_node_index) {
681   auto &node = graph->nodes.at(infer_node_index);
682   for (size_t i = 0; i < node->outputIndex.size(); i++) {
683     auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
684     for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
685       auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
686       if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
687                       [&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
688         AddNextInferShapeNode(graph, infer_node_indexes, next_nodes_indexes, j);
689       }
690     }
691   }
692 }
693 
AddNextInferShapeNode(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,std::vector<uint32_t> next_nodes_indexes,size_t index)694 void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
695                                            std::vector<uint32_t> next_nodes_indexes, size_t index) {
696   auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
697   if (find(infer_node_indexes->begin(), infer_node_indexes->end(), next_nodes_indexes[index]) ==
698       infer_node_indexes->end()) {
699     if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
700                     [&](uint32_t i) { return tensors_[i].is_inferred_; })) {
701       infer_node_indexes->push_back(next_nodes_indexes[index]);
702     }
703   }
704 }
705 
ResetIncorrectTensorShape(MetaGraphT * graph)706 void InferShapePass::ResetIncorrectTensorShape(MetaGraphT *graph) {
707   MS_ASSERT(graph != nullptr);
708   for (auto &node : graph->nodes) {
709     auto out_tensors_index = node->outputIndex;
710     for (auto index : out_tensors_index) {
711       auto &tensor = graph->allTensors.at(index);
712       auto shape = tensor->dims;
713       if (shape == std::vector{-1}) {
714         tensor->dims = {};
715       }
716     }
717   }
718 }
719 }  // namespace lite
720 }  // namespace mindspore
721