• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/node_infershape.h"
19 #include <memory>
20 #include <vector>
21 #include <algorithm>
22 #include <unordered_set>
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "mindspore/core/ops/comparison_ops.h"
27 #include "mindspore/core/ops/random_ops.h"
28 #include "src/common/primitive_t_utils.h"
29 #include "tools/common/node_util.h"
30 #include "tools/common/tensor_util.h"
31 #include "src/common/utils.h"
32 #include "src/common/ops/populate/populate_register.h"
33 #include "src/common/ops/anf_utils.h"
34 #include "src/litert/infer_manager.h"
35 #include "src/tensorlist.h"
36 #include "src/registry/kernel_interface_registry.h"
37 #include "tools/optimizer/graph/lite_tensor_extractor.h"
38 #include "nnacl/op_base.h"
39 #include "ops/op_utils.h"
40 #include "tools/optimizer/format/to_nchw_format.h"
41 #include "tools/optimizer/format/to_nhwc_format.h"
42 #include "tools/common/graph_util.h"
43 #include "src/common/common.h"
44 
45 namespace mindspore {
46 namespace opt {
47 static const std::unordered_set<PrimitivePtr> kNNACLToOpsInfer = {
48   // arithmetic_self
49   prim::kPrimAbs,
50   prim::kPrimAsin,
51   prim::kPrimAsinh,
52   prim::kPrimACos,
53   prim::kPrimAcosh,
54   prim::kPrimAtanh,
55   prim::kPrimCos,
56   prim::kPrimCosh,
57   prim::kPrimCeLU,
58   prim::kPrimSeLU,
59   prim::kPrimHSwish,
60   prim::kPrimMatrixDeterminant,
61   prim::kPrimLog,
62   prim::kPrimLog1p,
63   prim::kPrimSquare,
64   prim::kPrimSqrt,
65   prim::kPrimRsqrt,
66   prim::kPrimSin,
67   prim::kPrimSinh,
68   prim::kPrimFloor,
69   prim::kPrimCeil,
70   prim::kPrimRound,
71   prim::kPrimNeg,
72   prim::kPrimReciprocal,
73   prim::kPrimErf,
74   prim::kPrimSign,
75   prim::kPrimSoftsign,
76   prim::kPrimMultinomial,
77   // arithmetic
78   prim::kPrimFloorDiv,
79   prim::kPrimFloorMod,
80   prim::kPrimLogicalAnd,
81   prim::kPrimLogicalNot,
82   prim::kPrimLogicalOr,
83   prim::kPrimLogicalXor,
84   prim::kPrimMaximum,
85   prim::kPrimMinimum,
86   prim::kPrimMod,
87   prim::kPrimSquaredDifference,
88   prim::kPrimLeftShift,
89   prim::kPrimRightShift,
90   prim::kPrimROIAlign,
91   // tuple op
92   prim::kPrimTupleGetItem,
93   prim::kPrimMakeTuple,
94   prim::kPrimMakeTupleV2,
95   // nnacl/infer/common_infer.c
96   prim::kPrimClip,
97   prim::kPrimElu,
98   prim::kPrimLeakyRelu,
99   prim::kPrimLrn,
100   prim::kPrimOnesLike,
101   prim::kPrimReverseSequence,
102   prim::kPrimReverseV2,
103   prim::kPrimSmoothL1Loss,
104   prim::kPrimZerosLike,
105   // format op
106   prim::kPrimResize,
107   // compare op
108   prim::kPrimEqual,
109   prim::kPrimGreater,
110   prim::kPrimGreaterEqual,
111   prim::kPrimLess,
112   prim::kPrimLessEqual,
113   prim::kPrimNotEqual,
114 
115   prim::kPrimActivation,
116   prim::kPrimArgMaxFusion,
117   prim::kPrimArgMinFusion,
118   prim::kPrimGLU,
119   prim::kPrimGridSampler2D,
120   prim::kPrimGridSampler3D,
121   prim::kPrimDeformableConv2d,
122   // grad op
123   prim::kPrimActivationGrad,
124   prim::kPrimAbsGrad,
125   prim::kPrimBinaryCrossEntropyGrad,
126   prim::kPrimLogGrad,
127   prim::kPrimMaximumGrad,
128   prim::kPrimMinimumGrad,
129   prim::kPrimNegGrad,
130   prim::kPrimRsqrtGrad,
131   prim::kPrimSqrtGrad,
132   prim::kPrimSmoothL1LossGrad,
133   prim::kPrimGridSampler2D,
134 };
135 
136 namespace {
137 constexpr int kInputChannal = 3;
138 constexpr size_t INITIAL_SIZE = 1024;
RectifyFormat(const std::vector<lite::Tensor * > & inputs,FmkType fmk_type)139 void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
140   MS_ASSERT(cnode != nullptr);
141   if (fmk_type != converter::kFmkTypeOnnx) {
142     return;
143   }
144   for (auto &input : inputs) {
145     auto shape = input->shape();
146     if (shape.size() == kInputSizeFour && shape[kInputIndexThree] == kInputChannal && shape[1] == -1) {
147       input->set_format(mindspore::NHWC);
148     }
149   }
150 }
151 
NewTensorInfo(const lite::Tensor * tensor)152 tensor::TensorPtr NewTensorInfo(const lite::Tensor *tensor) {
153   std::vector<int> shape(tensor->shape());
154   std::vector<int64_t> shape_vector(shape.begin(), shape.end());
155   auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
156   if (tensor_info == nullptr) {
157     MS_LOG(ERROR) << "new tensor::Tensor failed";
158     return nullptr;
159   }
160   return tensor_info;
161 }
162 
ConvertAbstract(const AbstractBasePtr & src_abs,AbstractBasePtr * dst_abs,bool change,FormatTransNodeType perm)163 STATUS ConvertAbstract(const AbstractBasePtr &src_abs, AbstractBasePtr *dst_abs, bool change,
164                        FormatTransNodeType perm) {
165   if (SetAbstractTensorInfo(src_abs) != RET_OK) {
166     MS_LOG(ERROR) << "SetAbstractTensorInfo failed";
167     return lite::RET_ERROR;
168   }
169   *dst_abs = src_abs;
170   if (change) {
171     if (ConvertAbstractFormatShape(*dst_abs, perm) != RET_OK) {
172       MS_LOG(ERROR) << "ConvertAbstractFormatShape failed";
173       return lite::RET_ERROR;
174     }
175   }
176 
177   // change core/ops dynamic rank {-2} to Lite dynamic shape {-1}, will be removed after calling core/infer
178   ShapeVector shape;
179   if (opt::FetchShapeFromAbstract(*dst_abs, &shape) != RET_OK) {
180     MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
181     return RET_ERROR;
182   }
183   if (IsDynamicRank(shape)) {
184     auto nnacl_dynamic_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny});
185     (*dst_abs)->set_shape(nnacl_dynamic_shape);
186   }
187   return RET_OK;
188 }
189 }  // namespace
190 
JudgeOpSupportNNACLInfer(const CNodePtr & cnode)191 bool JudgeOpSupportNNACLInfer(const CNodePtr &cnode) {
192   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
193   if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
194     return true;
195   }
196   auto prim_t = lite::GetPrimitiveT(cnode->input(0));
197   if (prim_t == nullptr) {
198     return false;
199   }
200   auto parameter_gen =
201     lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim_t->value.type), lite::SCHEMA_CUR);
202   if (parameter_gen == nullptr) {
203     prim_t.reset();
204     return false;
205   }
206   return true;
207 }
208 
JudgeOpSupportOpsInfer(const CNodePtr & cnode)209 bool JudgeOpSupportOpsInfer(const CNodePtr &cnode) {
210   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
211   for (const auto &type : kNNACLToOpsInfer) {
212     if (CheckPrimitiveType(cnode, type)) {
213       return true;
214     }
215   }
216   return false;
217 }
218 
JudgeOpSupportInfer(const CNodePtr & cnode)219 bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
220   return JudgeOpSupportOpsInfer(cnode) || JudgeOpSupportNNACLInfer(cnode);
221 }
222 
InferShapeByNNACL(const CNodePtr & cnode)223 STATUS NodeInferShape::InferShapeByNNACL(const CNodePtr &cnode) {
224   MS_ASSERT(cnode != nullptr);
225   auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
226   if (anf_prim == nullptr) {
227     MS_LOG(DEBUG) << cnode->fullname_with_scope() << "'s cnode primitive is nullptr";
228     return lite::RET_ERROR;
229   }
230   (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
231   std::vector<TensorPtr> inputs_ptr;
232   if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, false) != lite::RET_OK) {
233     MS_LOG(ERROR) << cnode->fullname_with_scope() << " get inputs failed.";
234     return lite::RET_ERROR;
235   }
236   std::vector<TensorPtr> outputs_ptr;
237   if (LiteTensorExtractor::GetCNodeOutputTensors(cnode, &outputs_ptr, train_flag_) != lite::RET_OK) {
238     MS_LOG(ERROR) << cnode->fullname_with_scope() << " get outputs failed.";
239     return lite::RET_ERROR;
240   }
241   auto prim_t = lite::GetPrimitiveT(cnode->input(0));
242   if (prim_t == nullptr) {
243     MS_LOG(DEBUG) << cnode->fullname_with_scope() << " get lite prim_t is nullptr";
244     return lite::RET_ERROR;
245   }
246   flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
247   auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
248   if (prim == nullptr) {
249     MS_LOG(ERROR) << cnode->fullname_with_scope() << " get primitive failed.";
250     fbb.Clear();
251     return lite::RET_ERROR;
252   }
253   std::vector<lite::Tensor *> inputs;
254   (void)std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(inputs),
255                        [](const TensorPtr &input) { return input.get(); });
256   std::vector<lite::Tensor *> outputs;
257   (void)std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(outputs),
258                        [](const TensorPtr &output) { return output.get(); });
259   auto ret = KernelInferShape(inputs, outputs, prim, {}, lite::SCHEMA_CUR);
260   if (ret == lite::RET_NOT_SUPPORT) {
261     auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
262       static_cast<int>(prim->value_type()), lite::SCHEMA_CUR);
263     if (parameter_gen == nullptr) {
264       MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
265       fbb.Clear();
266       return lite::RET_ERROR;
267     }
268     auto parameter = parameter_gen(prim);
269     if (parameter == nullptr) {
270       MS_LOG(ERROR) << cnode->fullname_with_scope() << " generate nullptr lite op parameter.";
271       fbb.Clear();
272       return lite::RET_ERROR;
273     }
274     RectifyFormat(inputs, fmk_type_);
275     ret = KernelInferShape(inputs, outputs, parameter);
276     if (parameter->destroy_func_ != nullptr) {
277       parameter->destroy_func_(parameter);
278     }
279     free(parameter);
280     parameter = nullptr;
281   }
282   fbb.Clear();
283   if (ret == lite::RET_OK) {
284     (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
285   }
286   if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
287     auto set_status = SetCNodeAbstract(cnode, outputs, ret);
288     (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int64_t>(inputs[0]->format())));
289     if (set_status != lite::RET_OK) {
290       MS_LOG(ERROR) << cnode->fullname_with_scope() << " set CNode abstract failed.";
291       return set_status;
292     }
293   } else {
294     MS_LOG(WARNING) << "InferShapeByNNACL for op: " << cnode->fullname_with_scope() << " failed.";
295   }
296   std::vector<int64_t> outputs_format;
297   (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_format),
298                        [](const lite::Tensor *output) { return output->format(); });
299   (void)anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
300   return ret;
301 }
302 
InferShape(const CNodePtr & cnode)303 STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
304   MS_ASSERT(cnode != nullptr);
305   STATUS status;
306   if (JudgeOpSupportNNACLInfer(cnode)) {
307     status = InferShapeByNNACL(cnode);
308   } else {
309     MS_LOG(ERROR) << "Unsupported node: " << cnode->fullname_with_scope() << " for infershape.";
310     return RET_ERROR;
311   }
312   return status;
313 }
314 
OpsInferShape(const PrimitivePtr & anf_prim,const AbstractBasePtrList & abs_list,AbstractBasePtr * result,bool invalid)315 STATUS NodeInferShape::OpsInferShape(const PrimitivePtr &anf_prim, const AbstractBasePtrList &abs_list,
316                                      AbstractBasePtr *result, bool invalid) {
317   auto found = abstract::GetPrimitiveInferImpl(anf_prim);
318   if (!found.has_value()) {
319     MS_LOG(ERROR) << "Can't find the infer impl for ops: " << anf_prim->name();
320     return lite::RET_ERROR;
321   }
322   auto infer = found.value();
323   if (!infer.IsImplInferShapeAndType()) {
324     MS_LOG(ERROR) << "For ops: " << anf_prim->name() << ", the InferShapeAndType is not implemented.";
325     return lite::RET_ERROR;
326   }
327 
328   *result = found->InferShapeAndType(nullptr, anf_prim, abs_list);
329   if (*result == nullptr) {
330     MS_LOG(ERROR) << "For ops: " << anf_prim->name() << ", call InferShapeAndType failed.";
331     return lite::RET_ERROR;
332   }
333   return RET_OK;
334 }
335 
ConvertAbstractListToNCOrNH(const CNodePtr & cnode,AbstractBasePtrList abs_list,FormatTransNodeType perm,bool * changed)336 STATUS NodeInferShape::ConvertAbstractListToNCOrNH(const CNodePtr &cnode, AbstractBasePtrList abs_list,
337                                                    FormatTransNodeType perm, bool *changed) {
338   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
339   MS_ERROR_IF_NULL_W_RET_VAL(changed, lite::RET_ERROR);
340   std::vector<size_t> insert_index;
341   *changed = false;
342   if (GetFormatSensitiveOpInsertIndex(cnode, &insert_index) != RET_OK) {
343     MS_LOG(ERROR) << "GetFormatSensitiveOpInsertIndex failed.";
344     return RET_ERROR;
345   }
346   if (insert_index.size() == 0) {
347     MS_LOG(DEBUG) << "op don't meet condition.";
348     return lite::RET_OK;
349   }
350   *changed = true;
351   for (auto &index : insert_index) {
352     if ((index < 1) || index > abs_list.size()) {
353       MS_LOG(ERROR) << "index is invalid.";
354       return lite::RET_ERROR;
355     }
356     if (ConvertAbstractFormatShape(abs_list[index - 1], perm) != lite::RET_OK) {
357       MS_LOG(ERROR) << "ConvertAbstract failed.";
358       return lite::RET_ERROR;
359     }
360   }
361   return lite::RET_OK;
362 }
363 
SetCNodeAbstractByConvert(const CNodePtr & cnode,const AbstractBasePtr & result,STATUS infer_ret,bool change,FormatTransNodeType perm,const Format & format)364 STATUS NodeInferShape::SetCNodeAbstractByConvert(const CNodePtr &cnode, const AbstractBasePtr &result, STATUS infer_ret,
365                                                  bool change, FormatTransNodeType perm, const Format &format) {
366   AbstractBasePtr abs = result;
367   if (abs == nullptr) {
368     abs = cnode->abstract();
369     if (abs == nullptr) {
370       MS_LOG(ERROR) << "abstract is nullptr.";
371       return lite::RET_ERROR;
372     }
373   }
374   size_t output_size;
375   if (utils::isa<abstract::AbstractTuple>(abs)) {
376     auto abs_tuple = abs->cast_ptr<abstract::AbstractTuple>();
377     AbstractBasePtrList abstract_list;
378     output_size = abs_tuple->size();
379     if (output_size == 0 || (*abs_tuple)[0]->isa<abstract::AbstractScalar>()) {
380       ShapeVector ori_shape = {static_cast<int64_t>(output_size)};
381       BaseShapePtr new_shape = std::make_shared<abstract::Shape>(ori_shape);
382       TypeId type_id = static_cast<TypeId>(kNumberTypeFloat32);
383 
384       if (output_size != 0) {
385         auto scalar_type_ptr = (*abs_tuple)[0]->cast<abstract::AbstractScalarPtr>()->GetTypeTrack();
386         MS_CHECK_TRUE_MSG(scalar_type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
387         type_id = scalar_type_ptr->type_id();
388       }
389       auto type_ptr = TypeIdToType(type_id);
390       auto out_abs = std::make_shared<abstract::AbstractTensor>(type_ptr, new_shape);
391       AbstractBasePtr new_result;
392       if (ConvertAbstract(out_abs, &new_result, change, perm) != RET_OK) {
393         MS_LOG(ERROR) << "ConvertAbstract failed.";
394         return lite::RET_ERROR;
395       }
396       cnode->set_abstract(new_result);
397       output_size = 1;
398     } else {
399       for (size_t it = 0; it < output_size; ++it) {
400         auto abs_temp = (*abs_tuple)[it];
401         AbstractBasePtr new_result;
402         if (ConvertAbstract(abs_temp, &new_result, change, perm) != RET_OK) {
403           MS_LOG(ERROR) << "ConvertAbstract failed.";
404           return lite::RET_ERROR;
405         }
406         abstract_list.emplace_back(new_result);
407       }
408       auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
409       CHECK_NULL_RETURN(new_abstract_list);
410       cnode->set_abstract(new_abstract_list);
411     }
412   } else if (utils::isa<abstract::AbstractTensor>(abs)) {
413     AbstractBasePtr new_result;
414     if (ConvertAbstract(abs, &new_result, change, perm) != RET_OK) {
415       MS_LOG(ERROR) << "ConvertAbstract failed.";
416       return lite::RET_ERROR;
417     }
418     cnode->set_abstract(new_result);
419     output_size = 1;
420   } else if (utils::isa<abstract::AbstractScalar>(abs)) {
421     ShapeVector ori_shape = {1};
422     BaseShapePtr new_shape = std::make_shared<abstract::Shape>(ori_shape);
423     auto scalar_type_ptr = abs->cast<abstract::AbstractScalarPtr>()->GetTypeTrack();
424     MS_CHECK_TRUE_MSG(scalar_type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
425     auto out_abs = std::make_shared<abstract::AbstractTensor>(scalar_type_ptr, new_shape);
426     AbstractBasePtr new_result;
427     if (ConvertAbstract(out_abs, &new_result, change, perm) != RET_OK) {
428       MS_LOG(ERROR) << "ConvertAbstract failed.";
429       return lite::RET_ERROR;
430     }
431     cnode->set_abstract(new_result);
432     output_size = 1;
433   } else {
434     MS_LOG(ERROR) << "Unknown abstract type :" << abs;
435     return lite::RET_ERROR;
436   }
437 
438   std::vector<int64_t> outputs_format(output_size, static_cast<int64_t>(format));
439   auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
440   if (anf_prim == nullptr) {
441     MS_LOG(ERROR) << "primitive is nullptr";
442     return lite::RET_ERROR;
443   }
444   (void)anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
445   return lite::RET_OK;
446 }
447 
InferShapeByOps(const CNodePtr & cnode,bool invalid)448 STATUS NodeInferShape::InferShapeByOps(const CNodePtr &cnode, bool invalid) {
449   CHECK_NULL_RETURN(cnode);
450   STATUS infer_ret = RET_OK;
451   AbstractBasePtrList abs_list;
452   if (LiteTensorExtractor::GetCNodeInputAbstractLists(cnode, &abs_list) != RET_OK) {
453     MS_LOG(ERROR) << cnode->fullname_with_scope() << " GetCNodeInputAbstractLists failed.";
454     return lite::RET_ERROR;
455   }
456 
457   auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
458   if (anf_prim == nullptr) {
459     MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is nullptr";
460     return lite::RET_ERROR;
461   }
462   (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
463 
464   if (LiteTensorExtractor::GetCNodeConstInputToAbstract(cnode, abs_list, fmk_type_, train_flag_) != RET_OK) {
465     MS_LOG(ERROR) << cnode->fullname_with_scope() << " GetCNodeConstInputToAbstract failed.";
466     return RET_ERROR;
467   }
468   Format ori_format = Format::NHWC;
469   if (anf_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
470     ori_format = static_cast<Format>(GetValue<int64_t>(anf_prim->GetAttr(mindspore::ops::kFormat)));
471   }
472   bool changed = false;
473   if (ori_format == Format::NHWC) {
474     if (ConvertAbstractListToNCOrNH(cnode, abs_list, kNHWC2NCHW, &changed) != RET_OK) {
475       MS_LOG(ERROR) << cnode->fullname_with_scope() << " ConvertAbstractToNCOrNH failed.";
476       return RET_ERROR;
477     }
478   }
479   (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int>(Format::NCHW)));
480   AbstractBasePtr result;
481   try {
482     infer_ret = OpsInferShape(anf_prim, abs_list, &result, invalid);
483   } catch (const std::exception &e) {
484     MS_LOG(WARNING) << "InferShapeByOps for op: " << cnode->fullname_with_scope() << " failed. " << e.what();
485     throw;
486   }
487   (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int64_t>(ori_format)));
488   if (infer_ret == lite::RET_OK) {
489     (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
490     auto input_format = NHWC;
491     (void)opt::DetermineCertainVarInputFormat(cnode, 1, &input_format);
492     auto set_status = SetCNodeAbstractByConvert(cnode, result, infer_ret, changed, kNCHW2NHWC, input_format);
493     if (set_status != lite::RET_OK) {
494       MS_LOG(ERROR) << cnode->fullname_with_scope() << " SetCNodeAbstractByConvert failed.";
495       return set_status;
496     }
497   }
498 
499   return infer_ret;
500 }
501 
GetInputShape(const CNodePtr & cnode,size_t index)502 std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {
503   MS_ASSERT(cnode != nullptr);
504   if (index >= cnode->size()) {
505     return {};
506   }
507   lite::DataInfo data_info;
508   int status = lite::RET_OK;
509   CNodePtr base_node = cnode;
510   size_t position = index;
511   if (CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTuple) ||
512       CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTupleV2)) {
513     base_node = cnode->input(index)->cast<CNodePtr>();
514     position = 1;
515   }
516   if (utils::isa<CNode>(base_node->input(position))) {
517     status = lite::FetchDataFromCNode(base_node, position, &data_info);
518   } else if (utils::isa<Parameter>(base_node->input(position))) {
519     status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, &data_info, false);
520   } else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
521     status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
522   } else {
523     MS_LOG(ERROR) << "input node is invalid.";
524     return {};
525   }
526   if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
527     MS_LOG(ERROR) << "fetch data failed.";
528     return {};
529   }
530   return data_info.shape_;
531 }
532 
GetIntVecInput(const CNodePtr & cnode,size_t index)533 std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) {
534   MS_ASSERT(cnode != nullptr);
535   if (index >= cnode->size()) {
536     return {};
537   }
538   auto origin_inputs = cnode->inputs();
539   std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
540   cnode->set_inputs(specify_inputs);
541   std::vector<TensorPtr> specify_tensors;
542   if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_, false) !=
543         lite::RET_OK ||
544       specify_tensors.empty()) {
545     cnode->set_inputs(origin_inputs);
546     return {};
547   }
548   cnode->set_inputs(origin_inputs);
549   std::vector<int> tensor_data;
550   if (specify_tensors.front()->data_type() != kNumberTypeInt32 &&
551       specify_tensors.front()->data_type() != kNumberTypeInt) {
552     return {};
553   }
554   if (specify_tensors.front()->shape().size() != 1) {
555     return {};
556   }
557   MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
558   tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
559   if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
560                specify_tensors.front()->Size()) != EOK) {
561     return {};
562   }
563   return tensor_data;
564 }
565 
SetCNodeAbstract(const std::shared_ptr<CNode> & cnode,const std::vector<lite::Tensor * > & outputs,int status)566 STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs,
567                                         int status) {
568   MS_ASSERT(cnode != nullptr);
569   if (outputs.size() == 0) {
570     MS_LOG(ERROR) << "empty output_tensors";
571     return RET_ERROR;
572   }
573   auto origin_abstract = cnode->abstract();
574   MS_ASSERT(origin_abstract != nullptr);
575   if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) {
576     auto tensor = outputs.front();
577     auto new_abstract = ConvertLiteTensorToAbstract(tensor);
578     if (new_abstract == nullptr) {
579       MS_LOG(ERROR) << "new abstract failed.";
580       return RET_ERROR;
581     }
582     if (status == lite::RET_INFER_INVALID) {
583       if (tensor->data_type() == kObjectTypeTensorType) {
584         ShapeVector shape = {0};
585         auto abstract_shape = std::make_shared<abstract::Shape>(shape);
586         CHECK_NULL_RETURN(abstract_shape);
587         new_abstract->set_shape(abstract_shape);
588       }
589     }
590     cnode->set_abstract(new_abstract);
591   } else {
592     AbstractBasePtrList abstract_list;
593     for (size_t i = 0; i < outputs.size(); i++) {
594       auto tensor = outputs.at(i);
595       auto new_abstract = ConvertLiteTensorToAbstract(tensor);
596       if (new_abstract == nullptr) {
597         MS_LOG(ERROR) << "new abstract failed.";
598         return RET_ERROR;
599       }
600       if (status == lite::RET_INFER_INVALID) {
601         if (tensor->data_type() == kObjectTypeTensorType) {
602           ShapeVector shape = {0};
603           auto abstract_shape = std::make_shared<abstract::Shape>(shape);
604           CHECK_NULL_RETURN(abstract_shape);
605           new_abstract->set_shape(abstract_shape);
606         }
607       }
608       abstract_list.emplace_back(new_abstract);
609     }
610     auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
611     CHECK_NULL_RETURN(new_abstract_list);
612     cnode->set_abstract(new_abstract_list);
613   }
614   return RET_OK;
615 }
616 
ConvertLiteTensorToAbstract(lite::Tensor * tensor)617 abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
618   MS_ASSERT(tensor != nullptr);
619   if (tensor->data_type() == kObjectTypeTensorType) {
620     return ConvertTensorListToAbstract(tensor);
621   }
622   auto tensor_info = NewTensorInfo(tensor);
623   if (tensor_info == nullptr) {
624     MS_LOG(ERROR) << "new tensor::Tensor failed";
625     return nullptr;
626   }
627   return tensor_info->ToAbstract();
628 }
629 
630 // stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type.
631 // both of them is different in term of shape and type.
ConvertTensorListToAbstract(lite::Tensor * tensor)632 abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) {
633   MS_ASSERT(tensor != nullptr);
634   auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
635   if (tensor_list == nullptr) {
636     MS_LOG(ERROR) << "cast tensor_list failed";
637     return nullptr;
638   }
639   std::vector<int> shape(tensor->shape());
640   std::vector<int64_t> shape_vector(shape.begin(), shape.end());
641   auto tensor_list_abstract =
642     std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector);
643   if (tensor_list_abstract == nullptr) {
644     MS_LOG(ERROR) << "new AbstractTensor failed";
645     return nullptr;
646   }
647   auto elememt_shape = tensor_list->element_shape();
648   std::vector<int> data_info;
649   data_info.push_back(tensor_list->tensors_data_type());
650   data_info.push_back(elememt_shape.size());
651   std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info));
652   data_info.push_back(tensor_list->tensors().size());
653   for (size_t i = 0; i < tensor_list->tensors().size(); ++i) {
654     auto tensor_mem = tensor_list->tensors()[i];
655     auto tensor_mem_shape = tensor_mem->shape();
656     data_info.push_back(tensor_mem_shape.size());
657     std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info));
658   }
659   std::vector<int64_t> data_shape;
660   data_shape.push_back(data_info.size());
661   auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32);
662   if (tensor_info == nullptr) {
663     MS_LOG(ERROR) << "new tensor::Tensor failed";
664     return nullptr;
665   }
666   tensor_list_abstract->set_value(tensor_info);
667   return tensor_list_abstract;
668 }
669 }  // namespace opt
670 }  // namespace mindspore
671