1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
19 #include <vector>
20 #include <deque>
21 #include "src/common/common.h"
22 #include "src/common/log_adapter.h"
23 #include "include/errorcode.h"
24 #include "src/tensor.h"
25 #include "src/tensorlist.h"
26 #include "src/common/prim_util.h"
27 #include "src/common/ops/populate/populate_register.h"
28 #include "src/litert/infer_manager.h"
29 #include "src/common/primitive_t_utils.h"
30 #include "tools/common/node_util.h"
31 #include "src/common/string_utils.h"
32 #include "src/common/log_util.h"
33 #include "nnacl/op_base.h"
34
35 using mindspore::converter::kFmkTypeTf;
36 namespace {
37 constexpr int DEFAULT_DIM_VALUE = -1;
38 constexpr size_t kInitialSize = 1024;
39 constexpr int kMainGraphIndex = 0;
40 constexpr int kCallInputMinSize = 1;
41 constexpr int kSwitchInputMinSize = 3;
42 constexpr int kTypeIndex = 0;
43 constexpr int kElementShapeIndex = 1;
44 constexpr int kFirstElementShapeIndex = 2;
45 constexpr int kTensorListDataSize = 3;
46 } // namespace
47 namespace mindspore {
48 namespace lite {
49 namespace {
FreeTensors(std::vector<Tensor * > * input_tensors,std::vector<Tensor * > * output_tensors)50 void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
51 if (input_tensors == nullptr) {
52 return;
53 }
54 for (auto &tensor : *input_tensors) {
55 if (tensor == nullptr) {
56 continue;
57 }
58 if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
59 tensor->set_data(nullptr);
60 }
61 delete tensor;
62 tensor = nullptr;
63 }
64 if (output_tensors == nullptr) {
65 return;
66 }
67 for (auto &tensor : *output_tensors) {
68 if (tensor == nullptr) {
69 continue;
70 }
71 if (tensor->data_type() != kObjectTypeString && tensor->data_type() != kObjectTypeTensorType) {
72 tensor->set_data(nullptr);
73 }
74 delete tensor;
75 tensor = nullptr;
76 }
77 input_tensors->resize(0);
78 output_tensors->resize(0);
79 }
80
81 namespace {
82 constexpr int kBytesPerInt = 4;
83 } // namespace
84
ConvertTensorList(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)85 void ConvertTensorList(const MetaGraphT *graph, uint32_t index, bool *convert_succ,
86 std::vector<Tensor *> *lite_tensors) {
87 if (graph == nullptr) {
88 MS_LOG(ERROR) << "graph is nullptr";
89 return;
90 }
91 std::unique_ptr<Tensor> lite_tensor = nullptr;
92 auto &tensorT = graph->allTensors.at(index);
93 std::vector<int32_t> tensor_shape{};
94 TypeId type = kTypeUnknown;
95 std::vector<int> element_shape;
96 if (tensorT->data.size() >= kBytesPerInt) {
97 int *data = reinterpret_cast<int *>(tensorT->data.data());
98 type = TypeId(data[kTypeIndex]);
99 auto basic_data_size = tensorT->data.size() / sizeof(int);
100 if (basic_data_size < static_cast<size_t>(kTensorListDataSize)) {
101 MS_LOG(ERROR) << "tensorlist data length illegal, which should be at least 3, now is " << basic_data_size;
102 *convert_succ = false;
103 return;
104 }
105 if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDataSize)) {
106 MS_LOG(ERROR) << "int add overflow.";
107 *convert_succ = false;
108 return;
109 }
110 if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDataSize)) > basic_data_size) {
111 MS_LOG(ERROR) << "tensorlist data length illegal. current tensorlist data length should be at least "
112 << (data[kElementShapeIndex] + kTensorListDataSize) << ", but now is " << basic_data_size;
113 *convert_succ = false;
114 return;
115 }
116 auto element_num = data[data[kElementShapeIndex] + kFirstElementShapeIndex];
117 if (element_num > 0 && INT_ADD_OVERFLOW(element_num, 1)) {
118 MS_LOG(ERROR) << "int add overflow.";
119 *convert_succ = false;
120 return;
121 }
122 auto shape_once = data[kElementShapeIndex] + 1;
123 auto shape_group_num = element_num < 0 ? 1 : element_num + 1;
124 if (INT_MUL_OVERFLOW(shape_once, shape_group_num)) {
125 MS_LOG(ERROR) << "int mul overflow.";
126 *convert_succ = false;
127 return;
128 }
129 tensor_shape = {element_num};
130 auto shape_info_size = shape_once * shape_group_num;
131 if (INT_ADD_OVERFLOW(shape_info_size, kFirstElementShapeIndex)) {
132 MS_LOG(ERROR) << "int add overflow.";
133 *convert_succ = false;
134 return;
135 }
136 int real_data_size = shape_info_size + kFirstElementShapeIndex;
137 if (real_data_size <= 0 || static_cast<uint32_t>(real_data_size) != basic_data_size) {
138 MS_LOG(ERROR) << "current tensorlist data length should be " << real_data_size << ", but now is "
139 << basic_data_size;
140 *convert_succ = false;
141 return;
142 }
143 for (int j = 0; j < data[kElementShapeIndex]; ++j) {
144 element_shape.push_back(data[j + kFirstElementShapeIndex]);
145 }
146 }
147 lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
148 if (lite_tensor == nullptr) {
149 MS_LOG(ERROR) << "lite tensorlist is nullptr";
150 *convert_succ = false;
151 return;
152 }
153
154 auto lite_tensor_list = reinterpret_cast<TensorList *>(lite_tensor.get());
155 std::vector<Tensor *> tensors{};
156 if (!tensor_shape.empty() && tensor_shape.front() != -1) {
157 for (int32_t i = 0; i < tensor_shape.front(); ++i) {
158 auto tensor = new (std::nothrow) Tensor(type, element_shape);
159 tensors.emplace_back(tensor);
160 }
161 }
162
163 lite_tensor_list->set_tensors_data_type(type);
164 lite_tensor_list->set_element_shape(element_shape);
165 lite_tensor_list->set_tensors(tensors);
166 lite_tensors->emplace_back(lite_tensor.release());
167 }
168
169 namespace {
CreateRuntimeTensor(const std::unique_ptr<TensorT> & src_tensor)170 std::unique_ptr<Tensor> CreateRuntimeTensor(const std::unique_ptr<TensorT> &src_tensor) {
171 if (src_tensor == nullptr) {
172 MS_LOG(ERROR) << "src tensor is nullptr";
173 return nullptr;
174 }
175 std::unique_ptr<Tensor> runtime_tensor = nullptr;
176 auto tensor_shape = src_tensor->dims;
177 runtime_tensor = std::make_unique<Tensor>(TypeId(src_tensor->dataType), tensor_shape,
178 static_cast<mindspore::Format>(src_tensor->format),
179 TensorCategory(src_tensor->nodeType, src_tensor->dims.size(),
180 TypeId(src_tensor->dataType), src_tensor->data.size()));
181 if (runtime_tensor == nullptr) {
182 MS_LOG(ERROR) << "Create runtime tensor failed";
183 return nullptr;
184 }
185 return runtime_tensor;
186 }
187 } // namespace
188
ConvertString(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)189 void ConvertString(const MetaGraphT *graph, uint32_t index, bool *convert_succ, std::vector<Tensor *> *lite_tensors) {
190 auto &tensorT = graph->allTensors.at(index);
191 auto runtime_tensor = CreateRuntimeTensor(tensorT);
192 if (runtime_tensor == nullptr) {
193 *convert_succ = false;
194 return;
195 }
196 // when tensorT as param input
197 if (tensorT->data.empty()) {
198 lite_tensors->emplace_back(runtime_tensor.release());
199 return;
200 }
201 auto string_buffer = ParseStringBuffer(tensorT->data.data());
202 auto ret = WriteStringsToTensor(runtime_tensor.get(), string_buffer);
203 if (ret != RET_OK) {
204 MS_LOG(ERROR) << "WriteStringsToTensor failed";
205 *convert_succ = false;
206 return;
207 }
208 lite_tensors->emplace_back(runtime_tensor.release());
209 }
210
ConvertOtherTensor(const MetaGraphT * graph,uint32_t index,bool * convert_succ,std::vector<Tensor * > * lite_tensors)211 void ConvertOtherTensor(const MetaGraphT *graph, uint32_t index, bool *convert_succ,
212 std::vector<Tensor *> *lite_tensors) {
213 CHECK_NULL_RETURN_VOID(graph);
214 auto &tensorT = graph->allTensors.at(index);
215 auto runtime_tensor = CreateRuntimeTensor(tensorT);
216 if (runtime_tensor == nullptr) {
217 *convert_succ = false;
218 return;
219 }
220 // when tensorT as param input
221 if (tensorT->data.empty()) {
222 lite_tensors->emplace_back(runtime_tensor.release());
223 return;
224 }
225 runtime_tensor->set_data(tensorT->data.data());
226 lite_tensors->emplace_back(runtime_tensor.release());
227 }
228
ConvertTensorToLiteTensor(const MetaGraphT * graph,const std::vector<uint32_t> & tensor_indexs)229 std::vector<Tensor *> ConvertTensorToLiteTensor(const MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs) {
230 MS_ASSERT(graph != nullptr);
231 std::vector<Tensor *> lite_tensors;
232 bool convert_succ = true;
233 for (unsigned int tensor_index : tensor_indexs) {
234 auto &tensorT = graph->allTensors.at(tensor_index);
235 switch (tensorT->dataType) {
236 case kObjectTypeTensorType:
237 ConvertTensorList(graph, tensor_index, &convert_succ, &lite_tensors);
238 break;
239 case kObjectTypeString:
240 MS_CHECK_TRUE_MSG(tensorT->dims.size() <= 1, {}, "String type tensor dims should be less than or equal to 1.");
241 ConvertString(graph, tensor_index, &convert_succ, &lite_tensors);
242 break;
243 default:
244 ConvertOtherTensor(graph, tensor_index, &convert_succ, &lite_tensors);
245 break;
246 }
247 }
248 if (!convert_succ) {
249 FreeTensors(&lite_tensors, {});
250 return {};
251 }
252 return lite_tensors;
253 }
254
NodeInferShape(const std::unique_ptr<schema::CNodeT> & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)255 STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
256 std::vector<Tensor *> *outputs) {
257 flatbuffers::FlatBufferBuilder fbb(kInitialSize);
258 auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
259 if (prim == nullptr) {
260 MS_LOG(ERROR) << "get primitive failed.";
261 fbb.Clear();
262 return RET_ERROR;
263 }
264
265 auto ret = KernelInferShape(inputs, *outputs, prim, {}, static_cast<int>(SCHEMA_CUR));
266 if (ret == lite::RET_NOT_SUPPORT) {
267 auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
268 static_cast<int>(prim->value_type()), static_cast<int>(SCHEMA_CUR));
269 if (parameter_gen == nullptr) {
270 fbb.Clear();
271 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
272 return RET_ERROR;
273 }
274 auto parameter = parameter_gen(prim);
275 if (parameter == nullptr) {
276 fbb.Clear();
277 MS_LOG(ERROR) << "parameter is nullptr.";
278 return RET_ERROR;
279 }
280 parameter->quant_type_ = static_cast<int>(node->quantType);
281 ret = KernelInferShape(inputs, *outputs, parameter);
282 if (parameter->destroy_func_ != nullptr) {
283 parameter->destroy_func_(parameter);
284 }
285 free(parameter);
286 parameter = nullptr;
287 }
288
289 fbb.Clear();
290 return ret;
291 }
292
293 #ifdef Debug
PrintTensorShape(const std::vector<Tensor * > & input_tensors,const std::vector<Tensor * > & output_tensors)294 void PrintTensorShape(const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors) {
295 int i = 0;
296 for (auto input_tensor : input_tensors) {
297 std::ostringstream oss;
298 for (auto &dim : input_tensor->shape()) {
299 oss << " " << dim;
300 }
301 MS_LOG(DEBUG) << "input shape " << i++ << ":" << oss.str();
302 }
303 i = 0;
304 for (auto output_tensor : output_tensors) {
305 std::ostringstream oss;
306 for (auto &dim : output_tensor->shape()) {
307 oss << " " << dim;
308 }
309 MS_LOG(DEBUG) << "output shape" << i++ << ":" << oss.str();
310 }
311 }
312 #endif
313
SetDataType(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,const std::unique_ptr<mindspore::schema::CNodeT> & node,std::vector<InferTensor> * tensors,size_t i)314 int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
315 const std::unique_ptr<mindspore::schema::CNodeT> &node, std::vector<InferTensor> *tensors, size_t i) {
316 auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
317 output_tensor->format = static_cast<schema::Format>(output_tensors[i]->format());
318 output_tensor->dataType = output_tensors[i]->data_type();
319 if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
320 auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
321 MSLITE_CHECK_PTR(tensor_list);
322 int tensor_shape_dims = 0;
323 if (!tensor_list->tensors().empty()) {
324 tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
325 }
326 MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDataSize), static_cast<int>(sizeof(int))),
327 RET_ERROR, "int mul overflow");
328 if (tensor_list->tensors_data_type() == kTypeUnknown) {
329 if (!tensor_list->tensors().empty()) {
330 tensor_list->set_tensors_data_type(tensor_list->tensors().front()->data_type());
331 }
332 }
333 std::vector<int> basic_data;
334 basic_data.push_back(tensor_list->tensors_data_type());
335 if (tensor_list->element_shape().empty() && !tensor_list->tensors().empty()) {
336 tensor_list->set_element_shape(tensor_list->tensors().front()->shape());
337 }
338 basic_data.push_back(tensor_list->element_shape().size());
339 for (size_t j = 0; j < tensor_list->element_shape().size(); ++j) {
340 basic_data.push_back(tensor_list->element_shape().at(j));
341 }
342 basic_data.push_back(tensor_list->tensors().size());
343 for (size_t index = 0; index < tensor_list->tensors().size(); ++index) {
344 auto tensor_shape = tensor_list->GetTensor(static_cast<int>(index))->shape();
345 basic_data.push_back(tensor_shape.size());
346 for (size_t j = 0; j < tensor_shape.size(); ++j) {
347 basic_data.push_back(tensor_shape[j]);
348 }
349 }
350 output_tensor->data.resize(basic_data.size() * sizeof(int));
351 if (memcpy_s(output_tensor->data.data(), output_tensor->data.size(), basic_data.data(),
352 basic_data.size() * sizeof(int)) != EOK) {
353 MS_LOG(ERROR) << "memcpy data failed.";
354 return RET_ERROR;
355 }
356 } else if (output_tensors[i]->data_type() == kTypeUnknown) {
357 tensors->at(node->outputIndex[i]).is_inferred_ = false;
358 return RET_OK;
359 }
360 tensors->at(node->outputIndex[i]).is_inferred_ = true;
361 return RET_OK;
362 }
363
CopyOutputInfoToTensorT(MetaGraphT * graph,const std::vector<Tensor * > & output_tensors,const std::unique_ptr<mindspore::schema::CNodeT> & node,std::vector<InferTensor> * tensors)364 int CopyOutputInfoToTensorT(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
365 const std::unique_ptr<mindspore::schema::CNodeT> &node, std::vector<InferTensor> *tensors) {
366 for (uint32_t i = 0; i < output_tensors.size(); i++) {
367 auto output_dims = output_tensors[i]->shape();
368 auto &output_tensorT = graph->allTensors.at(node->outputIndex[i]);
369 MSLITE_CHECK_PTR(output_tensorT);
370 output_tensorT->dims.swap(output_dims);
371 if (SetDataType(graph, output_tensors, node, tensors, i) != RET_OK) {
372 MS_LOG(ERROR) << "SetDataType failed.";
373 return RET_ERROR;
374 }
375 }
376 return RET_OK;
377 }
378
PartialGraphIndex(const CNodeT * partial_node)379 int64_t PartialGraphIndex(const CNodeT *partial_node) {
380 MSLITE_CHECK_PTR(partial_node);
381 return partial_node->primitive->value.AsPartialFusion()->sub_graph_index;
382 }
383 } // namespace
384
CopyPartialShapeToSubGraph(const CNodeT * partial_node,MetaGraphT * graph)385 int InferShapePass::CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph) {
386 auto subgraph_index = PartialGraphIndex(partial_node);
387 auto &subgraph = graph->subGraph.at(subgraph_index);
388
389 if (subgraph->inputIndices.size() != partial_node->inputIndex.size()) {
390 MS_LOG(ERROR) << "partial node " << partial_node->name << " inputs size: " << partial_node->inputIndex.size()
391 << " vs "
392 << " subgraph " << subgraph_index << " input size: " << subgraph->inputIndices.size();
393 return RET_PARAM_INVALID;
394 }
395
396 for (size_t i = 0; i < partial_node->inputIndex.size(); ++i) {
397 auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
398 MSLITE_CHECK_PTR(subgraph_input);
399 auto &partial_input = graph->allTensors.at(partial_node->inputIndex[i]);
400 MSLITE_CHECK_PTR(partial_input);
401 subgraph_input->dataType = partial_input->dataType;
402 subgraph_input->dims = partial_input->dims;
403 subgraph_input->format = partial_input->format;
404 subgraph_input->data.resize(partial_input->data.size(), 0);
405 if (partial_input->data.empty()) {
406 continue;
407 }
408 auto ret = memcpy_s(subgraph_input->data.data(), subgraph_input->data.size(), partial_input->data.data(),
409 partial_input->data.size());
410 if (ret != EOK) {
411 MS_LOG(ERROR) << "memcpy failed, ret: " << ret;
412 return RET_ERROR;
413 }
414 }
415 return RET_OK;
416 }
417
RestoreSubGraphInput(const CNodeT * partial_node,MetaGraphT * graph)418 void InferShapePass::RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph) {
419 auto subgraph_index = PartialGraphIndex(partial_node);
420 auto &subgraph = graph->subGraph.at(subgraph_index);
421 for (size_t i = 0; i < subgraph->inputIndices.size(); ++i) {
422 auto &subgraph_input = graph->allTensors.at(subgraph->inputIndices[i]);
423 if (subgraph_input->dataType != kObjectTypeTensorType) {
424 subgraph_input->data = {};
425 }
426 }
427 }
428
SetNonTailCallOutputShape(const std::unique_ptr<CNodeT> & call_node,const CNodeT * partial_node,MetaGraphT * graph)429 int InferShapePass::SetNonTailCallOutputShape(const std::unique_ptr<CNodeT> &call_node, const CNodeT *partial_node,
430 MetaGraphT *graph) {
431 auto subgraph_index = PartialGraphIndex(partial_node);
432 auto &subgraph = graph->subGraph.at(subgraph_index);
433 size_t call_node_output_size = call_node->outputIndex.size();
434 size_t subgraph_output_size = subgraph->outputIndices.size();
435 if (subgraph_output_size != call_node_output_size) {
436 MS_LOG(ERROR) << "call node output size: " << call_node_output_size
437 << " is same as corresponding subgraph output size: " << subgraph_output_size;
438 return RET_ERROR;
439 }
440 for (size_t i = 0; i < subgraph_output_size; ++i) {
441 auto &subgraph_output_tensor = graph->allTensors.at(subgraph->outputIndices[i]);
442 auto &call_output_tensor = graph->allTensors.at(call_node->outputIndex[i]);
443 call_output_tensor->format = subgraph_output_tensor->format;
444 call_output_tensor->dims = subgraph_output_tensor->dims;
445 call_output_tensor->dataType = subgraph_output_tensor->dataType;
446 }
447 return RET_OK;
448 }
449
InferPartialNode(const bool & is_tail_call,const std::unique_ptr<CNodeT> & call_node,const CNodeT * partial_node,MetaGraphT * graph)450 int InferShapePass::InferPartialNode(const bool &is_tail_call, const std::unique_ptr<CNodeT> &call_node,
451 const CNodeT *partial_node, MetaGraphT *graph) {
452 int64_t subgraph_index = PartialGraphIndex(partial_node);
453 int ret = CopyPartialShapeToSubGraph(partial_node, graph);
454 if (ret != RET_OK) {
455 MS_LOG(ERROR) << "CopyPartialShapeToSubGraph failed, ret: " << ret;
456 return ret;
457 }
458
459 ret = InferSubgraph(subgraph_index, graph);
460 if (ret != RET_OK) {
461 // not return ret here to infer the following part of graph
462 MS_LOG(WARNING) << "InferSubgraph index: " << subgraph_index << " failed, ret: " << ret;
463 }
464
465 RestoreSubGraphInput(partial_node, graph);
466
467 if (!is_tail_call) {
468 ret = SetNonTailCallOutputShape(call_node, partial_node, graph);
469 if (ret != RET_OK) {
470 MS_LOG(ERROR) << "SetNonTailCallOutputShape failed.";
471 return ret;
472 }
473 }
474 return ret;
475 }
476
InitInferTensor(MetaGraphT * graph)477 void InferShapePass::InitInferTensor(MetaGraphT *graph) {
478 tensors_.resize(graph->allTensors.size());
479 for (size_t i = 0; i < graph->nodes.size(); i++) {
480 auto &node = graph->nodes.at(i);
481 auto node_input_indexes = node->inputIndex;
482 // init in_nodes index
483 for (unsigned int node_input_indexe : node_input_indexes) {
484 tensors_[node_input_indexe].next_nodes_.push_back(i);
485 }
486 auto node_output_indexes = node->outputIndex;
487 for (unsigned int node_output_indexe : node_output_indexes) {
488 tensors_[node_output_indexe].prev_nodes_.push_back(i);
489 }
490 }
491
492 for (auto input_idx : graph->inputIndex) {
493 auto input_tensor = graph->allTensors[input_idx].get();
494 CHECK_NULL_RETURN_VOID(input_tensor);
495 for (auto &dim : input_tensor->dims) {
496 if (dim == 0) {
497 MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
498 dim = DEFAULT_DIM_VALUE;
499 }
500 }
501 }
502 }
503
InferSwitchOrSwitchLayerNode(const bool & is_tail_call,const std::unique_ptr<CNodeT> & call_node,const std::unique_ptr<CNodeT> & aim_node,MetaGraphT * graph)504 int InferShapePass::InferSwitchOrSwitchLayerNode(const bool &is_tail_call, const std::unique_ptr<CNodeT> &call_node,
505 const std::unique_ptr<CNodeT> &aim_node, MetaGraphT *graph) {
506 if (aim_node->inputIndex.size() < kSwitchInputMinSize) {
507 MS_LOG(ERROR) << "switch or switch_layer node input size: " << aim_node->inputIndex.size() << " is less than 3.";
508 return RET_PARAM_INVALID;
509 }
510
511 size_t aim_node_input_size = aim_node->inputIndex.size();
512 std::vector<uint32_t> all_partial_index{};
513 for (size_t i = 1; i < aim_node_input_size; ++i) {
514 all_partial_index.push_back(aim_node->inputIndex.at(i));
515 }
516
517 std::vector<CNodeT *> all_partial_nodes{};
518 for (auto &partial_index : all_partial_index) {
519 for (auto &node : graph->nodes) {
520 MSLITE_CHECK_PTR(node);
521 if (node->primitive->value.type != PrimitiveType_PartialFusion) {
522 continue;
523 }
524 if (IsContain(node->outputIndex, partial_index)) {
525 all_partial_nodes.push_back(node.get());
526 break;
527 }
528 }
529 }
530
531 std::deque<CNodeT *> to_process{};
532 for (auto &partial_node : all_partial_nodes) {
533 if (partial_cnode_inferred_.find(partial_node) == partial_cnode_inferred_.end()) {
534 to_process.push_back(partial_node);
535 (void)partial_cnode_inferred_.insert(partial_node);
536 }
537 }
538
539 while (!to_process.empty()) {
540 auto node = to_process.front();
541 to_process.pop_front();
542 int ret = InferPartialNode(is_tail_call, call_node, node, graph);
543 if (ret != RET_OK) {
544 MS_LOG(WARNING) << "not support partial infer.";
545 return ret;
546 }
547 }
548
549 return RET_OK;
550 }
551
InferCallNode(const std::unique_ptr<CNodeT> & call_node,MetaGraphT * graph)552 int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph) {
553 MSLITE_CHECK_PTR(call_node);
554 if (call_node->inputIndex.size() < kCallInputMinSize) {
555 MS_LOG(ERROR) << "call node input size: " << call_node->inputIndex.size() << " is less than one.";
556 return RET_PARAM_INVALID;
557 }
558 auto call_first_input_index = call_node->inputIndex.front();
559 bool is_tail_call = call_node->primitive->value.AsCall()->is_tail_call;
560 for (auto &node : graph->nodes) {
561 if (!IsContain(node->outputIndex, call_first_input_index)) {
562 continue;
563 }
564 switch (node->primitive->value.type) {
565 case PrimitiveType_PartialFusion:
566 return InferPartialNode(is_tail_call, call_node, node.get(), graph);
567 case PrimitiveType_Switch:
568 case PrimitiveType_SwitchLayer:
569 return InferSwitchOrSwitchLayerNode(is_tail_call, call_node, node, graph);
570 default:
571 MS_LOG(ERROR) << "not able to call partial or call switch.";
572 return RET_ERROR;
573 }
574 }
575 return RET_OK;
576 }
577
InferSubgraph(const int64_t & subgraph_index,MetaGraphT * graph)578 int InferShapePass::InferSubgraph(const int64_t &subgraph_index, MetaGraphT *graph) {
579 std::vector<uint32_t> infer_node_indexes{};
580 int ret = InitSearchTensor(subgraph_index, graph, &infer_node_indexes);
581 if (ret != RET_OK) {
582 MS_LOG(ERROR) << "InitSearchTensor failed.";
583 return ret;
584 }
585 if (infer_node_indexes.empty()) {
586 MS_LOG(DEBUG) << "no need to infer.";
587 return RET_OK;
588 }
589
590 while (!infer_node_indexes.empty()) {
591 auto infer_node_index = infer_node_indexes.front();
592 auto &node = graph->nodes.at(infer_node_index);
593 MSLITE_CHECK_PTR(node);
594 infer_node_indexes.erase(infer_node_indexes.begin());
595 if (node->primitive == nullptr) {
596 MS_LOG(WARNING) << "node primitive is nullptr!";
597 continue;
598 }
599 auto node_type = node->primitive->value.type;
600 if (node_type == PrimitiveType_Call) {
601 ret = InferCallNode(node, graph);
602 if (ret != RET_OK) {
603 MS_LOG(ERROR) << "infer call node failed.";
604 return ret;
605 }
606 }
607
608 auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex);
609 auto output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex);
610 if (output_tensors.empty() || output_tensors.size() != node->outputIndex.size() || input_tensors.empty() ||
611 input_tensors.size() != node->inputIndex.size()) {
612 MS_LOG(ERROR) << "convert lite tensor error";
613 FreeTensors(&input_tensors, &output_tensors);
614 return RET_INFER_ERR;
615 }
616 auto status = NodeInferShape(node, input_tensors, &output_tensors);
617 MS_LOG(DEBUG) << "cur node:" << node->name;
618 if (status == RET_OK || status == RET_INFER_INVALID) {
619 #ifdef Debug
620 PrintTensorShape(input_tensors, output_tensors);
621 #endif
622 ret = CopyOutputInfoToTensorT(graph, output_tensors, node, &tensors_);
623 if (ret != RET_OK) {
624 MS_LOG(ERROR) << "SetDataType failed: " << ret;
625 FreeTensors(&input_tensors, &output_tensors);
626 return RET_INFER_ERR;
627 }
628 } else {
629 MS_LOG(WARNING) << "InferShape failed, name: " << node->name
630 << ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
631 FreeTensors(&input_tensors, &output_tensors);
632 return RET_INFER_ERR;
633 }
634 FreeTensors(&input_tensors, &output_tensors);
635 AddOutputNodes(graph, &infer_node_indexes, infer_node_index);
636 }
637 return RET_OK;
638 }
639
Run(MetaGraphT * graph)640 STATUS InferShapePass::Run(MetaGraphT *graph) {
641 CHECK_NULL_RETURN(graph);
642 InitInferTensor(graph);
643
644 int ret = InferSubgraph(kMainGraphIndex, graph);
645 if (ret != RET_OK) {
646 MS_LOG(ERROR) << "InferSubgraph index: " << kMainGraphIndex << " failed, ret: " << ret;
647 return ret;
648 }
649
650 ResetIncorrectTensorShape(graph);
651 return RET_OK;
652 }
653
InitSearchTensor(const int64_t & subgraph_index,MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes)654 int InferShapePass::InitSearchTensor(const int64_t &subgraph_index, MetaGraphT *graph,
655 std::vector<uint32_t> *infer_node_indexes) {
656 if (static_cast<size_t>(subgraph_index) >= graph->subGraph.size()) {
657 MS_LOG(ERROR) << "subgraph_index: " << subgraph_index
658 << " is larger than graph->subGraph.size(): " << graph->subGraph.size();
659 return RET_ERROR;
660 }
661 auto &subgraph = graph->subGraph.at(subgraph_index);
662 for (uint32_t i = 0; i < tensors_.size(); i++) {
663 if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty() ||
664 (graph->allTensors.at(i)->nodeType == NodeType_ValueNode && graph->allTensors.at(i)->dims.size() == 1 &&
665 graph->allTensors.at(i)->dims[0] == 0)) {
666 tensors_[i].is_inferred_ = true;
667 }
668 }
669 for (size_t i = 0; i < subgraph->nodeIndices.size(); i++) {
670 auto &node = graph->nodes.at(subgraph->nodeIndices.at(i));
671 if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
672 [&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
673 infer_node_indexes->push_back(subgraph->nodeIndices.at(i));
674 }
675 }
676 return RET_OK;
677 }
678
AddOutputNodes(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,uint32_t infer_node_index)679 void InferShapePass::AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
680 uint32_t infer_node_index) {
681 auto &node = graph->nodes.at(infer_node_index);
682 for (size_t i = 0; i < node->outputIndex.size(); i++) {
683 auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
684 for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
685 auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
686 if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
687 [&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
688 AddNextInferShapeNode(graph, infer_node_indexes, next_nodes_indexes, j);
689 }
690 }
691 }
692 }
693
AddNextInferShapeNode(MetaGraphT * graph,std::vector<uint32_t> * infer_node_indexes,std::vector<uint32_t> next_nodes_indexes,size_t index)694 void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes,
695 std::vector<uint32_t> next_nodes_indexes, size_t index) {
696 auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
697 if (find(infer_node_indexes->begin(), infer_node_indexes->end(), next_nodes_indexes[index]) ==
698 infer_node_indexes->end()) {
699 if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
700 [&](uint32_t i) { return tensors_[i].is_inferred_; })) {
701 infer_node_indexes->push_back(next_nodes_indexes[index]);
702 }
703 }
704 }
705
ResetIncorrectTensorShape(MetaGraphT * graph)706 void InferShapePass::ResetIncorrectTensorShape(MetaGraphT *graph) {
707 MS_ASSERT(graph != nullptr);
708 for (auto &node : graph->nodes) {
709 auto out_tensors_index = node->outputIndex;
710 for (auto index : out_tensors_index) {
711 auto &tensor = graph->allTensors.at(index);
712 auto shape = tensor->dims;
713 if (shape == std::vector{-1}) {
714 tensor->dims = {};
715 }
716 }
717 }
718 }
719 } // namespace lite
720 } // namespace mindspore
721