1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
18 #include <vector>
19 #include <deque>
20 #include <set>
21 #include "src/common/common.h"
22 #include "src/common/log_adapter.h"
23 #include "include/errorcode.h"
24 #include "src/tensor.h"
25 #include "src/tensorlist.h"
26 #include "src/common/prim_util.h"
27 #include "src/ops/populate/populate_register.h"
28 #include "src/runtime/infer_manager.h"
29 #include "tools/common/node_util.h"
30 #include "tools/converter/converter_flags.h"
31 #include "src/common/string_util.h"
32 #include "src/common/log_util.h"
33 #include "nnacl/op_base.h"
34
35 using mindspore::converter::kFmkTypeTf;
36 namespace mindspore {
37 namespace lite {
38 namespace {
39 constexpr int DEFAULT_DIM_VALUE = -1;
40 constexpr size_t kInitialSize = 1024;
41 constexpr int kMainGraphIndex = 0;
42 constexpr int kCallInputMinSize = 1;
43 constexpr int kSwitchInputMinSize = 3;
44 constexpr int kTypeIndex = 0;
45 constexpr int kElementShapeIndex = 1;
46 constexpr int kFirstElementShapeIndex = 2;
47 constexpr int kTensorListDatasize = 3;
48
FreeTensors(std::vector<Tensor * > * input_tensors,std::vector<Tensor * > * output_tensors)49 void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
50 if (input_tensors == nullptr) {
51 return;
52 }
53 for (auto &tensor : *input_tensors) {
54 if (tensor == nullptr) {
55 continue;
56 }
57 if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
58 tensor->set_data(nullptr);
59 }
60 delete tensor;
61 tensor = nullptr;
62 }
63 if (output_tensors == nullptr) {
64 return;
65 }
66 for (auto &tensor : *output_tensors) {
67 if (tensor == nullptr) {
68 continue;
69 }
70 if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
71 tensor->set_data(nullptr);
72 }
73 delete tensor;
74 tensor = nullptr;
75 }
76 input_tensors->resize(0);
77 output_tensors->resize(0);
78 }
79
ConvertTensorList(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)80 void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
81 std::unique_ptr<Tensor> lite_tensor = nullptr;
82 auto &tensorT = graph->allTensors.at(index);
83 std::vector<int32_t> tensor_shape{};
84 TypeId type = kTypeUnknown;
85 std::vector<int> element_shape;
86 if (!tensorT->data.empty()) {
87 auto data_len = tensorT->data.size();
88 int *data = reinterpret_cast<int *>(tensorT->data.data());
89 type = TypeId(data[kTypeIndex]);
90 if (data_len < kTensorDataSize ||
91 (data[kElementShapeIndex] != 0 && static_cast<int>((data[kElementShapeIndex] + kTensorListDatasize) *
92 sizeof(int)) != static_cast<int>(tensorT->data.size()))) {
93 MS_LOG(ERROR) << "tensorlist data length illegal, tensorT name: " << tensorT->name;
94 MS_LOG(ERROR) << "(data[1] + 3) * sizeof(int): "
95 << ((data[kElementShapeIndex] + kTensorListDatasize) * sizeof(int));
96 MS_LOG(ERROR) << "static_cast<int>(tensorT->data.size()): " << static_cast<int>(tensorT->data.size());
97 *convert_succ = false;
98 return;
99 }
100 for (int j = 0; j < data[kElementShapeIndex]; ++j) {
101 element_shape.push_back(data[j + kFirstElementShapeIndex]);
102 }
103 if (INT_ADD_OVERFLOW(data[kElementShapeIndex], kFirstElementShapeIndex)) {
104 MS_LOG(ERROR) << "int add overflow";
105 *convert_succ = false;
106 return;
107 }
108 tensor_shape = {data[data[kElementShapeIndex] + kFirstElementShapeIndex]};
109 }
110 lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
111 if (lite_tensor == nullptr) {
112 MS_LOG(ERROR) << "lite tensorlist is nullptr";
113 *convert_succ = false;
114 return;
115 }
116
117 auto lite_tensor_list = reinterpret_cast<TensorList *>(lite_tensor.get());
118 std::vector<Tensor *> tensors{};
119 if (!tensor_shape.empty() && tensor_shape.front() == -1) {
120 MS_LOG(INFO) << "tensor_shape is -1, tensor name: " << lite_tensor->tensor_name();
121 }
122 if (!tensor_shape.empty() && tensor_shape.front() != -1) {
123 for (int32_t i = 0; i < tensor_shape.front(); ++i) {
124 auto tensor = new (std::nothrow) Tensor(type, element_shape);
125 tensors.emplace_back(tensor);
126 }
127 }
128
129 lite_tensor_list->set_tensors_data_type(type);
130 lite_tensor_list->set_element_shape(element_shape);
131 lite_tensor_list->set_tensors(tensors);
132 lite_tensors->emplace_back(lite_tensor.release());
133 }
134
ConvertString(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)135 void ConvertString(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
136 std::unique_ptr<Tensor> lite_tensor = nullptr;
137 auto &tensorT = graph->allTensors.at(index);
138 auto tensor_shape = tensorT->dims;
139 lite_tensor = std::make_unique<Tensor>(
140 TypeId(tensorT->dataType), tensor_shape, static_cast<mindspore::Format>(tensorT->format),
141 TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
142 if (lite_tensor == nullptr) {
143 MS_LOG(ERROR) << "lite tensor is nullptr";
144 *convert_succ = false;
145 return;
146 }
147 auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
148 // when tensorT as param input
149 if (lite_tensor_size == 0) {
150 lite_tensors->emplace_back(lite_tensor.release());
151 return;
152 }
153 auto string_buffer = ParseStringBuffer(tensorT->data.data());
154 auto ret = WriteStringsToTensor(lite_tensor.get(), string_buffer);
155 if (ret != RET_OK) {
156 MS_LOG(ERROR) << "WriteStringsToTensor failed";
157 *convert_succ = false;
158 return;
159 }
160 lite_tensors->emplace_back(lite_tensor.release());
161 }
162
ConvertOtherTensor(MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)163 void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
164 std::unique_ptr<Tensor> lite_tensor = nullptr;
165 auto &tensorT = graph->allTensors.at(index);
166 auto tensor_shape = tensorT->dims;
167 lite_tensor = std::make_unique<Tensor>(
168 TypeId(tensorT->dataType), tensor_shape, static_cast<mindspore::Format>(tensorT->format),
169 TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
170 if (lite_tensor == nullptr) {
171 MS_LOG(ERROR) << "lite tensor is nullptr";
172 *convert_succ = false;
173 return;
174 }
175 auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
176 // when tensorT as param input
177 if (lite_tensor_size == 0) {
178 lite_tensors->emplace_back(lite_tensor.release());
179 return;
180 }
181 lite_tensor->set_data(tensorT->data.data());
182 lite_tensors->emplace_back(lite_tensor.release());
183 }
184
ConvertTensorToLiteTensor(MetaGraphT * graph,const std::vector<uint32_t> & tensor_indexs)185 std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs) {
186 MS_ASSERT(graph != nullptr);
187 std::vector<Tensor *> lite_tensors;
188 bool convert_succ = true;
189 for (size_t i = 0; i < tensor_indexs.size(); i++) {
190 auto &tensorT = graph->allTensors.at(tensor_indexs[i]);
191 switch (tensorT->dataType) {
192 case kObjectTypeTensorType:
193 ConvertTensorList(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
194 break;
195 case kObjectTypeString:
196 ConvertString(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
197 break;
198 default:
199 ConvertOtherTensor(graph, tensor_indexs[i], &convert_succ, &lite_tensors);
200 break;
201 }
202 }
203 if (!convert_succ) {
204 FreeTensors(&lite_tensors, {});
205 return {};
206 }
207 return lite_tensors;
208 }
209
NodeInferShape(const std::unique_ptr<schema::CNodeT> & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)210 STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
211 std::vector<Tensor *> *outputs) {
212 flatbuffers::FlatBufferBuilder fbb(kInitialSize);
213 auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
214 if (prim == nullptr) {
215 MS_LOG(ERROR) << "get primitive failed.";
216 fbb.Clear();
217 return RET_ERROR;
218 }
219
220 auto ret = KernelInferShape(inputs, *outputs, prim, {}, SCHEMA_CUR);
221 if (ret == lite::RET_NOT_SUPPORT) {
222 auto parameter_gen =
223 lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim->value_type()), SCHEMA_CUR);
224 if (parameter_gen == nullptr) {
225 fbb.Clear();
226 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
227 return RET_ERROR;
228 }
229 auto parameter = parameter_gen(prim);
230 if (parameter == nullptr) {
231 fbb.Clear();
232 MS_LOG(ERROR) << "parameter is nullptr.";
233 return RET_ERROR;
234 }
235 parameter->quant_type_ = static_cast<int>(node->quantType);
236 ret = KernelInferShape(inputs, *outputs, parameter);
237 if (parameter->destroy_func_ != nullptr) {
238 parameter->destroy_func_(parameter);
239 }
240 free(parameter);
241 parameter = nullptr;
242 }
243
244 fbb.Clear();
245 return ret;
246 }
247
248 #ifdef Debug
PrintTensorShape(const std::vector<Tensor * > & input_tensors,const std::vector<Tensor * > & output_tensors)249 void PrintTensorShape(const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors) {
250 int i = 0;
251 for (auto input_tensor : input_tensors) {
252 std::ostringstream oss;
253 for (auto &dim : input_tensor->shape()) {
254 oss << " " << dim;
255 }
256 MS_LOG(DEBUG) << "input shape " << i++ << ":" << oss.str();
257 }
258 i = 0;
259 for (auto output_tensor : output_tensors) {
260 std::ostringstream oss;
261 for (auto &dim : output_tensor->shape()) {
262 oss << " " << dim;
263 }
264 MS_LOG(DEBUG) << "output shape" << i++ << ":" << oss.str();
265 }
266 }
267 #endif
268
SetDataType(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,std::vector<InferTensor> * tensors,uint32_t i,uint32_t infer_node_index)269 int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors, std::vector<InferTensor> *tensors,
270 uint32_t i, uint32_t infer_node_index) {
271 auto &node = graph->nodes.at(infer_node_index);
272 auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
273 output_tensor->format = static_cast<schema::Format>(output_tensors[i]->format());
274 output_tensor->dataType = output_tensors[i]->data_type();
275 if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
276 auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
277 int tensor_shape_dims = 0;
278 if (!tensor_list->tensors().empty()) {
279 tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
280 }
281 MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDatasize), static_cast<int>(sizeof(int))),
282 RET_ERROR, "int mul overflow");
283 auto total_size = (tensor_shape_dims + kTensorListDatasize) * sizeof(int);
284 output_tensor->data.resize(total_size, 0);
285 auto output_tensor_data = reinterpret_cast<int *>(output_tensor->data.data());
286 if (tensor_list->tensors_data_type() == kTypeUnknown) {
287 if (!tensor_list->tensors().empty()) {
288 tensor_list->set_tensors_data_type(tensor_list->tensors().front()->data_type());
289 }
290 }
291 output_tensor_data[kTypeIndex] = tensor_list->tensors_data_type();
292 if (tensor_list->element_shape().empty() && !tensor_list->tensors().empty()) {
293 tensor_list->set_element_shape(tensor_list->tensors().front()->shape());
294 }
295 output_tensor_data[kElementShapeIndex] = static_cast<int>(tensor_list->element_shape().size());
296 for (size_t j = 0; j < tensor_list->element_shape().size(); ++j) {
297 output_tensor_data[j + kFirstElementShapeIndex] = tensor_list->element_shape().at(j);
298 }
299 output_tensor_data[kFirstElementShapeIndex + output_tensor_data[kElementShapeIndex]] =
300 static_cast<int>(tensor_list->tensors().size());
301 } else if (output_tensors[i]->data_type() == kTypeUnknown) {
302 tensors->at(node->outputIndex[i]).is_inferred_ = false;
303 return RET_OK;
304 }
305 tensors->at(node->outputIndex[i]).is_inferred_ = true;
306 return RET_OK;
307 }
308
PartialGraphIndex(const CNodeT * partial_node)309 int PartialGraphIndex(const CNodeT *partial_node) {
310 return partial_node->primitive->value.AsPartialFusion()->sub_graph_index;
311 }
312 } // namespace
313
CopyPartialShapeToSubGraph(const CNodeT * partial_node,MetaGraphT * graph)314 int InferShapePass::CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph) {
315 auto subgraph_index = PartialGraphIndex(partial_node);
316 auto &subgraph = graph->subGraph.at(subgraph_index);
317
318 if (subgraph->inputIndices.size() != partial_node->inputIndex.size()) {
319 MS_LOG(ERROR) << "partial node " << partial_node->name << " inputs size: " << partial_node->inputIndex.size()
320 << " vs "
321 << " subgraph " << subgraph_index << " input size: " << subgraph->inputIndices.size();
322 return RET_PARAM_INVALID;
323 }
324
325 for (size_t i = 0; i < partial_node->inputIndex.size(); ++i) {
326 auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
327 auto &partial_input = graph->allTensors.at(partial_node->inputIndex[i]);
328 subgraph_input->dataType = partial_input->dataType;
329 subgraph_input->dims = partial_input->dims;
330 subgraph_input->format = partial_input->format;
331 subgraph_input->data.resize(partial_input->data.size(), 0);
332 if (partial_input->data.empty()) {
333 continue;
334 }
335 auto ret = memcpy_s(subgraph_input->data.data(), subgraph_input->data.size(), partial_input->data.data(),
336 partial_input->data.size());
337 if (ret != EOK) {
338 MS_LOG(ERROR) << "memcpy failed, ret: " << ret;
339 return RET_ERROR;
340 }
341 }
342 return RET_OK;
343 }
344
RestoreSubGraphInput(const CNodeT * partial_node,MetaGraphT * graph)345 int InferShapePass::RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph) {
346 auto subgraph_index = PartialGraphIndex(partial_node);
347 auto &subgraph = graph->subGraph.at(subgraph_index);
348 for (size_t i = 0; i < subgraph->inputIndices.size(); ++i) {
349 auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
350 if (subgraph_input->dataType != kObjectTypeTensorType) {
351 subgraph_input->data = {};
352 }
353 }
354 return RET_OK;
355 }
356
InferPartialNode(const CNodeT * partial_node,MetaGraphT * graph)357 int InferShapePass::InferPartialNode(const CNodeT *partial_node, MetaGraphT *graph) {
358 int subgraph_index = PartialGraphIndex(partial_node);
359 int ret = CopyPartialShapeToSubGraph(partial_node, graph);
360 if (ret != RET_OK) {
361 MS_LOG(ERROR) << "CopyPartialShapeToSubGraph failed, ret: " << ret;
362 return ret;
363 }
364
365 ret = InferSubgraph(subgraph_index, graph);
366 if (ret != RET_OK) {
367 // not return ret here to infer the following part of graph
368 MS_LOG(WARNING) << "InferSubgraph index: " << subgraph_index << " failed, ret: " << ret;
369 }
370
371 ret = RestoreSubGraphInput(partial_node, graph);
372 if (ret != RET_OK) {
373 MS_LOG(ERROR) << "RestoreSubGraphInput failed, ret: " << ret;
374 }
375 return ret;
376 }
377
InitInferTensor(MetaGraphT * graph)378 void InferShapePass::InitInferTensor(MetaGraphT *graph) {
379 tensors_.resize(graph->allTensors.size());
380 for (size_t i = 0; i < graph->nodes.size(); i++) {
381 auto &node = graph->nodes.at(i);
382 auto node_input_indexes = node->inputIndex;
383 // init in_nodes index
384 for (unsigned int node_input_indexe : node_input_indexes) {
385 tensors_[node_input_indexe].next_nodes_.push_back(i);
386 }
387 auto node_output_indexes = node->outputIndex;
388 for (unsigned int node_output_indexe : node_output_indexes) {
389 tensors_[node_output_indexe].prev_nodes_.push_back(i);
390 }
391 }
392
393 for (auto input_idx : graph->inputIndex) {
394 auto input_tensor = graph->allTensors[input_idx].get();
395 for (auto &dim : input_tensor->dims) {
396 if (dim == 0) {
397 MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
398 dim = DEFAULT_DIM_VALUE;
399 }
400 }
401 }
402 }
403
InferSwitchNode(const std::unique_ptr<CNodeT> & switch_node,MetaGraphT * graph)404 int InferShapePass::InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph) {
405 if (switch_node->inputIndex.size() < kSwitchInputMinSize) {
406 MS_LOG(ERROR) << "switch node input size: " << switch_node->inputIndex.size() << " is less than three.";
407 return RET_PARAM_INVALID;
408 }
409
410 static std::set<CNodeT *> partial_cnode_inferred{};
411 std::deque<CNodeT *> to_process{};
412 auto true_branch_output_index = switch_node->inputIndex.at(kSwitchTrueIndex);
413 auto false_branch_output_index = switch_node->inputIndex.at(kSwitchFalseIndex);
414 for (auto &node : graph->nodes) {
415 if (node->primitive->value.type != PrimitiveType_PartialFusion) {
416 continue;
417 }
418 if (IsContain(node->outputIndex, true_branch_output_index) &&
419 partial_cnode_inferred.find(node.get()) == partial_cnode_inferred.end()) {
420 to_process.push_back(node.get());
421 partial_cnode_inferred.insert(node.get());
422 break;
423 }
424 }
425 for (auto &node : graph->nodes) {
426 if (node->primitive->value.type != PrimitiveType_PartialFusion) {
427 continue;
428 }
429 if (IsContain(node->outputIndex, false_branch_output_index) &&
430 partial_cnode_inferred.find(node.get()) == partial_cnode_inferred.end()) {
431 to_process.push_back(node.get());
432 partial_cnode_inferred.insert(node.get());
433 break;
434 }
435 }
436
437 while (!to_process.empty()) {
438 auto node = to_process.front();
439 to_process.pop_front();
440 int ret = InferPartialNode(node, graph);
441 if (ret != RET_OK) {
442 MS_LOG(WARNING) << "not support partial infer.";
443 return ret;
444 }
445 }
446
447 return RET_OK;
448 }
449
InferCallNode(const std::unique_ptr<CNodeT> & call_node,MetaGraphT * graph)450 int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph) {
451 if (call_node->inputIndex.size() < kCallInputMinSize) {
452 MS_LOG(ERROR) << "call node input size: " << call_node->inputIndex.size() << " is less than one.";
453 return RET_PARAM_INVALID;
454 }
455 auto call_first_input_index = call_node->inputIndex.front();
456 bool find_partial = false;
457 bool find_switch = false;
458 for (auto &node : graph->nodes) {
459 if (IsContain(node->outputIndex, call_first_input_index) &&
460 node->primitive->value.type == PrimitiveType_PartialFusion) {
461 find_partial = true;
462 int ret = InferPartialNode(node.get(), graph);
463 if (ret != RET_OK) {
464 MS_LOG(WARNING) << "not support partial infer.";
465 return ret;
466 }
467 break;
468 }
469 if (IsContain(node->outputIndex, call_first_input_index) && node->primitive->value.type == PrimitiveType_Switch) {
470 find_switch = true;
471 int ret = InferSwitchNode(node, graph);
472 if (ret != RET_OK) {
473 MS_LOG(WARNING) << "not support partial infer.";
474 return ret;
475 }
476 break;
477 }
478 }
479 if (!find_partial && !find_switch) {
480 MS_LOG(ERROR) << "not able to call partial or call switch.";
481 return RET_ERROR;
482 }
483 return RET_OK;
484 }
485
InferSubgraph(const int & subgraph_index,MetaGraphT * graph)486 int InferShapePass::InferSubgraph(const int &subgraph_index, MetaGraphT *graph) {
487 std::vector<uint32_t> infer_node_indexes{};
488 int ret = InitSearchTensor(subgraph_index, graph, &infer_node_indexes);
489 if (ret != RET_OK) {
490 MS_LOG(ERROR) << "InitSearchTensor failed.";
491 return ret;
492 }
493 if (infer_node_indexes.empty()) {
494 MS_LOG(DEBUG) << "no need to infer.";
495 return RET_OK;
496 }
497
498 while (!infer_node_indexes.empty()) {
499 auto infer_node_index = infer_node_indexes.front();
500 auto &node = graph->nodes.at(infer_node_index);
501 auto node_type = node->primitive->value.type;
502 if (node_type == PrimitiveType_Call) {
503 ret = InferCallNode(node, graph);
504 if (ret != RET_OK) {
505 MS_LOG(ERROR) << "infer call node failed.";
506 return ret;
507 }
508 }
509
510 infer_node_indexes.erase(infer_node_indexes.begin());
511 auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex);
512 auto output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex);
513 if (output_tensors.empty() || output_tensors.size() != node->outputIndex.size() || input_tensors.empty() ||
514 input_tensors.size() != node->inputIndex.size()) {
515 MS_LOG(ERROR) << "convert lite tensor error";
516 FreeTensors(&input_tensors, &output_tensors);
517 return RET_INFER_ERR;
518 }
519 auto status = NodeInferShape(node, input_tensors, &output_tensors);
520 MS_LOG(DEBUG) << "cur node:" << node->name;
521 if (status == RET_OK || status == RET_INFER_INVALID) {
522 #ifdef Debug
523 PrintTensorShape(input_tensors, output_tensors);
524 #endif
525 // copy output shape to tensorT
526 for (size_t i = 0; i < output_tensors.size(); i++) {
527 auto output_dims = output_tensors[i]->shape();
528 auto &output_tensorT = graph->allTensors.at(node->outputIndex[i]);
529 output_tensorT->dims.swap(output_dims);
530 SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
531 }
532 } else {
533 MS_LOG(WARNING) << "InferShape failed, name: " << node->name
534 << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
535 FreeTensors(&input_tensors, &output_tensors);
536 return RET_INFER_ERR;
537 }
538 FreeTensors(&input_tensors, &output_tensors);
539 AddOutputNodes(graph, &infer_node_indexes, infer_node_index);
540 }
541 return RET_OK;
542 }
543
Run(MetaGraphT * graph)544 STATUS InferShapePass::Run(MetaGraphT *graph) {
545 CHECK_NULL_RETURN(graph);
546 InitInferTensor(graph);
547
548 int ret = InferSubgraph(kMainGraphIndex, graph);
549 if (ret != RET_OK) {
550 MS_LOG(ERROR) << "InferSubgraph index: " << kMainGraphIndex << " failed, ret: " << ret;
551 return ret;
552 }
553
554 ResetIncorrectTensorShape(graph);
555 return RET_OK;
556 }
557
InitSearchTensor(const int & subgraph_index,MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes)558 int InferShapePass::InitSearchTensor(const int &subgraph_index, MetaGraphT *graph,
559 std::vector<uint32_t> *infer_node_indexes) {
560 if (static_cast<size_t>(subgraph_index) >= graph->subGraph.size()) {
561 MS_LOG(ERROR) << "subgraph_index: " << subgraph_index
562 << " is larger than graph->subGraph.size(): " << graph->subGraph.size();
563 return RET_ERROR;
564 }
565 auto &subgraph = graph->subGraph.at(subgraph_index);
566 for (uint32_t i = 0; i < tensors_.size(); i++) {
567 if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty()) {
568 tensors_[i].is_inferred_ = true;
569 }
570 }
571 for (size_t i = 0; i < subgraph->nodeIndices.size(); i++) {
572 auto &node = graph->nodes.at(subgraph->nodeIndices.at(i));
573 if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
574 [&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
575 infer_node_indexes->push_back(subgraph->nodeIndices.at(i));
576 }
577 }
578 return RET_OK;
579 }
580
AddOutputNodes(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,uint32_t infer_node_index)581 void InferShapePass::AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
582 uint32_t infer_node_index) {
583 auto &node = graph->nodes.at(infer_node_index);
584 for (size_t i = 0; i < node->outputIndex.size(); i++) {
585 auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
586 for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
587 auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
588 if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
589 [&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
590 AddNextInferShapeNode(graph, infer_node_indexes, next_nodes_indexes, j);
591 }
592 }
593 }
594 }
595
AddNextInferShapeNode(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,std::vector<uint32_t> next_nodes_indexes,size_t index)596 void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
597 std::vector<uint32_t> next_nodes_indexes, size_t index) {
598 auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
599 if (find(infer_node_indexes->begin(), infer_node_indexes->end(), next_nodes_indexes[index]) ==
600 infer_node_indexes->end()) {
601 if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
602 [&](uint32_t i) { return tensors_[i].is_inferred_; })) {
603 infer_node_indexes->push_back(next_nodes_indexes[index]);
604 }
605 }
606 }
607
ResetIncorrectTensorShape(MetaGraphT * graph)608 void InferShapePass::ResetIncorrectTensorShape(MetaGraphT *graph) {
609 MS_ASSERT(graph != nullptr);
610 for (auto &node : graph->nodes) {
611 auto out_tensors_index = node->outputIndex;
612 for (auto index : out_tensors_index) {
613 auto &tensor = graph->allTensors.at(index);
614 auto shape = tensor->dims;
615 if (shape == std::vector{-1}) {
616 tensor->dims = {};
617 }
618 }
619 }
620 }
621 } // namespace lite
622 } // namespace mindspore
623