• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/lite_exporter/fetch_content.h"
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "mindapi/base/format.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "nnacl/op_base.h"
29 #include "ops/op_utils.h"
30 #include "src/common/ops/anf_utils.h"
31 #include "src/common/ops/populate/populate_register.h"
32 #include "src/common/primitive_t_utils.h"
33 #include "tools/common/node_util.h"
34 #include "tools/converter/quantizer/quant_param_holder.h"
35 #include "tools/optimizer/common/format_utils.h"
36 #include "tools/optimizer/common/gllo_utils.h"
37 #include "tools/optimizer/graph/specify_graph_input_format.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/ms_utils_secure.h"
40 
41 namespace mindspore {
42 namespace lite {
43 namespace {
44 constexpr int kNumWeightIndex = 2;
45 constexpr int kNumTransposePermSize = 4;
46 constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
47 static const std::unordered_map<int, int> TypeToTypeMap = {
48   {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
GetShapeVectorFromStringTensor(const tensor::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)49 STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
50   MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr);
51   auto data_type = tensor_info->data_type();
52   if (data_type != kObjectTypeString) {
53     MS_LOG(ERROR) << "This function only used for string tensor.";
54     return RET_ERROR;
55   }
56   shape_vector->clear();
57   MS_CHECK_TRUE_MSG(tensor_info->data_c() != nullptr, RET_ERROR, "tensor_info->data_c() is nullptr");
58   auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
59   std::string shape_str;
60   std::string shape_size_str;
61   *offset = 0;
62   size_t cnt = 0;
63   for (; *offset < tensor_info->Size(); (*offset)++) {
64     if (tensor_data[*offset] == ',') {
65       (*offset)++;
66       break;
67     }
68     shape_size_str.push_back(tensor_data[*offset]);
69   }
70   if (*offset == 0) {
71     MS_LOG(ERROR) << "string tensor's dim size not found.";
72     return RET_ERROR;
73   }
74   constexpr int kBase = 10;
75   size_t shape_size = static_cast<size_t>(std::strtol(shape_size_str.c_str(), nullptr, kBase));
76   MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
77   for (; *offset < tensor_info->Size(); (*offset)++) {
78     if (tensor_data[*offset] == ',') {
79       cnt++;
80       int64_t shape = 0;
81       try {
82         shape = std::stoi(shape_str);
83       } catch (const std::exception &e) {
84         MS_LOG(ERROR) << "Get shape failed: " << e.what();
85         return RET_ERROR;
86       }
87       shape_vector->push_back(shape);
88       shape_str.clear();
89     } else {
90       shape_str.push_back(tensor_data[*offset]);
91     }
92     if (cnt == shape_size) {
93       (*offset)++;
94       break;
95     }
96   }
97   if (shape_vector->empty()) {
98     MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
99     return RET_ERROR;
100   }
101   return RET_OK;
102 }
103 
GetDataTypeAndShape(const ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)104 STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape_vector) {
105   MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr);
106   auto abstract_base = param_node->abstract();
107   if (abstract_base == nullptr) {
108     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
109     return RET_PARAM_INVALID;
110   }
111   if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
112     MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
113     return RET_INPUT_TENSOR_ERROR;
114   }
115   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
116   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
117   auto typePtr = abstract_tensor->element()->GetTypeTrack();
118   MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
119   *data_type = typePtr->type_id();
120   if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
121     MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
122     return RET_PARAM_INVALID;
123   }
124   *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
125   return RET_OK;
126 }
127 
FetchFromTensorValue(const ValueNodePtr & value_node,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info,bool copy_data)128 int FetchFromTensorValue(const ValueNodePtr &value_node, converter::FmkType fmk_type, bool train_flag,
129                          DataInfo *data_info, bool copy_data) {
130   MS_ASSERT(value_node != nullptr && data_info != nullptr);
131   auto valueAbstract = value_node->abstract();
132   MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
133   auto abstract_tensor = valueAbstract->cast<abstract::AbstractTensorPtr>();
134   if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
135     MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
136     return RET_ERROR;
137   }
138   auto typePtr = abstract_tensor->element()->GetTypeTrack();
139   MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
140   data_info->data_type_ = typePtr->type_id();
141   auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
142   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
143   data_info->shape_ = dims;
144   if (train_flag && dims.empty()) {
145     data_info->shape_ = {1};
146   }
147   auto value = value_node->value();
148   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
149   auto data = value->cast<tensor::TensorPtr>();
150   MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
151   if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
152     MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
153     return RET_ERROR;
154   }
155 
156   // process weight tensor
157   if (copy_data) {
158     data_info->data_.resize(data->Size());
159     if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
160       MS_LOG(ERROR) << "memcpy_s error.";
161       return RET_ERROR;
162     }
163   } else {
164     data_info->data_ptr_ = data->data_c();
165   }
166   return RET_OK;
167 }
168 
169 template <typename DstImm, typename SrcImm>
FetchCastImmValue(const ValueNodePtr & value_node,DataInfo * data_info)170 int FetchCastImmValue(const ValueNodePtr &value_node, DataInfo *data_info) {
171   MS_ASSERT(value_node != nullptr && data_info != nullptr);
172   DstImm dst_imm;
173   data_info->data_type_ = dst_imm.type()->number_type();
174   data_info->shape_ = {1};
175   data_info->data_.resize(sizeof(dst_imm.value()));
176   auto value = value_node->value();
177   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
178   auto data = value->cast<std::shared_ptr<SrcImm>>();
179   MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
180   auto data_value = data->value();
181   decltype(dst_imm.value()) dst_data = static_cast<decltype(dst_imm.value())>(data_value);
182   if (memcpy_s(data_info->data_.data(), sizeof(dst_imm.value()), &dst_data, sizeof(dst_imm.value())) != EOK) {
183     MS_LOG(ERROR) << "memcpy_s failed";
184     return RET_MEMORY_FAILED;
185   }
186   return RET_OK;
187 }
188 
189 template <typename ImmType>
FetchImmValue(const ValueNodePtr & value_node,DataInfo * data_info)190 int FetchImmValue(const ValueNodePtr &value_node, DataInfo *data_info) {
191   MS_ASSERT(value_node != nullptr && data_info != nullptr);
192   auto data = value_node->value()->cast<std::shared_ptr<ImmType>>();
193   MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberImm failed");
194   auto data_value = data->value();
195   data_info->data_type_ = data->type()->number_type();
196   data_info->shape_ = {1};
197   data_info->data_.resize(sizeof(data_value));
198   MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberImm failed");
199   if (memcpy_s(data_info->data_.data(), sizeof(data_value), &data_value, sizeof(data_value)) != EOK) {
200     MS_LOG(ERROR) << "memcpy_s failed";
201     return RET_MEMORY_FAILED;
202   }
203   return RET_OK;
204 }
205 
FetchFromNumberValue(const ValueNodePtr & value_node,DataInfo * data_info)206 int FetchFromNumberValue(const ValueNodePtr &value_node, DataInfo *data_info) {
207   MS_ASSERT(value_node != nullptr && data_info != nullptr);
208   data_info->data_type_ = kNumberTypeInt32;
209   data_info->shape_ = {1};
210   data_info->data_.resize(sizeof(int));
211   auto data = value_node->value()->cast<NumberPtr>();
212   MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
213   int number_type = data->number_type();
214   if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
215     number_type = TypeToTypeMap.at(number_type);
216   }
217   if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) {
218     MS_LOG(ERROR) << "memcpy_s failed";
219     return RET_MEMORY_FAILED;
220   }
221   return RET_OK;
222 }
223 
FetchFromSequenceValue(const ValueNodePtr & value_node,DataInfo * data_info)224 int FetchFromSequenceValue(const ValueNodePtr &value_node, DataInfo *data_info) {
225   MS_ASSERT(value_node != nullptr && data_info != nullptr);
226   auto value = value_node->value();
227   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
228   std::vector<int32_t> shape;
229   auto value_seq = value->cast<ValueSequencePtr>();
230   MS_CHECK_TRUE_MSG(value_seq != nullptr, RET_ERROR, "value_seq is nullptr");
231   if (!value_seq->value().empty()) {
232     if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 ||
233         value_seq->value().front()->type()->number_type() == kNumberTypeInt) {
234       shape = GetValue<std::vector<int>>(value);
235     } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) {
236       auto origin_value = GetValue<std::vector<int64_t>>(value);
237       std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape),
238                      [](int64_t val) { return static_cast<int32_t>(val); });
239     } else {
240       MS_LOG(ERROR) << "Value type is ValueSequence is not integer.";
241       return RET_ERROR;
242     }
243   }
244   data_info->data_type_ = kNumberTypeInt32;
245   data_info->shape_ = {static_cast<int32_t>(shape.size())};
246   data_info->data_.resize(shape.size() * sizeof(int));
247   if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(),
248                                  shape.size() * sizeof(int32_t)) != EOK) {
249     MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
250     return RET_ERROR;
251   }
252   return RET_OK;
253 }
254 
SetTensorData(const tensor::TensorPtr & tensor_info,DataInfo * data_info,TypeId data_type,size_t offset,bool copy_data)255 int SetTensorData(const tensor::TensorPtr &tensor_info, DataInfo *data_info, TypeId data_type, size_t offset,
256                   bool copy_data) {
257   if (data_type == kObjectTypeTensorType && tensor_info->Size() >= kTensorListMinSize) {
258     data_info->data_.resize(tensor_info->Size() - offset);
259     if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(),
260                                    static_cast<uint8_t *>(tensor_info->data_c()) + offset,
261                                    tensor_info->Size() - offset)) {
262       MS_LOG(ERROR) << "memcpy_s failed.";
263       return RET_ERROR;
264     }
265   }
266   // common node with const data
267   if (data_type != kObjectTypeTensorType) {
268     if (copy_data) {
269       data_info->data_.resize(tensor_info->Size() - offset);
270       if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(),
271                                      static_cast<uint8_t *>(tensor_info->data_c()) + offset,
272                                      tensor_info->Size() - offset)) {
273         MS_LOG(ERROR) << "memcpy_s failed.";
274         return RET_ERROR;
275       }
276     } else {
277       data_info->data_ptr_ = static_cast<uint8_t *>(tensor_info->data_c()) + offset;
278     }
279   }
280   return RET_OK;
281 }
282 }  // namespace
283 
FetchFromDefaultParam(const ParameterPtr & param_node,const converter::FmkType & fmk_type,DataInfo * data_info,bool copy_data)284 int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info,
285                           bool copy_data) {
286   MS_ASSERT(param_node != nullptr && data_info != nullptr);
287   ShapeVector shape_vector;
288   TypeId data_type = kTypeUnknown;
289   auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
290   if (status != RET_OK) {
291     MS_LOG(ERROR) << "get data type and shape from param node failed.";
292     return RET_ERROR;
293   }
294   data_info->data_type_ = data_type;
295   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
296   size_t offset = 0;
297   if (tensor_info != nullptr && !shape_vector.empty() && data_type == kObjectTypeString) {
298     status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
299     if (status != RET_OK) {
300       MS_LOG(ERROR) << "get shape vector from string tensor failed.";
301       return RET_ERROR;
302     }
303   }
304   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
305   data_info->shape_ = dims;
306   if (tensor_info != nullptr && tensor_info->Size() != 0) {
307     // tensor_list tensor
308     status = SetTensorData(tensor_info, data_info, data_type, offset, copy_data);
309     if (status != RET_OK) {
310       MS_LOG(ERROR) << "set tensor data failed.";
311       return RET_ERROR;
312     }
313   }
314   if (tensor_info != nullptr) {
315     data_info->compress_type_ = tensor_info->compression_type();
316     data_info->quant_params_ = tensor_info->quant_params();
317   }
318 
319   // the const tensor format from onnx/caffe should be nchw in general
320   auto const_format = (fmk_type == converter::kFmkTypeMsLite || fmk_type == converter::kFmkTypeTf ||
321                        fmk_type == converter::kFmkTypeTflite)
322                         ? NHWC
323                         : NCHW;
324   data_info->format_ = param_node->has_default() ? const_format : NHWC;
325   return RET_OK;
326 }
327 
FetchDataFromParameterNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,DataInfo * data_info,bool copy_data)328 int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, DataInfo *data_info,
329                                bool copy_data) {
330   MS_ASSERT(cnode != nullptr && data_info != nullptr);
331   auto param_node = cnode->input(index)->cast<ParameterPtr>();
332   MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
333   if (FetchFromDefaultParam(param_node, fmk_type, data_info, copy_data) != RET_OK) {
334     MS_LOG(ERROR) << "fetch information from default param failed.";
335     return RET_ERROR;
336   }
337   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
338   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
339   if (prim->GetAttr(mindspore::ops::kFormat) == nullptr && !param_node->has_default()) {
340     auto func_graph = cnode->func_graph();
341     MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "The func graph is nullptr");
342     auto input_format = func_graph->get_attr(kInputFormat);
343     data_info->format_ = input_format != nullptr ? GetValue<int>(input_format) : static_cast<int>(Format::NHWC);
344   }
345   if (prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
346     auto value = prim->GetAttr(mindspore::ops::kFormat);
347     if (value->isa<mindspore::Int64Imm>()) {
348       data_info->format_ = GetValue<int64_t>(value);
349     }
350   }
351   QuantParamHolderPtr quant_param_holder =
352     prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
353   if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
354       data_info->data_type_ == kNumberTypeInt8) {
355     data_info->enable_huffman_code_ = true;
356   }
357   data_info->node_type_ = NodeType_ValueNode;
358   return RET_OK;
359 }
360 
FetchDataFromValueNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info,bool copy_data)361 int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
362                            DataInfo *data_info, bool copy_data) {
363   MS_ASSERT(cnode != nullptr && data_info != nullptr);
364   auto value_node = cnode->input(index)->cast<ValueNodePtr>();
365   MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
366 
367   auto value = value_node->value();
368   int ret = RET_OK;
369   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
370   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
371   if (value->isa<tensor::Tensor>()) {
372     ret = FetchFromTensorValue(value_node, fmk_type, train_flag, data_info, copy_data);
373     if (index == kNumWeightIndex && prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
374       data_info->format_ = GetValue<int64_t>(prim->GetAttr(mindspore::ops::kFormat));
375     }
376   } else if (value->isa<mindspore::Int64Imm>()) {
377     ret = FetchCastImmValue<mindspore::Int32Imm, mindspore::Int64Imm>(value_node, data_info);
378   } else if (value->isa<mindspore::Int32Imm>()) {
379     ret = FetchImmValue<mindspore::Int32Imm>(value_node, data_info);
380   } else if (value->isa<mindspore::BoolImm>()) {
381     ret = FetchImmValue<mindspore::BoolImm>(value_node, data_info);
382   } else if (value->isa<mindspore::FP32Imm>()) {
383     ret = FetchImmValue<mindspore::FP32Imm>(value_node, data_info);
384   } else if (value->isa<mindspore::ValueSequence>()) {
385     ret = FetchFromSequenceValue(value_node, data_info);
386   } else if (value->isa<Number>()) {
387     ret = FetchFromNumberValue(value_node, data_info);
388   } else if (value->isa<FuncGraph>()) {
389     MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph";
390     return RET_NO_CHANGE;
391   } else if (value->isa<Monad>()) {
392     MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad";
393     return RET_NO_CHANGE;
394   } else {
395     MS_LOG(ERROR) << "Not support value type , need add support.";
396     return RET_ERROR;
397   }
398   data_info->node_type_ = NodeType_ValueNode;
399   return ret;
400 }
401 
FetchDataFromCNode(const CNodePtr & cnode,size_t index,DataInfo * data_info)402 int FetchDataFromCNode(const CNodePtr &cnode, size_t index, DataInfo *data_info) {
403   MS_ASSERT(cnode != nullptr && data_info != nullptr);
404   auto abstract = opt::GetCNodeInputAbstract(cnode, index);
405   if (abstract == nullptr) {
406     MS_LOG(ERROR) << "Abstract cnode is nullptr.";
407     return RET_ERROR;
408   }
409   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
410     MS_LOG(ERROR) << "Abstract should be anstract tensor.";
411     return RET_ERROR;
412   }
413   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
414   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
415   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
416   MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
417   if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
418     MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
419     return RET_ERROR;
420   }
421   auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
422   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
423   Format format{mindspore::NHWC};
424   auto ret = opt::DetermineCertainVarInputFormat(cnode, index, &format);
425   if (ret != RET_OK) {
426     MS_LOG(ERROR) << "set format for cnode failed";
427     return RET_ERROR;
428   }
429   data_info->format_ = format;
430   data_info->data_type_ = type_ptr->type_id();
431   data_info->shape_ = dims;
432   data_info->node_type_ = NodeType_CNode;
433   if (type_ptr->type_id() == kObjectTypeTensorType) {
434     auto tensor_info = abstract_tensor->GetValueTrack();
435     if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
436       MS_LOG(ERROR) << "tensor info is invalid.";
437       return RET_ERROR;
438     }
439     auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
440     MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
441     if (tensor_value->Size() >= kTensorListMinSize) {
442       data_info->data_.resize(tensor_value->Size());
443       if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
444           EOK) {
445         MS_LOG(ERROR) << "memcpy data failed.";
446         return RET_ERROR;
447       }
448     }
449   }
450   return RET_OK;
451 }
452 
FetchConstData(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,DataInfo * data_info,bool copy_data)453 int FetchConstData(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, DataInfo *data_info,
454                    bool copy_data) {
455   auto node_name = cnode->fullname_with_scope();
456   if (index > cnode->size()) {
457     MS_LOG(ERROR) << node_name << index << " > " << cnode->size();
458     return RET_ERROR;
459   }
460   int status;
461   auto input = cnode->input(index);
462   if (input->isa<Parameter>()) {
463     status = FetchDataFromParameterNode(cnode, index, fmk_type, data_info, copy_data);
464   } else if (input->isa<ValueNode>()) {
465     status = FetchDataFromValueNode(cnode, index, fmk_type, false, data_info, copy_data);
466   } else {
467     MS_LOG(ERROR) << node_name << " index " << index << " is not Parameter or ValueNode";
468     return RET_ERROR;
469   }
470   if (status != RET_OK) {
471     MS_LOG(ERROR) << node_name << " fetch data failed";
472     return status;
473   }
474   return RET_OK;
475 }
476 
FetchDataFromAbstract(const AbstractBasePtr & abstract,DataInfo * data_info)477 int FetchDataFromAbstract(const AbstractBasePtr &abstract, DataInfo *data_info) {
478   MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
479   if (!utils::isa<abstract::AbstractTensor>(abstract)) {
480     MS_LOG(ERROR) << "Abstract should be AbstractTensor.";
481     return RET_ERROR;
482   }
483   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
484   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
485   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
486   MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
487   if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
488     MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
489     return RET_ERROR;
490   }
491   auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
492   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
493   data_info->data_type_ = static_cast<int>(type_ptr->type_id());
494   data_info->shape_ = dims;
495   data_info->node_type_ = static_cast<int>(NodeType_CNode);
496   if (type_ptr->type_id() == kObjectTypeTensorType) {
497     auto tensor_info = abstract_tensor->GetValueTrack();
498     if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
499       MS_LOG(ERROR) << "tensor info is invalid.";
500       return RET_ERROR;
501     }
502     auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
503     MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
504     if (tensor_value->Size() >= kTensorListMinSize) {
505       data_info->data_.resize(tensor_value->Size());
506       if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
507           EOK) {
508         MS_LOG(ERROR) << "memcpy data failed.";
509         return RET_ERROR;
510       }
511     }
512   }
513   return RET_OK;
514 }
515 
RemoveIfDepend(const CNodePtr & cnode)516 int RemoveIfDepend(const CNodePtr &cnode) {
517   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
518   bool has_depend = false;
519   std::vector<AnfNodePtr> inputs;
520   inputs.clear();
521 
522   inputs.emplace_back(cnode->input(0));
523   for (size_t i = 1; i < cnode->size(); ++i) {
524     AnfNodePtr input_node = cnode->input(i);
525     MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "inputNode is nullptr");
526     if (!input_node->isa<CNode>()) {
527       inputs.emplace_back(cnode->input(i));
528       continue;
529     }
530     if (opt::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
531       auto depend_node = utils::cast<CNodePtr>(input_node);
532       MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
533       has_depend = true;
534       bool mask_out = (depend_node->size() == opt::kInputSizeThree);
535       for (size_t j = 1; j < depend_node->size(); ++j) {
536         AnfNodePtr depend_input_node = depend_node->input(j);
537         MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
538         inputs.emplace_back(depend_input_node);
539         if (mask_out) {
540           break;
541         }
542       }
543     } else {
544       inputs.emplace_back(cnode->input(i));
545     }
546   }
547   if (has_depend) {
548     cnode->set_inputs(inputs);
549   }
550   return RET_OK;
551 }
552 
GetFlattenInputsIfMakeTuple(const CNodePtr & cnode,std::vector<AnfNodePtr> * inputs,bool * has_make_tuple)553 int GetFlattenInputsIfMakeTuple(const CNodePtr &cnode, std::vector<AnfNodePtr> *inputs, bool *has_make_tuple) {
554   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "Cnode is nullptr.");
555   MS_CHECK_TRUE_MSG(inputs != nullptr, RET_NULL_PTR, "Inputs is nullptr.");
556   MS_CHECK_TRUE_MSG(has_make_tuple != nullptr, RET_NULL_PTR, "Has make tuple is nullptr.");
557   for (size_t i = 1; i < cnode->size(); ++i) {
558     AnfNodePtr input_node = cnode->input(i);
559     MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "Input_node is nullptr");
560     auto input_cnode = utils::cast<CNodePtr>(input_node);
561     if (input_cnode && (opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple) ||
562                         opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTupleV2))) {
563       *has_make_tuple = true;
564       GetFlattenInputsIfMakeTuple(input_cnode, inputs, has_make_tuple);
565     } else {
566       inputs->emplace_back(input_node);
567     }
568   }
569   return RET_OK;
570 }
571 
RemoveIfMakeTuple(const CNodePtr & cnode)572 int RemoveIfMakeTuple(const CNodePtr &cnode) {
573   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
574   bool has_make_tuple = false;
575   std::vector<AnfNodePtr> inputs;
576   inputs.clear();
577 
578   inputs.emplace_back(cnode->input(0));
579   if (GetFlattenInputsIfMakeTuple(cnode, &inputs, &has_make_tuple) != RET_OK) {
580     MS_LOG(ERROR) << "Trace real input of make tuple failed, name: " << cnode->fullname_with_scope();
581     return RET_ERROR;
582   }
583   if (has_make_tuple) {
584     cnode->set_inputs(inputs);
585   }
586   return RET_OK;
587 }
588 
FetchOpParameterFromNode(const AnfNodePtr & node,OpParameter ** op_parameter)589 int FetchOpParameterFromNode(const AnfNodePtr &node, OpParameter **op_parameter) {
590   if (op_parameter == nullptr) {
591     MS_LOG(ERROR) << "op_parameter is nullptr.";
592     return RET_NULL_PTR;
593   }
594   CHECK_NULL_RETURN(GetValueNode<PrimitivePtr>(node));
595   auto prim_t = lite::GetPrimitiveT(node);
596   CHECK_NULL_RETURN(prim_t);
597   size_t init_size = 1024;
598   flatbuffers::FlatBufferBuilder fbb(init_size);
599   auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb);
600   if (prim == nullptr) {
601     fbb.Clear();
602     MS_LOG(ERROR) << "get primitive failed.";
603     return RET_ERROR;
604   }
605   auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR);
606   if (parameter_gen == nullptr) {
607     fbb.Clear();
608     MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
609     return RET_ERROR;
610   }
611   *op_parameter = parameter_gen(prim);
612   fbb.Clear();
613   if (*op_parameter == nullptr) {
614     MS_LOG(ERROR) << "parameter is nullptr.";
615     return RET_ERROR;
616   }
617   return RET_OK;
618 }
619 
FetchOpParameterFromFuncGraph(const FuncGraphPtr & func_graph,std::map<std::string,OpParameter * > * op_parameters)620 int FetchOpParameterFromFuncGraph(const FuncGraphPtr &func_graph, std::map<std::string, OpParameter *> *op_parameters) {
621   MS_CHECK_TRUE_MSG(op_parameters != nullptr, RET_NULL_PTR, "op_parameters is nullptr.");
622   auto cnodes = func_graph->GetOrderedCnodes();
623   for (auto &cnode : cnodes) {
624     if (opt::IsSpecialType(cnode)) {
625       continue;
626     }
627     auto primitive = cnode->input(0);
628     OpParameter *parameter = nullptr;
629     auto ret = lite::FetchOpParameterFromNode(primitive, &parameter);
630     if (ret != lite::RET_OK) {
631       MS_LOG(ERROR) << cnode->fullname_with_scope() << " FetchOpParameterFromNode failed. ";
632       return ret;
633     }
634     CHECK_NULL_RETURN(parameter);
635     parameter->thread_num_ = 1;
636     op_parameters->emplace(std::pair<std::string, OpParameter *>(cnode->fullname_with_scope(), parameter));
637   }
638   return RET_OK;
639 }
640 }  // namespace lite
641 }  // namespace mindspore
642