1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/graph/node_infershape.h"
18 #include <memory>
19 #include <vector>
20 #include "tools/common/node_util.h"
21 #include "tools/common/tensor_util.h"
22 #include "src/common/utils.h"
23 #include "src/ops/populate/populate_register.h"
24 #include "src/ops/ops_utils.h"
25 #include "src/runtime/infer_manager.h"
26 #include "src/tensorlist.h"
27 #include "src/registry/kernel_interface_registry.h"
28 #include "nnacl/op_base.h"
29
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr int kInputChannal = 3;
34 constexpr size_t INITIAL_SIZE = 1024;
FreeTensors(std::vector<lite::Tensor * > * tensors)35 void FreeTensors(std::vector<lite::Tensor *> *tensors) {
36 if (tensors == nullptr) {
37 return;
38 }
39 for (auto &v : *tensors) {
40 delete v;
41 v = nullptr;
42 }
43 tensors->resize(0);
44 }
45
RectifyFormat(const std::vector<lite::Tensor * > & inputs,FmkType fmk_type)46 void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
47 MS_ASSERT(cnode != nullptr);
48 if (fmk_type != converter::kFmkTypeOnnx) {
49 return;
50 }
51 for (auto &input : inputs) {
52 auto shape = input->shape();
53 if (shape.size() == kInputSizeFour && shape[kInputIndexThree] == kInputChannal && shape[1] == -1) {
54 input->set_format(mindspore::NHWC);
55 }
56 }
57 }
58
NewTensorInfo(lite::Tensor * tensor)59 tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
60 std::vector<int> shape(tensor->shape());
61 std::vector<int64_t> shape_vector(shape.begin(), shape.end());
62 auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
63 if (tensor_info == nullptr) {
64 MS_LOG(ERROR) << "new tensor::Tensor failed";
65 return nullptr;
66 }
67 return tensor_info;
68 }
69 } // namespace
70
JudgeOpSupportInfer(const CNodePtr & cnode)71 bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
72 MS_ASSERT(cnode != nullptr);
73 if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
74 return true;
75 }
76 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
77 if (prim_t == nullptr) {
78 return false;
79 }
80 auto parameter_gen =
81 lite::PopulateRegistry::GetInstance()->GetParameterCreator(static_cast<int>(prim_t->value.type), lite::SCHEMA_CUR);
82 if (parameter_gen == nullptr) {
83 prim_t.reset();
84 return false;
85 }
86 return true;
87 }
88
InferShape(const CNodePtr & cnode)89 STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
90 MS_ASSERT(cnode != nullptr);
91 auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
92 if (anf_prim == nullptr) {
93 MS_LOG(DEBUG) << "primitive is nullptr";
94 return lite::RET_ERROR;
95 }
96 anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
97 std::vector<lite::Tensor *> inputs;
98 std::vector<lite::Tensor *> outputs;
99 if (GetCNodeInputTensors(cnode, &inputs) != lite::RET_OK) {
100 FreeTensors(&inputs);
101 MS_LOG(ERROR) << "get inputs failed.";
102 return lite::RET_ERROR;
103 }
104 if (GetCNodeOutputTensors(cnode, &outputs) != lite::RET_OK) {
105 FreeTensors(&inputs);
106 FreeTensors(&outputs);
107 MS_LOG(ERROR) << "get outputs failed.";
108 return lite::RET_ERROR;
109 }
110 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
111 if (prim_t == nullptr) {
112 MS_LOG(DEBUG) << "prim_t is nullptr";
113 FreeTensors(&inputs);
114 FreeTensors(&outputs);
115 return lite::RET_ERROR;
116 }
117 flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
118 auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
119 if (prim == nullptr) {
120 MS_LOG(ERROR) << "get primitive failed.";
121 FreeTensors(&inputs);
122 FreeTensors(&outputs);
123 fbb.Clear();
124 return lite::RET_ERROR;
125 }
126 auto ret = KernelInferShape(inputs, outputs, prim, {}, lite::SCHEMA_CUR);
127 if (ret == lite::RET_NOT_SUPPORT) {
128 auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(
129 static_cast<int>(prim->value_type()), lite::SCHEMA_CUR);
130 if (parameter_gen == nullptr) {
131 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
132 FreeTensors(&inputs);
133 FreeTensors(&outputs);
134 fbb.Clear();
135 return lite::RET_ERROR;
136 }
137 auto parameter = parameter_gen(prim);
138 if (parameter == nullptr) {
139 MS_LOG(ERROR) << "parameter is nullptr.";
140 FreeTensors(&inputs);
141 FreeTensors(&outputs);
142 fbb.Clear();
143 return lite::RET_ERROR;
144 }
145 RectifyFormat(inputs, fmk_type_);
146 ret = KernelInferShape(inputs, outputs, parameter);
147 if (parameter->destroy_func_ != nullptr) {
148 parameter->destroy_func_(parameter);
149 }
150 free(parameter);
151 parameter = nullptr;
152 }
153 fbb.Clear();
154 if (ret == lite::RET_OK) {
155 anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
156 }
157 if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
158 auto set_status = SetCNodeAbstract(cnode, outputs, ret);
159 auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
160 MS_CHECK_TRUE_MSG(cnode_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
161 cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
162 if (set_status != lite::RET_OK) {
163 MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
164 FreeTensors(&inputs);
165 FreeTensors(&outputs);
166 return set_status;
167 }
168 } else {
169 MS_LOG(ERROR) << "infer shape failed.";
170 }
171 FreeTensors(&inputs);
172 FreeTensors(&outputs);
173 return ret;
174 }
175
GetInputShape(const CNodePtr & cnode,size_t index)176 std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {
177 MS_ASSERT(cnode != nullptr);
178 if (index >= cnode->size()) {
179 return {};
180 }
181 lite::DataInfo data_info;
182 int status = lite::RET_OK;
183 CNodePtr base_node = cnode;
184 size_t position = index;
185 if (CheckPrimitiveType(cnode->input(index), prim::kPrimMakeTuple) ||
186 CheckPrimitiveType(cnode->input(index), kPrimMakeTupleV2)) {
187 base_node = cnode->input(index)->cast<CNodePtr>();
188 position = 1;
189 }
190 if (utils::isa<CNode>(base_node->input(position))) {
191 status = lite::FetchDataFromCNode(base_node, position, fmk_type_, train_flag_, &data_info);
192 } else if (utils::isa<Parameter>(base_node->input(position))) {
193 status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info);
194 } else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
195 status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info);
196 } else {
197 MS_LOG(ERROR) << "input node is invalid.";
198 return {};
199 }
200 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
201 MS_LOG(ERROR) << "fetch data failed.";
202 return {};
203 }
204 return data_info.shape_;
205 }
206
GetIntVecInput(const CNodePtr & cnode,size_t index)207 std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) {
208 MS_ASSERT(cnode != nullptr);
209 if (index >= cnode->size()) {
210 return {};
211 }
212 auto origin_inputs = cnode->inputs();
213 std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
214 cnode->set_inputs(specify_inputs);
215 std::vector<lite::Tensor *> specify_tensors;
216 if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) {
217 cnode->set_inputs(origin_inputs);
218 return {};
219 }
220 cnode->set_inputs(origin_inputs);
221 std::vector<int> tensor_data;
222 if (specify_tensors.front()->data_type() != kNumberTypeInt32 &&
223 specify_tensors.front()->data_type() != kNumberTypeInt) {
224 FreeTensors(&specify_tensors);
225 return {};
226 }
227 if (specify_tensors.front()->shape().size() != 1) {
228 FreeTensors(&specify_tensors);
229 return {};
230 }
231 MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
232 tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
233 if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
234 tensor_data.size() * sizeof(int)) != EOK) {
235 FreeTensors(&specify_tensors);
236 return {};
237 }
238 return tensor_data;
239 }
240
GetCNodeInputTensors(const CNodePtr & cnode,std::vector<lite::Tensor * > * inputs)241 STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs) {
242 MS_ASSERT(cnode != nullptr);
243 MS_ASSERT(inputs != nullptr);
244 auto origin_inputs = cnode->inputs();
245 lite::RemoveIfDepend(cnode);
246 lite::RemoveIfMakeTuple(cnode);
247 RemoveIfMonad(cnode);
248 std::vector<lite::Tensor *> const_inputs;
249 if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) {
250 MS_LOG(ERROR) << "get const inputs failed.";
251 FreeTensors(&const_inputs);
252 cnode->set_inputs(origin_inputs);
253 return lite::RET_ERROR;
254 }
255 std::vector<lite::Tensor *> var_inputs;
256 if (GetCNodeVarInput(cnode, &var_inputs) != lite::RET_OK) {
257 MS_LOG(ERROR) << "get var inputs failed.";
258 FreeTensors(&var_inputs);
259 cnode->set_inputs(origin_inputs);
260 return lite::RET_ERROR;
261 }
262 size_t const_index = 0;
263 size_t var_index = 0;
264 bool input_valid = true;
265 for (size_t i = 1; i < cnode->size(); ++i) {
266 if (utils::isa<CNodePtr>(cnode->input(i))) {
267 if (var_index >= var_inputs.size()) {
268 MS_LOG(ERROR) << "var inputs size invalid.";
269 input_valid = false;
270 break;
271 }
272 inputs->emplace_back(var_inputs[var_index++]);
273 } else {
274 if (const_index >= const_inputs.size()) {
275 MS_LOG(ERROR) << "const inputs size invalid.";
276 input_valid = false;
277 break;
278 }
279 inputs->emplace_back(const_inputs[const_index++]);
280 }
281 }
282 cnode->set_inputs(origin_inputs);
283 if (!input_valid) {
284 FreeTensors(&const_inputs);
285 FreeTensors(&var_inputs);
286 inputs->resize(0);
287 }
288 return lite::RET_OK;
289 }
290
GetCNodeConstInput(const CNodePtr & cnode,std::vector<lite::Tensor * > * const_ms_inputs)291 STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) {
292 MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr);
293 std::vector<lite::DataInfo> data_infos;
294 for (size_t i = 1; i < cnode->size(); ++i) {
295 if (utils::isa<CNodePtr>(cnode->input(i))) {
296 continue;
297 }
298 STATUS status;
299 lite::DataInfo data_info;
300 if (utils::isa<ParameterPtr>(cnode->input(i))) {
301 status = lite::FetchDataFromParameterNode(cnode, i, fmk_type_, train_flag_, &data_info);
302 } else {
303 status = lite::FetchDataFromValueNode(cnode, i, fmk_type_, train_flag_, &data_info);
304 }
305 if (status == lite::RET_NO_CHANGE) {
306 continue;
307 }
308 if (status != lite::RET_OK) {
309 MS_LOG(ERROR) << "fetch const input data failed.";
310 return status;
311 }
312 data_infos.emplace_back(data_info);
313 }
314 return ConvertToLiteTensor(data_infos, const_ms_inputs);
315 }
316
GetCNodeVarInput(const CNodePtr & cnode,std::vector<lite::Tensor * > * var_ms_inputs)317 STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) {
318 MS_ASSERT(cnode != nullptr);
319 MS_ASSERT(var_ms_inputs != nullptr);
320 for (size_t i = 1; i < cnode->size(); ++i) {
321 if (!utils::isa<CNodePtr>(cnode->input(i))) {
322 continue;
323 }
324 lite::DataInfo data_info;
325 if (lite::FetchDataFromCNode(cnode, i, fmk_type_, train_flag_, &data_info) != lite::RET_OK) {
326 MS_LOG(ERROR) << "parse cnode failed.";
327 return lite::RET_ERROR;
328 }
329 lite::Tensor *tensor = nullptr;
330 if (data_info.data_type_ == kObjectTypeTensorType) {
331 tensor = GetCNodeTensorListVarInput(data_info);
332 } else {
333 tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
334 tensor->set_format((Format)(data_info.format_));
335 }
336 if (tensor == nullptr) {
337 MS_LOG(ERROR) << "new a lite tensor failed";
338 return lite::RET_ERROR;
339 }
340 auto input_cnode = cnode->input(i)->cast<CNodePtr>();
341 MS_ASSERT(input_cnode != nullptr);
342 PrimitivePtr input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
343 if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
344 auto item_input_cnode = input_cnode->input(1)->cast<CNodePtr>();
345 MS_ASSERT(item_input_cnode != nullptr);
346 input_prim = GetValueNode<PrimitivePtr>(item_input_cnode->input(0));
347 }
348 MS_ASSERT(input_prim != nullptr);
349 if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
350 tensor->set_shape({-1});
351 }
352 var_ms_inputs->emplace_back(tensor);
353 }
354 return lite::RET_OK;
355 }
356
GetCNodeTensorListVarInput(const lite::DataInfo & data_info)357 lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(const lite::DataInfo &data_info) {
358 auto tensor_list = new (std::nothrow) lite::TensorList(data_info.shape_, {});
359 if (tensor_list == nullptr) {
360 MS_LOG(ERROR) << "new a lite tensor list failed";
361 return nullptr;
362 }
363 if (data_info.data_.empty()) {
364 return tensor_list;
365 }
366 auto status = tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data()));
367 if (status != lite::RET_OK) {
368 delete tensor_list;
369 MS_LOG(ERROR) << "decode tensor list failed.";
370 return nullptr;
371 }
372 return tensor_list;
373 }
374
GetCNodeOutputTensors(const CNodePtr & cnode,std::vector<lite::Tensor * > * outputs)375 STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) {
376 MS_ASSERT(cnode != nullptr);
377 MS_ASSERT(outputs != nullptr);
378 std::vector<lite::DataInfo> data_infos;
379 if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
380 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
381 if (tuple == nullptr) {
382 MS_LOG(ERROR) << "tuple is nullptr";
383 return lite::RET_ERROR;
384 }
385 auto elements = tuple->elements();
386 for (size_t i = 0; i < elements.size(); i++) {
387 lite::DataInfo data_info;
388 data_info.node_type_ = lite::NodeType_CNode;
389 if (train_flag_) {
390 data_infos.emplace_back(data_info);
391 if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || CheckPrimitiveType(cnode, prim::kPrimAdam)) {
392 break;
393 }
394 } else {
395 if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
396 MS_LOG(ERROR) << "abstract is not AbstractTensor";
397 return lite::RET_ERROR;
398 }
399 auto type = kNumberTypeFloat32;
400 if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
401 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
402 auto typePtr = abstract_tensor->element()->GetTypeTrack();
403 type = typePtr->type_id();
404 }
405 data_info.data_type_ = type;
406 data_infos.emplace_back(data_info);
407 if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
408 CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) {
409 break;
410 }
411 }
412 }
413 } else {
414 lite::DataInfo data_info;
415 auto type = kNumberTypeFloat32;
416 if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
417 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
418 auto typePtr = abstract_tensor->element()->GetTypeTrack();
419 type = typePtr->type_id();
420 }
421 data_info.data_type_ = type;
422 data_info.node_type_ = lite::NodeType_CNode;
423 data_infos.emplace_back(data_info);
424 }
425 return ConvertToLiteTensor(data_infos, outputs);
426 }
427
ConvertToLiteTensor(const std::vector<lite::DataInfo> & data_infos,std::vector<lite::Tensor * > * tensors)428 STATUS NodeInferShape::ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos,
429 std::vector<lite::Tensor *> *tensors) {
430 MS_ASSERT(tensors != nullptr);
431 for (auto &data_info : data_infos) {
432 auto tensor_category = lite::TensorCategory(lite::NodeType(data_info.node_type_), data_info.shape_.size(),
433 TypeId(data_info.data_type_), data_info.data_.size());
434 lite::Tensor *tensor = nullptr;
435 if (data_info.data_type_ != kObjectTypeTensorType) {
436 tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_,
437 (mindspore::Format)data_info.format_, tensor_category);
438 } else {
439 tensor = new (std::nothrow) lite::TensorList(data_info.shape_, std::vector<int>(), tensor_category);
440 }
441 if (tensor == nullptr) {
442 MS_LOG(ERROR) << "new a lite tensor failed";
443 return lite::RET_ERROR;
444 }
445 auto tensor_size = data_info.data_.size();
446 if (tensor_size > 0) {
447 if (data_info.data_type_ == kObjectTypeTensorType) {
448 auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
449 if (tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data())) != RET_OK) {
450 MS_LOG(ERROR) << "Decode tensorlist data failed";
451 return RET_ERROR;
452 }
453 } else {
454 auto tensor_data = reinterpret_cast<char *>(malloc(tensor_size));
455 if (tensor_data == nullptr) {
456 MS_LOG(ERROR) << "tensor_data is nullptr";
457 delete tensor;
458 return lite::RET_ERROR;
459 }
460 if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) {
461 delete tensor;
462 free(tensor_data);
463 tensor_data = nullptr;
464 MS_LOG(ERROR) << "memcpy error: ";
465 return lite::RET_ERROR;
466 }
467 tensor->set_data(tensor_data);
468 }
469 }
470 tensors->emplace_back(tensor);
471 }
472 return lite::RET_OK;
473 }
474
SetCNodeAbstract(const std::shared_ptr<CNode> & cnode,const std::vector<lite::Tensor * > & outputs,int status)475 STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs,
476 int status) {
477 MS_ASSERT(cnode != nullptr);
478 if (outputs.size() == 0) {
479 MS_LOG(ERROR) << "empty output_tensors";
480 return RET_ERROR;
481 }
482 auto origin_abstract = cnode->abstract();
483 MS_ASSERT(origin_abstract != nullptr);
484 if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) {
485 auto tensor = outputs.front();
486 auto new_abstract = ConvertLiteTensorToAbstract(tensor);
487 if (new_abstract == nullptr) {
488 MS_LOG(ERROR) << "new abstract failed.";
489 return RET_ERROR;
490 }
491 if (status == lite::RET_INFER_INVALID) {
492 ShapeVector shape;
493 if (tensor->data_type() == kObjectTypeTensorType) {
494 shape = {0};
495 }
496 auto abstract_shape = std::make_shared<abstract::Shape>(shape);
497 CHECK_NULL_RETURN(abstract_shape);
498 new_abstract->set_shape(abstract_shape);
499 }
500 cnode->set_abstract(new_abstract);
501 } else {
502 AbstractBasePtrList abstract_list;
503 for (size_t i = 0; i < outputs.size(); i++) {
504 auto tensor = outputs.at(i);
505 auto new_abstract = ConvertLiteTensorToAbstract(tensor);
506 if (new_abstract == nullptr) {
507 MS_LOG(ERROR) << "new abstract failed.";
508 return RET_ERROR;
509 }
510 if (status == lite::RET_INFER_INVALID) {
511 ShapeVector shape;
512 if (tensor->data_type() == kObjectTypeTensorType) {
513 shape = {0};
514 }
515 auto abstract_shape = std::make_shared<abstract::Shape>(shape);
516 CHECK_NULL_RETURN(abstract_shape);
517 new_abstract->set_shape(abstract_shape);
518 }
519 abstract_list.emplace_back(new_abstract);
520 }
521 auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
522 CHECK_NULL_RETURN(new_abstract_list);
523 cnode->set_abstract(new_abstract_list);
524 }
525 return RET_OK;
526 }
527
ConvertLiteTensorToAbstract(lite::Tensor * tensor)528 abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
529 MS_ASSERT(tensor != nullptr);
530 if (tensor->data_type() == kObjectTypeTensorType) {
531 return ConvertTensorListToAbstract(tensor);
532 }
533 auto tensor_info = NewTensorInfo(tensor);
534 if (tensor_info == nullptr) {
535 MS_LOG(ERROR) << "new tensor::Tensor failed";
536 return nullptr;
537 }
538 return tensor_info->ToAbstract();
539 }
540
541 // stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type.
542 // both of them is different in term of shape and type.
ConvertTensorListToAbstract(lite::Tensor * tensor)543 abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) {
544 MS_ASSERT(tensor != nullptr);
545 auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
546 if (tensor_list == nullptr) {
547 MS_LOG(ERROR) << "cast tensor_list failed";
548 return nullptr;
549 }
550 std::vector<int> shape(tensor->shape());
551 std::vector<int64_t> shape_vector(shape.begin(), shape.end());
552 auto tensor_list_abstract =
553 std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector);
554 if (tensor_list_abstract == nullptr) {
555 MS_LOG(ERROR) << "new AbstractTensor failed";
556 return nullptr;
557 }
558 auto elememt_shape = tensor_list->element_shape();
559 std::vector<int> data_info;
560 data_info.push_back(tensor_list->tensors_data_type());
561 data_info.push_back(elememt_shape.size());
562 std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info));
563 data_info.push_back(tensor_list->tensors().size());
564 for (size_t i = 0; i < tensor_list->tensors().size(); ++i) {
565 auto tensor_mem = tensor_list->tensors()[i];
566 auto tensor_mem_shape = tensor_mem->shape();
567 data_info.push_back(tensor_mem_shape.size());
568 std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info));
569 }
570 std::vector<int64_t> data_shape;
571 data_shape.push_back(data_info.size());
572 auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32);
573 if (tensor_info == nullptr) {
574 MS_LOG(ERROR) << "new tensor::Tensor failed";
575 return nullptr;
576 }
577 tensor_list_abstract->set_value(tensor_info);
578 return tensor_list_abstract;
579 }
580 } // namespace opt
581 } // namespace mindspore
582