1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/lite_exporter/fetch_content.h"
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "mindapi/base/format.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "nnacl/op_base.h"
29 #include "ops/op_utils.h"
30 #include "src/common/ops/anf_utils.h"
31 #include "src/common/ops/populate/populate_register.h"
32 #include "src/common/primitive_t_utils.h"
33 #include "tools/common/node_util.h"
34 #include "tools/converter/quantizer/quant_param_holder.h"
35 #include "tools/optimizer/common/format_utils.h"
36 #include "tools/optimizer/common/gllo_utils.h"
37 #include "tools/optimizer/graph/specify_graph_input_format.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/ms_utils_secure.h"
40
41 namespace mindspore {
42 namespace lite {
43 namespace {
44 constexpr int kNumWeightIndex = 2;
45 constexpr int kNumTransposePermSize = 4;
46 constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
47 static const std::unordered_map<int, int> TypeToTypeMap = {
48 {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
GetShapeVectorFromStringTensor(const tensor::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)49 STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
50 MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr);
51 auto data_type = tensor_info->data_type();
52 if (data_type != kObjectTypeString) {
53 MS_LOG(ERROR) << "This function only used for string tensor.";
54 return RET_ERROR;
55 }
56 shape_vector->clear();
57 MS_CHECK_TRUE_MSG(tensor_info->data_c() != nullptr, RET_ERROR, "tensor_info->data_c() is nullptr");
58 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
59 std::string shape_str;
60 std::string shape_size_str;
61 *offset = 0;
62 size_t cnt = 0;
63 for (; *offset < tensor_info->Size(); (*offset)++) {
64 if (tensor_data[*offset] == ',') {
65 (*offset)++;
66 break;
67 }
68 shape_size_str.push_back(tensor_data[*offset]);
69 }
70 if (*offset == 0) {
71 MS_LOG(ERROR) << "string tensor's dim size not found.";
72 return RET_ERROR;
73 }
74 constexpr int kBase = 10;
75 size_t shape_size = static_cast<size_t>(std::strtol(shape_size_str.c_str(), nullptr, kBase));
76 MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
77 for (; *offset < tensor_info->Size(); (*offset)++) {
78 if (tensor_data[*offset] == ',') {
79 cnt++;
80 int64_t shape = 0;
81 try {
82 shape = std::stoi(shape_str);
83 } catch (const std::exception &e) {
84 MS_LOG(ERROR) << "Get shape failed: " << e.what();
85 return RET_ERROR;
86 }
87 shape_vector->push_back(shape);
88 shape_str.clear();
89 } else {
90 shape_str.push_back(tensor_data[*offset]);
91 }
92 if (cnt == shape_size) {
93 (*offset)++;
94 break;
95 }
96 }
97 if (shape_vector->empty()) {
98 MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
99 return RET_ERROR;
100 }
101 return RET_OK;
102 }
103
GetDataTypeAndShape(const ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)104 STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) {
105 MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr);
106 auto abstract_base = param_node->abstract();
107 if (abstract_base == nullptr) {
108 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
109 return RET_PARAM_INVALID;
110 }
111 if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
112 MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
113 return RET_INPUT_TENSOR_ERROR;
114 }
115 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
116 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
117 auto typePtr = abstract_tensor->element()->GetTypeTrack();
118 MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
119 *data_type = typePtr->type_id();
120 if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
121 MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
122 return RET_PARAM_INVALID;
123 }
124 *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
125 return RET_OK;
126 }
127
FetchFromTensorValue(const ValueNodePtr & value_node,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info,bool copy_data)128 int FetchFromTensorValue(const ValueNodePtr &value_node, converter::FmkType fmk_type, bool train_flag,
129 DataInfo *data_info, bool copy_data) {
130 MS_ASSERT(value_node != nullptr && data_info != nullptr);
131 auto valueAbstract = value_node->abstract();
132 MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
133 auto abstract_tensor = valueAbstract->cast<abstract::AbstractTensorPtr>();
134 if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
135 MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
136 return RET_ERROR;
137 }
138 auto typePtr = abstract_tensor->element()->GetTypeTrack();
139 MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
140 data_info->data_type_ = typePtr->type_id();
141 auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
142 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
143 data_info->shape_ = dims;
144 if (train_flag && dims.empty()) {
145 data_info->shape_ = {1};
146 }
147 auto value = value_node->value();
148 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
149 auto data = value->cast<tensor::TensorPtr>();
150 MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
151 if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
152 MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
153 return RET_ERROR;
154 }
155
156 // process weight tensor
157 if (copy_data) {
158 data_info->data_.resize(data->Size());
159 if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
160 MS_LOG(ERROR) << "memcpy_s error.";
161 return RET_ERROR;
162 }
163 } else {
164 data_info->data_ptr_ = data->data_c();
165 }
166 return RET_OK;
167 }
168
169 template <typename DstImm, typename SrcImm>
FetchCastImmValue(const ValueNodePtr & value_node,DataInfo * data_info)170 int FetchCastImmValue(const ValueNodePtr &value_node, DataInfo *data_info) {
171 MS_ASSERT(value_node != nullptr && data_info != nullptr);
172 DstImm dst_imm;
173 data_info->data_type_ = dst_imm.type()->number_type();
174 data_info->shape_ = {1};
175 data_info->data_.resize(sizeof(dst_imm.value()));
176 auto value = value_node->value();
177 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
178 auto data = value->cast<std::shared_ptr<SrcImm>>();
179 MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
180 auto data_value = data->value();
181 decltype(dst_imm.value()) dst_data = static_cast<decltype(dst_imm.value())>(data_value);
182 if (memcpy_s(data_info->data_.data(), sizeof(dst_imm.value()), &dst_data, sizeof(dst_imm.value())) != EOK) {
183 MS_LOG(ERROR) << "memcpy_s failed";
184 return RET_MEMORY_FAILED;
185 }
186 return RET_OK;
187 }
188
189 template <typename ImmType>
FetchImmValue(const ValueNodePtr & value_node,DataInfo * data_info)190 int FetchImmValue(const ValueNodePtr &value_node, DataInfo *data_info) {
191 MS_ASSERT(value_node != nullptr && data_info != nullptr);
192 auto data = value_node->value()->cast<std::shared_ptr<ImmType>>();
193 MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberImm failed");
194 auto data_value = data->value();
195 data_info->data_type_ = data->type()->number_type();
196 data_info->shape_ = {1};
197 data_info->data_.resize(sizeof(data_value));
198 MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberImm failed");
199 if (memcpy_s(data_info->data_.data(), sizeof(data_value), &data_value, sizeof(data_value)) != EOK) {
200 MS_LOG(ERROR) << "memcpy_s failed";
201 return RET_MEMORY_FAILED;
202 }
203 return RET_OK;
204 }
205
FetchFromNumberValue(const ValueNodePtr & value_node,DataInfo * data_info)206 int FetchFromNumberValue(const ValueNodePtr &value_node, DataInfo *data_info) {
207 MS_ASSERT(value_node != nullptr && data_info != nullptr);
208 data_info->data_type_ = kNumberTypeInt32;
209 data_info->shape_ = {1};
210 data_info->data_.resize(sizeof(int));
211 auto data = value_node->value()->cast<NumberPtr>();
212 MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
213 int number_type = data->number_type();
214 if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
215 number_type = TypeToTypeMap.at(number_type);
216 }
217 if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) {
218 MS_LOG(ERROR) << "memcpy_s failed";
219 return RET_MEMORY_FAILED;
220 }
221 return RET_OK;
222 }
223
FetchFromSequenceValue(const ValueNodePtr & value_node,DataInfo * data_info)224 int FetchFromSequenceValue(const ValueNodePtr &value_node, DataInfo *data_info) {
225 MS_ASSERT(value_node != nullptr && data_info != nullptr);
226 auto value = value_node->value();
227 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
228 std::vector<int32_t> shape;
229 auto value_seq = value->cast<ValueSequencePtr>();
230 MS_CHECK_TRUE_MSG(value_seq != nullptr, RET_ERROR, "value_seq is nullptr");
231 if (!value_seq->value().empty()) {
232 if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 ||
233 value_seq->value().front()->type()->number_type() == kNumberTypeInt) {
234 shape = GetValue<std::vector<int>>(value);
235 } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) {
236 auto origin_value = GetValue<std::vector<int64_t>>(value);
237 std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape),
238 [](int64_t val) { return static_cast<int32_t>(val); });
239 } else {
240 MS_LOG(ERROR) << "Value type is ValueSequence is not integer.";
241 return RET_ERROR;
242 }
243 }
244 data_info->data_type_ = kNumberTypeInt32;
245 data_info->shape_ = {static_cast<int32_t>(shape.size())};
246 data_info->data_.resize(shape.size() * sizeof(int));
247 if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(),
248 shape.size() * sizeof(int32_t)) != EOK) {
249 MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
250 return RET_ERROR;
251 }
252 return RET_OK;
253 }
254
SetTensorData(const tensor::TensorPtr & tensor_info,DataInfo * data_info,TypeId data_type,size_t offset,bool copy_data)255 int SetTensorData(const tensor::TensorPtr &tensor_info, DataInfo *data_info, TypeId data_type, size_t offset,
256 bool copy_data) {
257 if (data_type == kObjectTypeTensorType && tensor_info->Size() >= kTensorListMinSize) {
258 data_info->data_.resize(tensor_info->Size() - offset);
259 if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(),
260 static_cast<uint8_t *>(tensor_info->data_c()) + offset,
261 tensor_info->Size() - offset)) {
262 MS_LOG(ERROR) << "memcpy_s failed.";
263 return RET_ERROR;
264 }
265 }
266 // common node with const data
267 if (data_type != kObjectTypeTensorType) {
268 if (copy_data) {
269 data_info->data_.resize(tensor_info->Size() - offset);
270 if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(),
271 static_cast<uint8_t *>(tensor_info->data_c()) + offset,
272 tensor_info->Size() - offset)) {
273 MS_LOG(ERROR) << "memcpy_s failed.";
274 return RET_ERROR;
275 }
276 } else {
277 data_info->data_ptr_ = static_cast<uint8_t *>(tensor_info->data_c()) + offset;
278 }
279 }
280 return RET_OK;
281 }
282 } // namespace
283
FetchFromDefaultParam(const ParameterPtr & param_node,const converter::FmkType & fmk_type,DataInfo * data_info,bool copy_data)284 int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info,
285 bool copy_data) {
286 MS_ASSERT(param_node != nullptr && data_info != nullptr);
287 ShapeVector shape_vector;
288 TypeId data_type = kTypeUnknown;
289 auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
290 if (status != RET_OK) {
291 MS_LOG(ERROR) << "get data type and shape from param node failed.";
292 return RET_ERROR;
293 }
294 data_info->data_type_ = data_type;
295 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
296 size_t offset = 0;
297 if (tensor_info != nullptr && !shape_vector.empty() && data_type == kObjectTypeString) {
298 status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
299 if (status != RET_OK) {
300 MS_LOG(ERROR) << "get shape vector from string tensor failed.";
301 return RET_ERROR;
302 }
303 }
304 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
305 data_info->shape_ = dims;
306 if (tensor_info != nullptr && tensor_info->Size() != 0) {
307 // tensor_list tensor
308 status = SetTensorData(tensor_info, data_info, data_type, offset, copy_data);
309 if (status != RET_OK) {
310 MS_LOG(ERROR) << "set tensor data failed.";
311 return RET_ERROR;
312 }
313 }
314 if (tensor_info != nullptr) {
315 data_info->compress_type_ = tensor_info->compression_type();
316 data_info->quant_params_ = tensor_info->quant_params();
317 }
318
319 // the const tensor format from onnx/caffe should be nchw in general
320 auto const_format = (fmk_type == converter::kFmkTypeMsLite || fmk_type == converter::kFmkTypeTf ||
321 fmk_type == converter::kFmkTypeTflite)
322 ? NHWC
323 : NCHW;
324 data_info->format_ = param_node->has_default() ? const_format : NHWC;
325 return RET_OK;
326 }
327
FetchDataFromParameterNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,DataInfo * data_info,bool copy_data)328 int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, DataInfo *data_info,
329 bool copy_data) {
330 MS_ASSERT(cnode != nullptr && data_info != nullptr);
331 auto param_node = cnode->input(index)->cast<ParameterPtr>();
332 MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
333 if (FetchFromDefaultParam(param_node, fmk_type, data_info, copy_data) != RET_OK) {
334 MS_LOG(ERROR) << "fetch information from default param failed.";
335 return RET_ERROR;
336 }
337 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
338 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
339 if (prim->GetAttr(mindspore::ops::kFormat) == nullptr && !param_node->has_default()) {
340 auto func_graph = cnode->func_graph();
341 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "The func graph is nullptr");
342 auto input_format = func_graph->get_attr(kInputFormat);
343 data_info->format_ = input_format != nullptr ? GetValue<int>(input_format) : static_cast<int>(Format::NHWC);
344 }
345 if (prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
346 auto value = prim->GetAttr(mindspore::ops::kFormat);
347 if (value->isa<mindspore::Int64Imm>()) {
348 data_info->format_ = GetValue<int64_t>(value);
349 }
350 }
351 QuantParamHolderPtr quant_param_holder =
352 prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
353 if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
354 data_info->data_type_ == kNumberTypeInt8) {
355 data_info->enable_huffman_code_ = true;
356 }
357 data_info->node_type_ = NodeType_ValueNode;
358 return RET_OK;
359 }
360
FetchDataFromValueNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info,bool copy_data)361 int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
362 DataInfo *data_info, bool copy_data) {
363 MS_ASSERT(cnode != nullptr && data_info != nullptr);
364 auto value_node = cnode->input(index)->cast<ValueNodePtr>();
365 MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
366
367 auto value = value_node->value();
368 int ret = RET_OK;
369 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
370 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
371 if (value->isa<tensor::Tensor>()) {
372 ret = FetchFromTensorValue(value_node, fmk_type, train_flag, data_info, copy_data);
373 if (index == kNumWeightIndex && prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
374 data_info->format_ = GetValue<int64_t>(prim->GetAttr(mindspore::ops::kFormat));
375 }
376 } else if (value->isa<mindspore::Int64Imm>()) {
377 ret = FetchCastImmValue<mindspore::Int32Imm, mindspore::Int64Imm>(value_node, data_info);
378 } else if (value->isa<mindspore::Int32Imm>()) {
379 ret = FetchImmValue<mindspore::Int32Imm>(value_node, data_info);
380 } else if (value->isa<mindspore::BoolImm>()) {
381 ret = FetchImmValue<mindspore::BoolImm>(value_node, data_info);
382 } else if (value->isa<mindspore::FP32Imm>()) {
383 ret = FetchImmValue<mindspore::FP32Imm>(value_node, data_info);
384 } else if (value->isa<mindspore::ValueSequence>()) {
385 ret = FetchFromSequenceValue(value_node, data_info);
386 } else if (value->isa<Number>()) {
387 ret = FetchFromNumberValue(value_node, data_info);
388 } else if (value->isa<FuncGraph>()) {
389 MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph";
390 return RET_NO_CHANGE;
391 } else if (value->isa<Monad>()) {
392 MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad";
393 return RET_NO_CHANGE;
394 } else {
395 MS_LOG(ERROR) << "Not support value type , need add support.";
396 return RET_ERROR;
397 }
398 data_info->node_type_ = NodeType_ValueNode;
399 return ret;
400 }
401
FetchDataFromCNode(const CNodePtr & cnode,size_t index,DataInfo * data_info)402 int FetchDataFromCNode(const CNodePtr &cnode, size_t index, DataInfo *data_info) {
403 MS_ASSERT(cnode != nullptr && data_info != nullptr);
404 auto abstract = opt::GetCNodeInputAbstract(cnode, index);
405 if (abstract == nullptr) {
406 MS_LOG(ERROR) << "Abstract cnode is nullptr.";
407 return RET_ERROR;
408 }
409 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
410 MS_LOG(ERROR) << "Abstract should be anstract tensor.";
411 return RET_ERROR;
412 }
413 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
414 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
415 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
416 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
417 if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
418 MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
419 return RET_ERROR;
420 }
421 auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
422 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
423 Format format{mindspore::NHWC};
424 auto ret = opt::DetermineCertainVarInputFormat(cnode, index, &format);
425 if (ret != RET_OK) {
426 MS_LOG(ERROR) << "set format for cnode failed";
427 return RET_ERROR;
428 }
429 data_info->format_ = format;
430 data_info->data_type_ = type_ptr->type_id();
431 data_info->shape_ = dims;
432 data_info->node_type_ = NodeType_CNode;
433 if (type_ptr->type_id() == kObjectTypeTensorType) {
434 auto tensor_info = abstract_tensor->GetValueTrack();
435 if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
436 MS_LOG(ERROR) << "tensor info is invalid.";
437 return RET_ERROR;
438 }
439 auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
440 MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
441 if (tensor_value->Size() >= kTensorListMinSize) {
442 data_info->data_.resize(tensor_value->Size());
443 if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
444 EOK) {
445 MS_LOG(ERROR) << "memcpy data failed.";
446 return RET_ERROR;
447 }
448 }
449 }
450 return RET_OK;
451 }
452
FetchConstData(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,DataInfo * data_info,bool copy_data)453 int FetchConstData(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, DataInfo *data_info,
454 bool copy_data) {
455 auto node_name = cnode->fullname_with_scope();
456 if (index > cnode->size()) {
457 MS_LOG(ERROR) << node_name << index << " > " << cnode->size();
458 return RET_ERROR;
459 }
460 int status;
461 auto input = cnode->input(index);
462 if (input->isa<Parameter>()) {
463 status = FetchDataFromParameterNode(cnode, index, fmk_type, data_info, copy_data);
464 } else if (input->isa<ValueNode>()) {
465 status = FetchDataFromValueNode(cnode, index, fmk_type, false, data_info, copy_data);
466 } else {
467 MS_LOG(ERROR) << node_name << " index " << index << " is not Parameter or ValueNode";
468 return RET_ERROR;
469 }
470 if (status != RET_OK) {
471 MS_LOG(ERROR) << node_name << " fetch data failed";
472 return status;
473 }
474 return RET_OK;
475 }
476
FetchDataFromAbstract(const AbstractBasePtr & abstract,DataInfo * data_info)477 int FetchDataFromAbstract(const AbstractBasePtr &abstract, DataInfo *data_info) {
478 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
479 if (!utils::isa<abstract::AbstractTensor>(abstract)) {
480 MS_LOG(ERROR) << "Abstract should be AbstractTensor.";
481 return RET_ERROR;
482 }
483 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
484 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
485 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
486 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
487 if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
488 MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
489 return RET_ERROR;
490 }
491 auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
492 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
493 data_info->data_type_ = static_cast<int>(type_ptr->type_id());
494 data_info->shape_ = dims;
495 data_info->node_type_ = static_cast<int>(NodeType_CNode);
496 if (type_ptr->type_id() == kObjectTypeTensorType) {
497 auto tensor_info = abstract_tensor->GetValueTrack();
498 if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
499 MS_LOG(ERROR) << "tensor info is invalid.";
500 return RET_ERROR;
501 }
502 auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
503 MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
504 if (tensor_value->Size() >= kTensorListMinSize) {
505 data_info->data_.resize(tensor_value->Size());
506 if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
507 EOK) {
508 MS_LOG(ERROR) << "memcpy data failed.";
509 return RET_ERROR;
510 }
511 }
512 }
513 return RET_OK;
514 }
515
RemoveIfDepend(const CNodePtr & cnode)516 int RemoveIfDepend(const CNodePtr &cnode) {
517 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
518 bool has_depend = false;
519 std::vector<AnfNodePtr> inputs;
520 inputs.clear();
521
522 inputs.emplace_back(cnode->input(0));
523 for (size_t i = 1; i < cnode->size(); ++i) {
524 AnfNodePtr input_node = cnode->input(i);
525 MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "inputNode is nullptr");
526 if (!input_node->isa<CNode>()) {
527 inputs.emplace_back(cnode->input(i));
528 continue;
529 }
530 if (opt::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
531 auto depend_node = utils::cast<CNodePtr>(input_node);
532 MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
533 has_depend = true;
534 bool mask_out = (depend_node->size() == opt::kInputSizeThree);
535 for (size_t j = 1; j < depend_node->size(); ++j) {
536 AnfNodePtr depend_input_node = depend_node->input(j);
537 MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
538 inputs.emplace_back(depend_input_node);
539 if (mask_out) {
540 break;
541 }
542 }
543 } else {
544 inputs.emplace_back(cnode->input(i));
545 }
546 }
547 if (has_depend) {
548 cnode->set_inputs(inputs);
549 }
550 return RET_OK;
551 }
552
GetFlattenInputsIfMakeTuple(const CNodePtr & cnode,std::vector<AnfNodePtr> * inputs,bool * has_make_tuple)553 int GetFlattenInputsIfMakeTuple(const CNodePtr &cnode, std::vector<AnfNodePtr> *inputs, bool *has_make_tuple) {
554 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "Cnode is nullptr.");
555 MS_CHECK_TRUE_MSG(inputs != nullptr, RET_NULL_PTR, "Inputs is nullptr.");
556 MS_CHECK_TRUE_MSG(has_make_tuple != nullptr, RET_NULL_PTR, "Has make tuple is nullptr.");
557 for (size_t i = 1; i < cnode->size(); ++i) {
558 AnfNodePtr input_node = cnode->input(i);
559 MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "Input_node is nullptr");
560 auto input_cnode = utils::cast<CNodePtr>(input_node);
561 if (input_cnode && (opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple) ||
562 opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTupleV2))) {
563 *has_make_tuple = true;
564 GetFlattenInputsIfMakeTuple(input_cnode, inputs, has_make_tuple);
565 } else {
566 inputs->emplace_back(input_node);
567 }
568 }
569 return RET_OK;
570 }
571
RemoveIfMakeTuple(const CNodePtr & cnode)572 int RemoveIfMakeTuple(const CNodePtr &cnode) {
573 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
574 bool has_make_tuple = false;
575 std::vector<AnfNodePtr> inputs;
576 inputs.clear();
577
578 inputs.emplace_back(cnode->input(0));
579 if (GetFlattenInputsIfMakeTuple(cnode, &inputs, &has_make_tuple) != RET_OK) {
580 MS_LOG(ERROR) << "Trace real input of make tuple failed, name: " << cnode->fullname_with_scope();
581 return RET_ERROR;
582 }
583 if (has_make_tuple) {
584 cnode->set_inputs(inputs);
585 }
586 return RET_OK;
587 }
588
FetchOpParameterFromNode(const AnfNodePtr & node,OpParameter ** op_parameter)589 int FetchOpParameterFromNode(const AnfNodePtr &node, OpParameter **op_parameter) {
590 if (op_parameter == nullptr) {
591 MS_LOG(ERROR) << "op_parameter is nullptr.";
592 return RET_NULL_PTR;
593 }
594 CHECK_NULL_RETURN(GetValueNode<PrimitivePtr>(node));
595 auto prim_t = lite::GetPrimitiveT(node);
596 CHECK_NULL_RETURN(prim_t);
597 size_t init_size = 1024;
598 flatbuffers::FlatBufferBuilder fbb(init_size);
599 auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
600 if (prim == nullptr) {
601 fbb.Clear();
602 MS_LOG(ERROR) << "get primitive failed.";
603 return RET_ERROR;
604 }
605 auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR);
606 if (parameter_gen == nullptr) {
607 fbb.Clear();
608 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
609 return RET_ERROR;
610 }
611 *op_parameter = parameter_gen(prim);
612 fbb.Clear();
613 if (*op_parameter == nullptr) {
614 MS_LOG(ERROR) << "parameter is nullptr.";
615 return RET_ERROR;
616 }
617 return RET_OK;
618 }
619
FetchOpParameterFromFuncGraph(const FuncGraphPtr & func_graph,std::map<std::string,OpParameter * > * op_parameters)620 int FetchOpParameterFromFuncGraph(const FuncGraphPtr &func_graph, std::map<std::string, OpParameter *> *op_parameters) {
621 MS_CHECK_TRUE_MSG(op_parameters != nullptr, RET_NULL_PTR, "op_parameters is nullptr.");
622 auto cnodes = func_graph->GetOrderedCnodes();
623 for (auto &cnode : cnodes) {
624 if (opt::IsSpecialType(cnode)) {
625 continue;
626 }
627 auto primitive = cnode->input(0);
628 OpParameter *parameter = nullptr;
629 auto ret = lite::FetchOpParameterFromNode(primitive, ¶meter);
630 if (ret != lite::RET_OK) {
631 MS_LOG(ERROR) << cnode->fullname_with_scope() << " FetchOpParameterFromNode failed. ";
632 return ret;
633 }
634 CHECK_NULL_RETURN(parameter);
635 parameter->thread_num_ = 1;
636 op_parameters->emplace(std::pair<std::string, OpParameter *>(cnode->fullname_with_scope(), parameter));
637 }
638 return RET_OK;
639 }
640 } // namespace lite
641 } // namespace mindspore
642