1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/node_infershape.h"
19 #include <memory>
20 #include <vector>
21 #include <algorithm>
22 #include <unordered_set>
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "mindspore/core/ops/comparison_ops.h"
27 #include "mindspore/core/ops/random_ops.h"
28 #include "src/common/primitive_t_utils.h"
29 #include "tools/common/node_util.h"
30 #include "tools/common/tensor_util.h"
31 #include "src/common/utils.h"
32 #include "src/common/ops/populate/populate_register.h"
33 #include "src/common/ops/anf_utils.h"
34 #include "src/litert/infer_manager.h"
35 #include "src/tensorlist.h"
36 #include "src/registry/kernel_interface_registry.h"
37 #include "tools/optimizer/graph/lite_tensor_extractor.h"
38 #include "nnacl/op_base.h"
39 #include "ops/op_utils.h"
40 #include "tools/optimizer/format/to_nchw_format.h"
41 #include "tools/optimizer/format/to_nhwc_format.h"
42 #include "tools/common/graph_util.h"
43 #include "src/common/common.h"
44
45 namespace mindspore {
46 namespace opt {
47 static const std::unordered_set<PrimitivePtr> kNNACLToOpsInfer = {
48 // arithmetic_self
49 prim::kPrimAbs,
50 prim::kPrimAsin,
51 prim::kPrimAsinh,
52 prim::kPrimACos,
53 prim::kPrimAcosh,
54 prim::kPrimAtanh,
55 prim::kPrimCos,
56 prim::kPrimCosh,
57 prim::kPrimCeLU,
58 prim::kPrimSeLU,
59 prim::kPrimHSwish,
60 prim::kPrimMatrixDeterminant,
61 prim::kPrimLog,
62 prim::kPrimLog1p,
63 prim::kPrimSquare,
64 prim::kPrimSqrt,
65 prim::kPrimRsqrt,
66 prim::kPrimSin,
67 prim::kPrimSinh,
68 prim::kPrimFloor,
69 prim::kPrimCeil,
70 prim::kPrimRound,
71 prim::kPrimNeg,
72 prim::kPrimReciprocal,
73 prim::kPrimErf,
74 prim::kPrimSign,
75 prim::kPrimSoftsign,
76 prim::kPrimMultinomial,
77 // arithmetic
78 prim::kPrimFloorDiv,
79 prim::kPrimFloorMod,
80 prim::kPrimLogicalAnd,
81 prim::kPrimLogicalNot,
82 prim::kPrimLogicalOr,
83 prim::kPrimLogicalXor,
84 prim::kPrimMaximum,
85 prim::kPrimMinimum,
86 prim::kPrimMod,
87 prim::kPrimSquaredDifference,
88 prim::kPrimLeftShift,
89 prim::kPrimRightShift,
90 prim::kPrimROIAlign,
91 // tuple op
92 prim::kPrimTupleGetItem,
93 prim::kPrimMakeTuple,
94 prim::kPrimMakeTupleV2,
95 // nnacl/infer/common_infer.c
96 prim::kPrimClip,
97 prim::kPrimElu,
98 prim::kPrimLeakyRelu,
99 prim::kPrimLrn,
100 prim::kPrimOnesLike,
101 prim::kPrimReverseSequence,
102 prim::kPrimReverseV2,
103 prim::kPrimSmoothL1Loss,
104 prim::kPrimZerosLike,
105 // format op
106 prim::kPrimResize,
107 // compare op
108 prim::kPrimEqual,
109 prim::kPrimGreater,
110 prim::kPrimGreaterEqual,
111 prim::kPrimLess,
112 prim::kPrimLessEqual,
113 prim::kPrimNotEqual,
114
115 prim::kPrimActivation,
116 prim::kPrimArgMaxFusion,
117 prim::kPrimArgMinFusion,
118 prim::kPrimGLU,
119 prim::kPrimGridSampler2D,
120 prim::kPrimGridSampler3D,
121 prim::kPrimDeformableConv2d,
122 // grad op
123 prim::kPrimActivationGrad,
124 prim::kPrimAbsGrad,
125 prim::kPrimBinaryCrossEntropyGrad,
126 prim::kPrimLogGrad,
127 prim::kPrimMaximumGrad,
128 prim::kPrimMinimumGrad,
129 prim::kPrimNegGrad,
130 prim::kPrimRsqrtGrad,
131 prim::kPrimSqrtGrad,
132 prim::kPrimSmoothL1LossGrad,
133 prim::kPrimGridSampler2D,
134 };
135
136 namespace {
137 constexpr int kInputChannal = 3;
138 constexpr size_t INITIAL_SIZE = 1024;
RectifyFormat(const std::vector<lite::Tensor * > & inputs,FmkType fmk_type)139 void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
140 MS_ASSERT(cnode != nullptr);
141 if (fmk_type != converter::kFmkTypeOnnx) {
142 return;
143 }
144 for (auto &input : inputs) {
145 auto shape = input->shape();
146 if (shape.size() == kInputSizeFour && shape[kInputIndexThree] == kInputChannal && shape[1] == -1) {
147 input->set_format(mindspore::NHWC);
148 }
149 }
150 }
151
NewTensorInfo(const lite::Tensor * tensor)152 tensor::TensorPtr NewTensorInfo(const lite::Tensor *tensor) {
153 std::vector<int> shape(tensor->shape());
154 std::vector<int64_t> shape_vector(shape.begin(), shape.end());
155 auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
156 if (tensor_info == nullptr) {
157 MS_LOG(ERROR) << "new tensor::Tensor failed";
158 return nullptr;
159 }
160 return tensor_info;
161 }
162
ConvertAbstract(const AbstractBasePtr & src_abs,AbstractBasePtr * dst_abs,bool change,FormatTransNodeType perm)163 STATUS ConvertAbstract(const AbstractBasePtr &src_abs, AbstractBasePtr *dst_abs, bool change,
164 FormatTransNodeType perm) {
165 if (SetAbstractTensorInfo(src_abs) != RET_OK) {
166 MS_LOG(ERROR) << "SetAbstractTensorInfo failed";
167 return lite::RET_ERROR;
168 }
169 *dst_abs = src_abs;
170 if (change) {
171 if (ConvertAbstractFormatShape(*dst_abs, perm) != RET_OK) {
172 MS_LOG(ERROR) << "ConvertAbstractFormatShape failed";
173 return lite::RET_ERROR;
174 }
175 }
176
177 // change core/ops dynamic rank {-2} to Lite dynamic shape {-1}, will be removed after calling core/infer
178 ShapeVector shape;
179 if (opt::FetchShapeFromAbstract(*dst_abs, &shape) != RET_OK) {
180 MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
181 return RET_ERROR;
182 }
183 if (IsDynamicRank(shape)) {
184 auto nnacl_dynamic_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny});
185 (*dst_abs)->set_shape(nnacl_dynamic_shape);
186 }
187 return RET_OK;
188 }
189 } // namespace
190
JudgeOpSupportNNACLInfer(const CNodePtr & cnode)191 bool JudgeOpSupportNNACLInfer(const CNodePtr &cnode) {
192 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
193 if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
194 return true;
195 }
196 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
197 if (prim_t == nullptr) {
198 return false;
199 }
200 auto parameter_gen =
201 lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim_t->value.type), lite::SCHEMA_CUR);
202 if (parameter_gen == nullptr) {
203 prim_t.reset();
204 return false;
205 }
206 return true;
207 }
208
JudgeOpSupportOpsInfer(const CNodePtr & cnode)209 bool JudgeOpSupportOpsInfer(const CNodePtr &cnode) {
210 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
211 for (const auto &type : kNNACLToOpsInfer) {
212 if (CheckPrimitiveType(cnode, type)) {
213 return true;
214 }
215 }
216 return false;
217 }
218
JudgeOpSupportInfer(const CNodePtr & cnode)219 bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
220 return JudgeOpSupportOpsInfer(cnode) || JudgeOpSupportNNACLInfer(cnode);
221 }
222
InferShapeByNNACL(const CNodePtr & cnode)223 STATUS NodeInferShape::InferShapeByNNACL(const CNodePtr &cnode) {
224 MS_ASSERT(cnode != nullptr);
225 auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
226 if (anf_prim == nullptr) {
227 MS_LOG(DEBUG) << cnode->fullname_with_scope() << "'s cnode primitive is nullptr";
228 return lite::RET_ERROR;
229 }
230 (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
231 std::vector<TensorPtr> inputs_ptr;
232 if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, false) != lite::RET_OK) {
233 MS_LOG(ERROR) << cnode->fullname_with_scope() << " get inputs failed.";
234 return lite::RET_ERROR;
235 }
236 std::vector<TensorPtr> outputs_ptr;
237 if (LiteTensorExtractor::GetCNodeOutputTensors(cnode, &outputs_ptr, train_flag_) != lite::RET_OK) {
238 MS_LOG(ERROR) << cnode->fullname_with_scope() << " get outputs failed.";
239 return lite::RET_ERROR;
240 }
241 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
242 if (prim_t == nullptr) {
243 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " get lite prim_t is nullptr";
244 return lite::RET_ERROR;
245 }
246 flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
247 auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
248 if (prim == nullptr) {
249 MS_LOG(ERROR) << cnode->fullname_with_scope() << " get primitive failed.";
250 fbb.Clear();
251 return lite::RET_ERROR;
252 }
253 std::vector<lite::Tensor *> inputs;
254 (void)std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(inputs),
255 [](const TensorPtr &input) { return input.get(); });
256 std::vector<lite::Tensor *> outputs;
257 (void)std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(outputs),
258 [](const TensorPtr &output) { return output.get(); });
259 auto ret = KernelInferShape(inputs, outputs, prim, {}, lite::SCHEMA_CUR);
260 if (ret == lite::RET_NOT_SUPPORT) {
261 auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
262 static_cast<int>(prim->value_type()), lite::SCHEMA_CUR);
263 if (parameter_gen == nullptr) {
264 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
265 fbb.Clear();
266 return lite::RET_ERROR;
267 }
268 auto parameter = parameter_gen(prim);
269 if (parameter == nullptr) {
270 MS_LOG(ERROR) << cnode->fullname_with_scope() << " generate nullptr lite op parameter.";
271 fbb.Clear();
272 return lite::RET_ERROR;
273 }
274 RectifyFormat(inputs, fmk_type_);
275 ret = KernelInferShape(inputs, outputs, parameter);
276 if (parameter->destroy_func_ != nullptr) {
277 parameter->destroy_func_(parameter);
278 }
279 free(parameter);
280 parameter = nullptr;
281 }
282 fbb.Clear();
283 if (ret == lite::RET_OK) {
284 (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
285 }
286 if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
287 auto set_status = SetCNodeAbstract(cnode, outputs, ret);
288 (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int64_t>(inputs[0]->format())));
289 if (set_status != lite::RET_OK) {
290 MS_LOG(ERROR) << cnode->fullname_with_scope() << " set CNode abstract failed.";
291 return set_status;
292 }
293 } else {
294 MS_LOG(WARNING) << "InferShapeByNNACL for op: " << cnode->fullname_with_scope() << " failed.";
295 }
296 std::vector<int64_t> outputs_format;
297 (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_format),
298 [](const lite::Tensor *output) { return output->format(); });
299 (void)anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
300 return ret;
301 }
302
InferShape(const CNodePtr & cnode)303 STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
304 MS_ASSERT(cnode != nullptr);
305 STATUS status;
306 if (JudgeOpSupportNNACLInfer(cnode)) {
307 status = InferShapeByNNACL(cnode);
308 } else {
309 MS_LOG(ERROR) << "Unsupported node: " << cnode->fullname_with_scope() << " for infershape.";
310 return RET_ERROR;
311 }
312 return status;
313 }
314
OpsInferShape(const PrimitivePtr & anf_prim,const AbstractBasePtrList & abs_list,AbstractBasePtr * result,bool invalid)315 STATUS NodeInferShape::OpsInferShape(const PrimitivePtr &anf_prim, const AbstractBasePtrList &abs_list,
316 AbstractBasePtr *result, bool invalid) {
317 auto found = abstract::GetPrimitiveInferImpl(anf_prim);
318 if (!found.has_value()) {
319 MS_LOG(ERROR) << "Can't find the infer impl for ops: " << anf_prim->name();
320 return lite::RET_ERROR;
321 }
322 auto infer = found.value();
323 if (!infer.IsImplInferShapeAndType()) {
324 MS_LOG(ERROR) << "For ops: " << anf_prim->name() << ", the InferShapeAndType is not implemented.";
325 return lite::RET_ERROR;
326 }
327
328 *result = found->InferShapeAndType(nullptr, anf_prim, abs_list);
329 if (*result == nullptr) {
330 MS_LOG(ERROR) << "For ops: " << anf_prim->name() << ", call InferShapeAndType failed.";
331 return lite::RET_ERROR;
332 }
333 return RET_OK;
334 }
335
ConvertAbstractListToNCOrNH(const CNodePtr & cnode,AbstractBasePtrList abs_list,FormatTransNodeType perm,bool * changed)336 STATUS NodeInferShape::ConvertAbstractListToNCOrNH(const CNodePtr &cnode, AbstractBasePtrList abs_list,
337 FormatTransNodeType perm, bool *changed) {
338 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
339 MS_ERROR_IF_NULL_W_RET_VAL(changed, lite::RET_ERROR);
340 std::vector<size_t> insert_index;
341 *changed = false;
342 if (GetFormatSensitiveOpInsertIndex(cnode, &insert_index) != RET_OK) {
343 MS_LOG(ERROR) << "GetFormatSensitiveOpInsertIndex failed.";
344 return RET_ERROR;
345 }
346 if (insert_index.size() == 0) {
347 MS_LOG(DEBUG) << "op don't meet condition.";
348 return lite::RET_OK;
349 }
350 *changed = true;
351 for (auto &index : insert_index) {
352 if ((index < 1) || index > abs_list.size()) {
353 MS_LOG(ERROR) << "index is invalid.";
354 return lite::RET_ERROR;
355 }
356 if (ConvertAbstractFormatShape(abs_list[index - 1], perm) != lite::RET_OK) {
357 MS_LOG(ERROR) << "ConvertAbstract failed.";
358 return lite::RET_ERROR;
359 }
360 }
361 return lite::RET_OK;
362 }
363
SetCNodeAbstractByConvert(const CNodePtr & cnode,const AbstractBasePtr & result,STATUS infer_ret,bool change,FormatTransNodeType perm,const Format & format)364 STATUS NodeInferShape::SetCNodeAbstractByConvert(const CNodePtr &cnode, const AbstractBasePtr &result, STATUS infer_ret,
365 bool change, FormatTransNodeType perm, const Format &format) {
366 AbstractBasePtr abs = result;
367 if (abs == nullptr) {
368 abs = cnode->abstract();
369 if (abs == nullptr) {
370 MS_LOG(ERROR) << "abstract is nullptr.";
371 return lite::RET_ERROR;
372 }
373 }
374 size_t output_size;
375 if (utils::isa<abstract::AbstractTuple>(abs)) {
376 auto abs_tuple = abs->cast_ptr<abstract::AbstractTuple>();
377 AbstractBasePtrList abstract_list;
378 output_size = abs_tuple->size();
379 if (output_size == 0 || (*abs_tuple)[0]->isa<abstract::AbstractScalar>()) {
380 ShapeVector ori_shape = {static_cast<int64_t>(output_size)};
381 BaseShapePtr new_shape = std::make_shared<abstract::Shape>(ori_shape);
382 TypeId type_id = static_cast<TypeId>(kNumberTypeFloat32);
383
384 if (output_size != 0) {
385 auto scalar_type_ptr = (*abs_tuple)[0]->cast<abstract::AbstractScalarPtr>()->GetTypeTrack();
386 MS_CHECK_TRUE_MSG(scalar_type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
387 type_id = scalar_type_ptr->type_id();
388 }
389 auto type_ptr = TypeIdToType(type_id);
390 auto out_abs = std::make_shared<abstract::AbstractTensor>(type_ptr, new_shape);
391 AbstractBasePtr new_result;
392 if (ConvertAbstract(out_abs, &new_result, change, perm) != RET_OK) {
393 MS_LOG(ERROR) << "ConvertAbstract failed.";
394 return lite::RET_ERROR;
395 }
396 cnode->set_abstract(new_result);
397 output_size = 1;
398 } else {
399 for (size_t it = 0; it < output_size; ++it) {
400 auto abs_temp = (*abs_tuple)[it];
401 AbstractBasePtr new_result;
402 if (ConvertAbstract(abs_temp, &new_result, change, perm) != RET_OK) {
403 MS_LOG(ERROR) << "ConvertAbstract failed.";
404 return lite::RET_ERROR;
405 }
406 abstract_list.emplace_back(new_result);
407 }
408 auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
409 CHECK_NULL_RETURN(new_abstract_list);
410 cnode->set_abstract(new_abstract_list);
411 }
412 } else if (utils::isa<abstract::AbstractTensor>(abs)) {
413 AbstractBasePtr new_result;
414 if (ConvertAbstract(abs, &new_result, change, perm) != RET_OK) {
415 MS_LOG(ERROR) << "ConvertAbstract failed.";
416 return lite::RET_ERROR;
417 }
418 cnode->set_abstract(new_result);
419 output_size = 1;
420 } else if (utils::isa<abstract::AbstractScalar>(abs)) {
421 ShapeVector ori_shape = {1};
422 BaseShapePtr new_shape = std::make_shared<abstract::Shape>(ori_shape);
423 auto scalar_type_ptr = abs->cast<abstract::AbstractScalarPtr>()->GetTypeTrack();
424 MS_CHECK_TRUE_MSG(scalar_type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
425 auto out_abs = std::make_shared<abstract::AbstractTensor>(scalar_type_ptr, new_shape);
426 AbstractBasePtr new_result;
427 if (ConvertAbstract(out_abs, &new_result, change, perm) != RET_OK) {
428 MS_LOG(ERROR) << "ConvertAbstract failed.";
429 return lite::RET_ERROR;
430 }
431 cnode->set_abstract(new_result);
432 output_size = 1;
433 } else {
434 MS_LOG(ERROR) << "Unknown abstract type :" << abs;
435 return lite::RET_ERROR;
436 }
437
438 std::vector<int64_t> outputs_format(output_size, static_cast<int64_t>(format));
439 auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
440 if (anf_prim == nullptr) {
441 MS_LOG(ERROR) << "primitive is nullptr";
442 return lite::RET_ERROR;
443 }
444 (void)anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
445 return lite::RET_OK;
446 }
447
InferShapeByOps(const CNodePtr & cnode,bool invalid)448 STATUS NodeInferShape::InferShapeByOps(const CNodePtr &cnode, bool invalid) {
449 CHECK_NULL_RETURN(cnode);
450 STATUS infer_ret = RET_OK;
451 AbstractBasePtrList abs_list;
452 if (LiteTensorExtractor::GetCNodeInputAbstractLists(cnode, &abs_list) != RET_OK) {
453 MS_LOG(ERROR) << cnode->fullname_with_scope() << " GetCNodeInputAbstractLists failed.";
454 return lite::RET_ERROR;
455 }
456
457 auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
458 if (anf_prim == nullptr) {
459 MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is nullptr";
460 return lite::RET_ERROR;
461 }
462 (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
463
464 if (LiteTensorExtractor::GetCNodeConstInputToAbstract(cnode, abs_list, fmk_type_, train_flag_) != RET_OK) {
465 MS_LOG(ERROR) << cnode->fullname_with_scope() << " GetCNodeConstInputToAbstract failed.";
466 return RET_ERROR;
467 }
468 Format ori_format = Format::NHWC;
469 if (anf_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
470 ori_format = static_cast<Format>(GetValue<int64_t>(anf_prim->GetAttr(mindspore::ops::kFormat)));
471 }
472 bool changed = false;
473 if (ori_format == Format::NHWC) {
474 if (ConvertAbstractListToNCOrNH(cnode, abs_list, kNHWC2NCHW, &changed) != RET_OK) {
475 MS_LOG(ERROR) << cnode->fullname_with_scope() << " ConvertAbstractToNCOrNH failed.";
476 return RET_ERROR;
477 }
478 }
479 (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int>(Format::NCHW)));
480 AbstractBasePtr result;
481 try {
482 infer_ret = OpsInferShape(anf_prim, abs_list, &result, invalid);
483 } catch (const std::exception &e) {
484 MS_LOG(WARNING) << "InferShapeByOps for op: " << cnode->fullname_with_scope() << " failed. " << e.what();
485 throw;
486 }
487 (void)anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(static_cast<int64_t>(ori_format)));
488 if (infer_ret == lite::RET_OK) {
489 (void)anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
490 auto input_format = NHWC;
491 (void)opt::DetermineCertainVarInputFormat(cnode, 1, &input_format);
492 auto set_status = SetCNodeAbstractByConvert(cnode, result, infer_ret, changed, kNCHW2NHWC, input_format);
493 if (set_status != lite::RET_OK) {
494 MS_LOG(ERROR) << cnode->fullname_with_scope() << " SetCNodeAbstractByConvert failed.";
495 return set_status;
496 }
497 }
498
499 return infer_ret;
500 }
501
GetInputShape(const CNodePtr & cnode,size_t index)502 std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {
503 MS_ASSERT(cnode != nullptr);
504 if (index >= cnode->size()) {
505 return {};
506 }
507 lite::DataInfo data_info;
508 int status = lite::RET_OK;
509 CNodePtr base_node = cnode;
510 size_t position = index;
511 if (CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTuple) ||
512 CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTupleV2)) {
513 base_node = cnode->input(index)->cast<CNodePtr>();
514 position = 1;
515 }
516 if (utils::isa<CNode>(base_node->input(position))) {
517 status = lite::FetchDataFromCNode(base_node, position, &data_info);
518 } else if (utils::isa<Parameter>(base_node->input(position))) {
519 status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, &data_info, false);
520 } else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
521 status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
522 } else {
523 MS_LOG(ERROR) << "input node is invalid.";
524 return {};
525 }
526 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
527 MS_LOG(ERROR) << "fetch data failed.";
528 return {};
529 }
530 return data_info.shape_;
531 }
532
GetIntVecInput(const CNodePtr & cnode,size_t index)533 std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) {
534 MS_ASSERT(cnode != nullptr);
535 if (index >= cnode->size()) {
536 return {};
537 }
538 auto origin_inputs = cnode->inputs();
539 std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
540 cnode->set_inputs(specify_inputs);
541 std::vector<TensorPtr> specify_tensors;
542 if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_, false) !=
543 lite::RET_OK ||
544 specify_tensors.empty()) {
545 cnode->set_inputs(origin_inputs);
546 return {};
547 }
548 cnode->set_inputs(origin_inputs);
549 std::vector<int> tensor_data;
550 if (specify_tensors.front()->data_type() != kNumberTypeInt32 &&
551 specify_tensors.front()->data_type() != kNumberTypeInt) {
552 return {};
553 }
554 if (specify_tensors.front()->shape().size() != 1) {
555 return {};
556 }
557 MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
558 tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
559 if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
560 specify_tensors.front()->Size()) != EOK) {
561 return {};
562 }
563 return tensor_data;
564 }
565
SetCNodeAbstract(const std::shared_ptr<CNode> & cnode,const std::vector<lite::Tensor * > & outputs,int status)566 STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs,
567 int status) {
568 MS_ASSERT(cnode != nullptr);
569 if (outputs.size() == 0) {
570 MS_LOG(ERROR) << "empty output_tensors";
571 return RET_ERROR;
572 }
573 auto origin_abstract = cnode->abstract();
574 MS_ASSERT(origin_abstract != nullptr);
575 if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) {
576 auto tensor = outputs.front();
577 auto new_abstract = ConvertLiteTensorToAbstract(tensor);
578 if (new_abstract == nullptr) {
579 MS_LOG(ERROR) << "new abstract failed.";
580 return RET_ERROR;
581 }
582 if (status == lite::RET_INFER_INVALID) {
583 if (tensor->data_type() == kObjectTypeTensorType) {
584 ShapeVector shape = {0};
585 auto abstract_shape = std::make_shared<abstract::Shape>(shape);
586 CHECK_NULL_RETURN(abstract_shape);
587 new_abstract->set_shape(abstract_shape);
588 }
589 }
590 cnode->set_abstract(new_abstract);
591 } else {
592 AbstractBasePtrList abstract_list;
593 for (size_t i = 0; i < outputs.size(); i++) {
594 auto tensor = outputs.at(i);
595 auto new_abstract = ConvertLiteTensorToAbstract(tensor);
596 if (new_abstract == nullptr) {
597 MS_LOG(ERROR) << "new abstract failed.";
598 return RET_ERROR;
599 }
600 if (status == lite::RET_INFER_INVALID) {
601 if (tensor->data_type() == kObjectTypeTensorType) {
602 ShapeVector shape = {0};
603 auto abstract_shape = std::make_shared<abstract::Shape>(shape);
604 CHECK_NULL_RETURN(abstract_shape);
605 new_abstract->set_shape(abstract_shape);
606 }
607 }
608 abstract_list.emplace_back(new_abstract);
609 }
610 auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
611 CHECK_NULL_RETURN(new_abstract_list);
612 cnode->set_abstract(new_abstract_list);
613 }
614 return RET_OK;
615 }
616
ConvertLiteTensorToAbstract(lite::Tensor * tensor)617 abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
618 MS_ASSERT(tensor != nullptr);
619 if (tensor->data_type() == kObjectTypeTensorType) {
620 return ConvertTensorListToAbstract(tensor);
621 }
622 auto tensor_info = NewTensorInfo(tensor);
623 if (tensor_info == nullptr) {
624 MS_LOG(ERROR) << "new tensor::Tensor failed";
625 return nullptr;
626 }
627 return tensor_info->ToAbstract();
628 }
629
630 // stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type.
631 // both of them is different in term of shape and type.
ConvertTensorListToAbstract(lite::Tensor * tensor)632 abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) {
633 MS_ASSERT(tensor != nullptr);
634 auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
635 if (tensor_list == nullptr) {
636 MS_LOG(ERROR) << "cast tensor_list failed";
637 return nullptr;
638 }
639 std::vector<int> shape(tensor->shape());
640 std::vector<int64_t> shape_vector(shape.begin(), shape.end());
641 auto tensor_list_abstract =
642 std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector);
643 if (tensor_list_abstract == nullptr) {
644 MS_LOG(ERROR) << "new AbstractTensor failed";
645 return nullptr;
646 }
647 auto elememt_shape = tensor_list->element_shape();
648 std::vector<int> data_info;
649 data_info.push_back(tensor_list->tensors_data_type());
650 data_info.push_back(elememt_shape.size());
651 std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info));
652 data_info.push_back(tensor_list->tensors().size());
653 for (size_t i = 0; i < tensor_list->tensors().size(); ++i) {
654 auto tensor_mem = tensor_list->tensors()[i];
655 auto tensor_mem_shape = tensor_mem->shape();
656 data_info.push_back(tensor_mem_shape.size());
657 std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info));
658 }
659 std::vector<int64_t> data_shape;
660 data_shape.push_back(data_info.size());
661 auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32);
662 if (tensor_info == nullptr) {
663 MS_LOG(ERROR) << "new tensor::Tensor failed";
664 return nullptr;
665 }
666 tensor_list_abstract->set_value(tensor_info);
667 return tensor_list_abstract;
668 }
669 } // namespace opt
670 } // namespace mindspore
671