• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/graph/node_infershape.h"
18 #include <memory>
19 #include <vector>
20 #include "tools/common/node_util.h"
21 #include "tools/common/tensor_util.h"
22 #include "src/common/utils.h"
23 #include "src/ops/populate/populate_register.h"
24 #include "src/ops/ops_utils.h"
25 #include "src/runtime/infer_manager.h"
26 #include "src/tensorlist.h"
27 #include "src/registry/kernel_interface_registry.h"
28 #include "nnacl/op_base.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr int kInputChannal = 3;
34 constexpr size_t INITIAL_SIZE = 1024;
FreeTensors(std::vector<lite::Tensor * > * tensors)35 void FreeTensors(std::vector<lite::Tensor *> *tensors) {
36   if (tensors == nullptr) {
37     return;
38   }
39   for (auto &v : *tensors) {
40     delete v;
41     v = nullptr;
42   }
43   tensors->resize(0);
44 }
45 
RectifyFormat(const std::vector<lite::Tensor * > & inputs,FmkType fmk_type)46 void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
47   MS_ASSERT(cnode != nullptr);
48   if (fmk_type != converter::kFmkTypeOnnx) {
49     return;
50   }
51   for (auto &input : inputs) {
52     auto shape = input->shape();
53     if (shape.size() == kInputSizeFour && shape[kInputIndexThree] == kInputChannal && shape[1] == -1) {
54       input->set_format(mindspore::NHWC);
55     }
56   }
57 }
58 
NewTensorInfo(lite::Tensor * tensor)59 tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
60   std::vector<int> shape(tensor->shape());
61   std::vector<int64_t> shape_vector(shape.begin(), shape.end());
62   auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
63   if (tensor_info == nullptr) {
64     MS_LOG(ERROR) << "new tensor::Tensor failed";
65     return nullptr;
66   }
67   return tensor_info;
68 }
69 }  // namespace
70 
JudgeOpSupportInfer(const CNodePtr & cnode)71 bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
72   MS_ASSERT(cnode != nullptr);
73   if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
74     return true;
75   }
76   auto prim_t = lite::GetPrimitiveT(cnode->input(0));
77   if (prim_t == nullptr) {
78     return false;
79   }
80   auto parameter_gen =
81     lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim_t->value.type), lite::SCHEMA_CUR);
82   if (parameter_gen == nullptr) {
83     prim_t.reset();
84     return false;
85   }
86   return true;
87 }
88 
InferShape(const CNodePtr & cnode)89 STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
90   MS_ASSERT(cnode != nullptr);
91   auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
92   if (anf_prim == nullptr) {
93     MS_LOG(DEBUG) << "primitive is nullptr";
94     return lite::RET_ERROR;
95   }
96   anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
97   std::vector<lite::Tensor *> inputs;
98   std::vector<lite::Tensor *> outputs;
99   if (GetCNodeInputTensors(cnode, &inputs) != lite::RET_OK) {
100     FreeTensors(&inputs);
101     MS_LOG(ERROR) << "get inputs failed.";
102     return lite::RET_ERROR;
103   }
104   if (GetCNodeOutputTensors(cnode, &outputs) != lite::RET_OK) {
105     FreeTensors(&inputs);
106     FreeTensors(&outputs);
107     MS_LOG(ERROR) << "get outputs failed.";
108     return lite::RET_ERROR;
109   }
110   auto prim_t = lite::GetPrimitiveT(cnode->input(0));
111   if (prim_t == nullptr) {
112     MS_LOG(DEBUG) << "prim_t is nullptr";
113     FreeTensors(&inputs);
114     FreeTensors(&outputs);
115     return lite::RET_ERROR;
116   }
117   flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
118   auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
119   if (prim == nullptr) {
120     MS_LOG(ERROR) << "get primitive failed.";
121     FreeTensors(&inputs);
122     FreeTensors(&outputs);
123     fbb.Clear();
124     return lite::RET_ERROR;
125   }
126   auto ret = KernelInferShape(inputs, outputs, prim, {}, lite::SCHEMA_CUR);
127   if (ret == lite::RET_NOT_SUPPORT) {
128     auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
129       static_cast<int>(prim->value_type()), lite::SCHEMA_CUR);
130     if (parameter_gen == nullptr) {
131       MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
132       FreeTensors(&inputs);
133       FreeTensors(&outputs);
134       fbb.Clear();
135       return lite::RET_ERROR;
136     }
137     auto parameter = parameter_gen(prim);
138     if (parameter == nullptr) {
139       MS_LOG(ERROR) << "parameter is nullptr.";
140       FreeTensors(&inputs);
141       FreeTensors(&outputs);
142       fbb.Clear();
143       return lite::RET_ERROR;
144     }
145     RectifyFormat(inputs, fmk_type_);
146     ret = KernelInferShape(inputs, outputs, parameter);
147     if (parameter->destroy_func_ != nullptr) {
148       parameter->destroy_func_(parameter);
149     }
150     free(parameter);
151     parameter = nullptr;
152   }
153   fbb.Clear();
154   if (ret == lite::RET_OK) {
155     anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
156   }
157   if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
158     auto set_status = SetCNodeAbstract(cnode, outputs, ret);
159     auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
160     MS_CHECK_TRUE_MSG(cnode_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
161     cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
162     if (set_status != lite::RET_OK) {
163       MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
164       FreeTensors(&inputs);
165       FreeTensors(&outputs);
166       return set_status;
167     }
168   } else {
169     MS_LOG(ERROR) << "infer shape failed.";
170   }
171   FreeTensors(&inputs);
172   FreeTensors(&outputs);
173   return ret;
174 }
175 
GetInputShape(const CNodePtr & cnode,size_t index)176 std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {
177   MS_ASSERT(cnode != nullptr);
178   if (index >= cnode->size()) {
179     return {};
180   }
181   lite::DataInfo data_info;
182   int status = lite::RET_OK;
183   CNodePtr base_node = cnode;
184   size_t position = index;
185   if (CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTuple) ||
186       CheckPrimitiveType(cnode->input(index), kPrimMakeTupleV2)) {
187     base_node = cnode->input(index)->cast<CNodePtr>();
188     position = 1;
189   }
190   if (utils::isa<CNode>(base_node->input(position))) {
191     status = lite::FetchDataFromCNode(base_node, position, fmk_type_, train_flag_, &data_info);
192   } else if (utils::isa<Parameter>(base_node->input(position))) {
193     status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info);
194   } else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
195     status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info);
196   } else {
197     MS_LOG(ERROR) << "input node is invalid.";
198     return {};
199   }
200   if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
201     MS_LOG(ERROR) << "fetch data failed.";
202     return {};
203   }
204   return data_info.shape_;
205 }
206 
GetIntVecInput(const CNodePtr & cnode,size_t index)207 std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) {
208   MS_ASSERT(cnode != nullptr);
209   if (index >= cnode->size()) {
210     return {};
211   }
212   auto origin_inputs = cnode->inputs();
213   std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
214   cnode->set_inputs(specify_inputs);
215   std::vector<lite::Tensor *> specify_tensors;
216   if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) {
217     cnode->set_inputs(origin_inputs);
218     return {};
219   }
220   cnode->set_inputs(origin_inputs);
221   std::vector<int> tensor_data;
222   if (specify_tensors.front()->data_type() != kNumberTypeInt32 &&
223       specify_tensors.front()->data_type() != kNumberTypeInt) {
224     FreeTensors(&specify_tensors);
225     return {};
226   }
227   if (specify_tensors.front()->shape().size() != 1) {
228     FreeTensors(&specify_tensors);
229     return {};
230   }
231   MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
232   tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
233   if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
234                tensor_data.size() * sizeof(int)) != EOK) {
235     FreeTensors(&specify_tensors);
236     return {};
237   }
238   return tensor_data;
239 }
240 
GetCNodeInputTensors(const CNodePtr & cnode,std::vector<lite::Tensor * > * inputs)241 STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs) {
242   MS_ASSERT(cnode != nullptr);
243   MS_ASSERT(inputs != nullptr);
244   auto origin_inputs = cnode->inputs();
245   lite::RemoveIfDepend(cnode);
246   lite::RemoveIfMakeTuple(cnode);
247   RemoveIfMonad(cnode);
248   std::vector<lite::Tensor *> const_inputs;
249   if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) {
250     MS_LOG(ERROR) << "get const inputs failed.";
251     FreeTensors(&const_inputs);
252     cnode->set_inputs(origin_inputs);
253     return lite::RET_ERROR;
254   }
255   std::vector<lite::Tensor *> var_inputs;
256   if (GetCNodeVarInput(cnode, &var_inputs) != lite::RET_OK) {
257     MS_LOG(ERROR) << "get var inputs failed.";
258     FreeTensors(&var_inputs);
259     cnode->set_inputs(origin_inputs);
260     return lite::RET_ERROR;
261   }
262   size_t const_index = 0;
263   size_t var_index = 0;
264   bool input_valid = true;
265   for (size_t i = 1; i < cnode->size(); ++i) {
266     if (utils::isa<CNodePtr>(cnode->input(i))) {
267       if (var_index >= var_inputs.size()) {
268         MS_LOG(ERROR) << "var inputs size invalid.";
269         input_valid = false;
270         break;
271       }
272       inputs->emplace_back(var_inputs[var_index++]);
273     } else {
274       if (const_index >= const_inputs.size()) {
275         MS_LOG(ERROR) << "const inputs size invalid.";
276         input_valid = false;
277         break;
278       }
279       inputs->emplace_back(const_inputs[const_index++]);
280     }
281   }
282   cnode->set_inputs(origin_inputs);
283   if (!input_valid) {
284     FreeTensors(&const_inputs);
285     FreeTensors(&var_inputs);
286     inputs->resize(0);
287   }
288   return lite::RET_OK;
289 }
290 
GetCNodeConstInput(const CNodePtr & cnode,std::vector<lite::Tensor * > * const_ms_inputs)291 STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) {
292   MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr);
293   std::vector<lite::DataInfo> data_infos;
294   for (size_t i = 1; i < cnode->size(); ++i) {
295     if (utils::isa<CNodePtr>(cnode->input(i))) {
296       continue;
297     }
298     STATUS status;
299     lite::DataInfo data_info;
300     if (utils::isa<ParameterPtr>(cnode->input(i))) {
301       status = lite::FetchDataFromParameterNode(cnode, i, fmk_type_, train_flag_, &data_info);
302     } else {
303       status = lite::FetchDataFromValueNode(cnode, i, fmk_type_, train_flag_, &data_info);
304     }
305     if (status == lite::RET_NO_CHANGE) {
306       continue;
307     }
308     if (status != lite::RET_OK) {
309       MS_LOG(ERROR) << "fetch const input data failed.";
310       return status;
311     }
312     data_infos.emplace_back(data_info);
313   }
314   return ConvertToLiteTensor(data_infos, const_ms_inputs);
315 }
316 
GetCNodeVarInput(const CNodePtr & cnode,std::vector<lite::Tensor * > * var_ms_inputs)317 STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) {
318   MS_ASSERT(cnode != nullptr);
319   MS_ASSERT(var_ms_inputs != nullptr);
320   for (size_t i = 1; i < cnode->size(); ++i) {
321     if (!utils::isa<CNodePtr>(cnode->input(i))) {
322       continue;
323     }
324     lite::DataInfo data_info;
325     if (lite::FetchDataFromCNode(cnode, i, fmk_type_, train_flag_, &data_info) != lite::RET_OK) {
326       MS_LOG(ERROR) << "parse cnode failed.";
327       return lite::RET_ERROR;
328     }
329     lite::Tensor *tensor = nullptr;
330     if (data_info.data_type_ == kObjectTypeTensorType) {
331       tensor = GetCNodeTensorListVarInput(data_info);
332     } else {
333       tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
334       tensor->set_format((Format)(data_info.format_));
335     }
336     if (tensor == nullptr) {
337       MS_LOG(ERROR) << "new a lite tensor failed";
338       return lite::RET_ERROR;
339     }
340     auto input_cnode = cnode->input(i)->cast<CNodePtr>();
341     MS_ASSERT(input_cnode != nullptr);
342     PrimitivePtr input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
343     if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
344       auto item_input_cnode = input_cnode->input(1)->cast<CNodePtr>();
345       MS_ASSERT(item_input_cnode != nullptr);
346       input_prim = GetValueNode<PrimitivePtr>(item_input_cnode->input(0));
347     }
348     MS_ASSERT(input_prim != nullptr);
349     if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
350       tensor->set_shape({-1});
351     }
352     var_ms_inputs->emplace_back(tensor);
353   }
354   return lite::RET_OK;
355 }
356 
GetCNodeTensorListVarInput(const lite::DataInfo & data_info)357 lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(const lite::DataInfo &data_info) {
358   auto tensor_list = new (std::nothrow) lite::TensorList(data_info.shape_, {});
359   if (tensor_list == nullptr) {
360     MS_LOG(ERROR) << "new a lite tensor list failed";
361     return nullptr;
362   }
363   if (data_info.data_.empty()) {
364     return tensor_list;
365   }
366   auto status = tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data()));
367   if (status != lite::RET_OK) {
368     delete tensor_list;
369     MS_LOG(ERROR) << "decode tensor list failed.";
370     return nullptr;
371   }
372   return tensor_list;
373 }
374 
GetCNodeOutputTensors(const CNodePtr & cnode,std::vector<lite::Tensor * > * outputs)375 STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) {
376   MS_ASSERT(cnode != nullptr);
377   MS_ASSERT(outputs != nullptr);
378   std::vector<lite::DataInfo> data_infos;
379   if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
380     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
381     if (tuple == nullptr) {
382       MS_LOG(ERROR) << "tuple is nullptr";
383       return lite::RET_ERROR;
384     }
385     auto elements = tuple->elements();
386     for (size_t i = 0; i < elements.size(); i++) {
387       lite::DataInfo data_info;
388       data_info.node_type_ = lite::NodeType_CNode;
389       if (train_flag_) {
390         data_infos.emplace_back(data_info);
391         if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || CheckPrimitiveType(cnode, prim::kPrimAdam)) {
392           break;
393         }
394       } else {
395         if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
396           MS_LOG(ERROR) << "abstract is not AbstractTensor";
397           return lite::RET_ERROR;
398         }
399         auto type = kNumberTypeFloat32;
400         if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
401           auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
402           auto typePtr = abstract_tensor->element()->GetTypeTrack();
403           type = typePtr->type_id();
404         }
405         data_info.data_type_ = type;
406         data_infos.emplace_back(data_info);
407         if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
408             CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) {
409           break;
410         }
411       }
412     }
413   } else {
414     lite::DataInfo data_info;
415     auto type = kNumberTypeFloat32;
416     if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
417       auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
418       auto typePtr = abstract_tensor->element()->GetTypeTrack();
419       type = typePtr->type_id();
420     }
421     data_info.data_type_ = type;
422     data_info.node_type_ = lite::NodeType_CNode;
423     data_infos.emplace_back(data_info);
424   }
425   return ConvertToLiteTensor(data_infos, outputs);
426 }
427 
ConvertToLiteTensor(const std::vector<lite::DataInfo> & data_infos,std::vector<lite::Tensor * > * tensors)428 STATUS NodeInferShape::ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos,
429                                            std::vector<lite::Tensor *> *tensors) {
430   MS_ASSERT(tensors != nullptr);
431   for (auto &data_info : data_infos) {
432     auto tensor_category = lite::TensorCategory(lite::NodeType(data_info.node_type_), data_info.shape_.size(),
433                                                 TypeId(data_info.data_type_), data_info.data_.size());
434     lite::Tensor *tensor = nullptr;
435     if (data_info.data_type_ != kObjectTypeTensorType) {
436       tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_,
437                                                (mindspore::Format)data_info.format_, tensor_category);
438     } else {
439       tensor = new (std::nothrow) lite::TensorList(data_info.shape_, std::vector<int>(), tensor_category);
440     }
441     if (tensor == nullptr) {
442       MS_LOG(ERROR) << "new a lite tensor failed";
443       return lite::RET_ERROR;
444     }
445     auto tensor_size = data_info.data_.size();
446     if (tensor_size > 0) {
447       if (data_info.data_type_ == kObjectTypeTensorType) {
448         auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
449         if (tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data())) != RET_OK) {
450           MS_LOG(ERROR) << "Decode tensorlist data failed";
451           return RET_ERROR;
452         }
453       } else {
454         auto tensor_data = reinterpret_cast<char *>(malloc(tensor_size));
455         if (tensor_data == nullptr) {
456           MS_LOG(ERROR) << "tensor_data is nullptr";
457           delete tensor;
458           return lite::RET_ERROR;
459         }
460         if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) {
461           delete tensor;
462           free(tensor_data);
463           tensor_data = nullptr;
464           MS_LOG(ERROR) << "memcpy error: ";
465           return lite::RET_ERROR;
466         }
467         tensor->set_data(tensor_data);
468       }
469     }
470     tensors->emplace_back(tensor);
471   }
472   return lite::RET_OK;
473 }
474 
SetCNodeAbstract(const std::shared_ptr<CNode> & cnode,const std::vector<lite::Tensor * > & outputs,int status)475 STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs,
476                                         int status) {
477   MS_ASSERT(cnode != nullptr);
478   if (outputs.size() == 0) {
479     MS_LOG(ERROR) << "empty output_tensors";
480     return RET_ERROR;
481   }
482   auto origin_abstract = cnode->abstract();
483   MS_ASSERT(origin_abstract != nullptr);
484   if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) {
485     auto tensor = outputs.front();
486     auto new_abstract = ConvertLiteTensorToAbstract(tensor);
487     if (new_abstract == nullptr) {
488       MS_LOG(ERROR) << "new abstract failed.";
489       return RET_ERROR;
490     }
491     if (status == lite::RET_INFER_INVALID) {
492       ShapeVector shape;
493       if (tensor->data_type() == kObjectTypeTensorType) {
494         shape = {0};
495       }
496       auto abstract_shape = std::make_shared<abstract::Shape>(shape);
497       CHECK_NULL_RETURN(abstract_shape);
498       new_abstract->set_shape(abstract_shape);
499     }
500     cnode->set_abstract(new_abstract);
501   } else {
502     AbstractBasePtrList abstract_list;
503     for (size_t i = 0; i < outputs.size(); i++) {
504       auto tensor = outputs.at(i);
505       auto new_abstract = ConvertLiteTensorToAbstract(tensor);
506       if (new_abstract == nullptr) {
507         MS_LOG(ERROR) << "new abstract failed.";
508         return RET_ERROR;
509       }
510       if (status == lite::RET_INFER_INVALID) {
511         ShapeVector shape;
512         if (tensor->data_type() == kObjectTypeTensorType) {
513           shape = {0};
514         }
515         auto abstract_shape = std::make_shared<abstract::Shape>(shape);
516         CHECK_NULL_RETURN(abstract_shape);
517         new_abstract->set_shape(abstract_shape);
518       }
519       abstract_list.emplace_back(new_abstract);
520     }
521     auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
522     CHECK_NULL_RETURN(new_abstract_list);
523     cnode->set_abstract(new_abstract_list);
524   }
525   return RET_OK;
526 }
527 
ConvertLiteTensorToAbstract(lite::Tensor * tensor)528 abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
529   MS_ASSERT(tensor != nullptr);
530   if (tensor->data_type() == kObjectTypeTensorType) {
531     return ConvertTensorListToAbstract(tensor);
532   }
533   auto tensor_info = NewTensorInfo(tensor);
534   if (tensor_info == nullptr) {
535     MS_LOG(ERROR) << "new tensor::Tensor failed";
536     return nullptr;
537   }
538   return tensor_info->ToAbstract();
539 }
540 
541 // stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type.
542 // both of them is different in term of shape and type.
ConvertTensorListToAbstract(lite::Tensor * tensor)543 abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) {
544   MS_ASSERT(tensor != nullptr);
545   auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
546   if (tensor_list == nullptr) {
547     MS_LOG(ERROR) << "cast tensor_list failed";
548     return nullptr;
549   }
550   std::vector<int> shape(tensor->shape());
551   std::vector<int64_t> shape_vector(shape.begin(), shape.end());
552   auto tensor_list_abstract =
553     std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector);
554   if (tensor_list_abstract == nullptr) {
555     MS_LOG(ERROR) << "new AbstractTensor failed";
556     return nullptr;
557   }
558   auto elememt_shape = tensor_list->element_shape();
559   std::vector<int> data_info;
560   data_info.push_back(tensor_list->tensors_data_type());
561   data_info.push_back(elememt_shape.size());
562   std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info));
563   data_info.push_back(tensor_list->tensors().size());
564   for (size_t i = 0; i < tensor_list->tensors().size(); ++i) {
565     auto tensor_mem = tensor_list->tensors()[i];
566     auto tensor_mem_shape = tensor_mem->shape();
567     data_info.push_back(tensor_mem_shape.size());
568     std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info));
569   }
570   std::vector<int64_t> data_shape;
571   data_shape.push_back(data_info.size());
572   auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32);
573   if (tensor_info == nullptr) {
574     MS_LOG(ERROR) << "new tensor::Tensor failed";
575     return nullptr;
576   }
577   tensor_list_abstract->set_value(tensor_info);
578   return tensor_list_abstract;
579 }
580 }  // namespace opt
581 }  // namespace mindspore
582