• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/onnx/onnx_node_parser.h"
18 #include <algorithm>
19 #include <vector>
20 #include <memory>
21 #include <unordered_map>
22 #include "tools/converter/parser/onnx/onnx_model_parser.h"
23 #include "nnacl/op_base.h"
24 #include "src/common/file_utils.h"
25 #include "utils/ms_utils_secure.h"
26 
27 namespace mindspore {
28 namespace lite {
29 namespace {
30 constexpr int kMaxValidCharacters = 10;
31 static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
32   {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
33   {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
34   {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16},
35   {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32},
36   {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
37   {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
38   {onnx::TensorProto_DataType_UINT64, mindspore::kNumberTypeUInt64},
39   {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
40   {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
41   {onnx::TensorProto_DataType_DOUBLE, mindspore::kNumberTypeFloat64},
42   {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
43 }  // namespace
44 
45 int64_t OnnxNodeParser::opset_version_ = 0;
46 
Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> & external_data,ExternalDataInfo * external_data_info)47 STATUS ExternalDataInfo::Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &external_data,
48                                 ExternalDataInfo *external_data_info) {
49   const int data_size = external_data.size();
50   for (int i = 0; i != data_size; ++i) {
51     onnx::StringStringEntryProto string_map = external_data[i];
52     if (!string_map.has_key()) {
53       MS_LOG(ERROR) << "No key is in external data.";
54       return RET_ERROR;
55     }
56     if (!string_map.has_value()) {
57       MS_LOG(ERROR) << "No value is in external data.";
58       return RET_ERROR;
59     }
60 
61     if (StringMapKeyIs("location", string_map)) {
62       external_data_info->relative_path_ = string_map.value();
63     } else if (StringMapKeyIs("offset", string_map)) {
64       external_data_info->offset_ = strtol(string_map.value().c_str(), nullptr, kMaxValidCharacters);
65       if (std::to_string(external_data_info->offset_).length() != string_map.value().length()) {
66         MS_LOG(ERROR) << "Failed to parse offset with size " << std::to_string(external_data_info->offset_).length()
67                       << ", expected size is " << string_map.value().length();
68         return RET_ERROR;
69       }
70     } else if (StringMapKeyIs("length", string_map)) {
71       external_data_info->length_ =
72         static_cast<size_t>(strtol(string_map.value().c_str(), nullptr, kMaxValidCharacters));
73       if (std::to_string(external_data_info->length_).length() != string_map.value().length()) {
74         MS_LOG(ERROR) << "Failed to parse length with size " << std::to_string(external_data_info->offset_).length()
75                       << ", expected size is " << string_map.value().length();
76         return RET_ERROR;
77       }
78     } else if (StringMapKeyIs("checksum", string_map)) {
79       external_data_info->checksum_ = string_map.value();
80     } else {
81       MS_LOG(ERROR) << "Invalid model format";
82       return RET_ERROR;
83     }
84   }
85   return RET_OK;
86 }
87 
StringMapKeyIs(const std::string & key,const onnx::StringStringEntryProto & string_map)88 bool ExternalDataInfo::StringMapKeyIs(const std::string &key, const onnx::StringStringEntryProto &string_map) {
89   return string_map.key() == key && !string_map.value().empty();
90 }
91 
GetOnnxPadMode(const onnx::AttributeProto & onnx_node_attr)92 mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
93   if (onnx_node_attr.s() == "NOTSET") {
94     return mindspore::PadMode::PAD;
95   } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") {
96     return mindspore::PadMode::SAME;
97   } else if (onnx_node_attr.s() == "VALID") {
98     return mindspore::PadMode::VALID;
99   } else {
100     MS_LOG(ERROR) << "unsupported padMode";
101     return mindspore::PadMode::PAD;
102   }
103 }
104 
CopyOnnxTensorData(const onnx::TensorProto & onnx_const_tensor)105 tensor::TensorPtr OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor) {
106   auto onnx_data_type = static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type());
107   auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_data_type);
108   if (data_type == kTypeUnknown) {
109     MS_LOG(ERROR) << "not support onnx data type " << onnx_data_type;
110     return nullptr;
111   }
112   std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
113   auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
114   if (tensor_info == nullptr) {
115     MS_LOG(ERROR) << "new a tensor::Tensor failed, data type: " << data_type << ", shape: " << shape_vector;
116     return nullptr;
117   }
118   bool overflow = false;
119   auto data_count = GetOnnxElementNum(onnx_const_tensor, &overflow);
120   if (overflow) {
121     MS_LOG(ERROR) << "data count overflow, tensor shape: " << shape_vector;
122     return nullptr;
123   }
124   if (data_count == 0) {
125     return tensor_info;
126   }
127   auto type_size = lite::DataTypeSize(data_type);
128   if (type_size == 0) {
129     MS_LOG(ERROR) << "Unsupported data type: " << data_type;
130     return nullptr;
131   }
132   if (INT_MUL_OVERFLOW_THRESHOLD(data_count, type_size, SIZE_MAX)) {
133     MS_LOG(ERROR) << "data_size overflow";
134     return nullptr;
135   }
136   auto data_size = data_count * type_size;
137   auto tensor_data = tensor_info->data_c();
138   if (tensor_data == nullptr) {
139     MS_LOG(ERROR) << "Dst tensor cannot be nullptr";
140     return nullptr;
141   }
142   auto dst_bytes_size = tensor_info->data().nbytes();
143   if (dst_bytes_size != SizeToLong(data_size)) {
144     MS_LOG(ERROR) << "Calculated data size " << data_size << " != tensor bytes size " << dst_bytes_size;
145     return nullptr;
146   }
147   if (onnx_const_tensor.raw_data().size() != 0) {
148     auto ret = GetOnnxRawData(onnx_const_tensor, data_count, tensor_info);
149     if (ret != RET_OK) {
150       MS_LOG(ERROR) << "Failed to get tensor data, data count " << data_count << ", data type " << data_type;
151       return nullptr;
152     }
153   } else {
154     auto ret = GetOnnxListData(onnx_const_tensor, data_count, tensor_info);
155     if (ret != RET_OK) {
156       MS_LOG(ERROR) << "Failed to get tensor data, data count " << data_count << ", data type " << data_type;
157       return nullptr;
158     }
159   }
160   return tensor_info;
161 }
162 
GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type)163 TypeId OnnxNodeParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
164   auto iter = kOnnxTypeTransferMap.find(onnx_type);
165   if (iter == kOnnxTypeTransferMap.end()) {
166     MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type;
167     return kTypeUnknown;
168   }
169   return iter->second;
170 }
171 
SetTypeAndValueForFloat(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count)172 void OnnxNodeParser::SetTypeAndValueForFloat(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
173                                              size_t data_count) {
174   for (size_t i = 0; i < data_count; i++) {
175     value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i]));
176   }
177 }
178 
SetTypeAndValueForBool(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count)179 void OnnxNodeParser::SetTypeAndValueForBool(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
180                                             size_t data_count) {
181   for (size_t i = 0; i < data_count; i++) {
182     value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i]));
183   }
184 }
185 
SetDataTypeAndValue(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count,int * type)186 STATUS OnnxNodeParser::SetDataTypeAndValue(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
187                                            size_t data_count, int *type) {
188   switch (onnx_tensor.data_type()) {
189     case onnx::TensorProto_DataType_FLOAT:
190       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
191       if (onnx_tensor.float_data_size() > 0) {
192         for (int i = 0; i < onnx_tensor.float_data_size(); i++) {
193           value->push_back(onnx_tensor.float_data(i));
194         }
195       } else {
196         for (size_t i = 0; i < data_count; i++) {
197           value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
198         }
199       }
200       break;
201     case onnx::TensorProto_DataType_INT32:
202       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
203       if (onnx_tensor.int32_data_size() > 0) {
204         for (int i = 0; i < onnx_tensor.int32_data_size(); i++) {
205           value->push_back(onnx_tensor.int32_data(i));
206         }
207       } else {
208         for (size_t i = 0; i < data_count; i++) {
209           value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
210         }
211       }
212       break;
213     case onnx::TensorProto_DataType_INT64:
214       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT64);
215       if (onnx_tensor.int64_data_size() > 0) {
216         for (int i = 0; i < onnx_tensor.int64_data_size(); i++) {
217           value->push_back(onnx_tensor.int64_data(i));
218         }
219       } else {
220         for (size_t i = 0; i < data_count; i++) {
221           value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
222         }
223       }
224       break;
225     case onnx::TensorProto_DataType_FLOAT16:
226       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT16);
227       SetTypeAndValueForFloat(onnx_tensor, value, data_count);
228       break;
229     case onnx::TensorProto_DataType_BOOL:
230       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL);
231       SetTypeAndValueForBool(onnx_tensor, value, data_count);
232       break;
233     default:
234       MS_LOG(ERROR) << "The data type is not supported.";
235       return RET_ERROR;
236   }
237   return RET_OK;
238 }
239 
GetTensorDataFromOnnx(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,int * type)240 STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
241                                              int *type) {
242   if (value == nullptr || type == nullptr) {
243     MS_LOG(ERROR) << "Input value or type is nullptr";
244     return RET_INPUT_PARAM_INVALID;
245   }
246   bool overflow = false;
247   auto data_count = GetOnnxElementNum(onnx_tensor, &overflow);
248   if (overflow) {
249     MS_LOG(ERROR) << "data count overflow";
250     return RET_ERROR;
251   }
252   return SetDataTypeAndValue(onnx_tensor, value, data_count, type);
253 }
254 
GetOnnxElementNum(const onnx::TensorProto & onnx_tensor,bool * overflowed)255 size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed) {
256   size_t data_count = 1;
257   bool is_overflow = false;
258   if (!onnx_tensor.dims().empty()) {
259     std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count, &is_overflow](int dim) {
260       if (is_overflow || dim < 0) {
261         is_overflow = true;
262         data_count = 0;
263         return;
264       }
265       auto udim = static_cast<size_t>(dim);
266       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, udim, SIZE_MAX)) {
267         is_overflow = true;
268         data_count = 0;
269         return;
270       }
271       data_count *= udim;
272     });
273   }
274   if (overflowed != nullptr) {
275     *overflowed = is_overflow;
276   }
277   return data_count;
278 }
279 
LoadOnnxExternalTensorData(const onnx::TensorProto & onnx_const_tensor,const tensor::TensorPtr & tensor_info,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)280 STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
281                                                   const tensor::TensorPtr &tensor_info, const std::string &model_file,
282                                                   std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
283   if (tensor_info == nullptr) {
284     MS_LOG(ERROR) << "tensor_info is nullptr.";
285     return RET_NULL_PTR;
286   }
287   size_t data_size = 0;
288   const void *onnx_data = LoadOnnxRawData(onnx_const_tensor, &data_size, model_file, external_datas);
289   if (onnx_data == nullptr) {
290     MS_LOG(ERROR) << "origin data from external data is nullptr.";
291     return RET_MEMORY_FAILED;
292   }
293   auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
294   if (common::huge_memcpy(tensor_data, static_cast<size_t>(tensor_info->data().nbytes()),
295                           static_cast<const uint8_t *>(onnx_data), data_size) != EOK) {
296     MS_LOG(ERROR) << "memcpy_s from onnx tensor data to mindspore tensor data failed, dst size "
297                   << tensor_info->data().nbytes() << ", src size " << data_size;
298     return RET_ERROR;
299   }
300   return RET_OK;
301 }
302 
SetExternalTensorFile(const std::string & model_file,std::string * external_tensor_dir)303 STATUS OnnxNodeParser::SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir) {
304   CHECK_NULL_RETURN(external_tensor_dir);
305   auto i_end_index = model_file.find_last_of('/');
306   if (i_end_index == std::string::npos) {
307     i_end_index = model_file.find_last_of('\\');
308   }
309   if (i_end_index == std::string::npos) {
310     *external_tensor_dir = ".";
311   } else {
312     *external_tensor_dir = model_file.substr(0, i_end_index);
313   }
314   return RET_OK;
315 }
316 
317 template <class DstT, class SrcT>
CopyOnnxData(void * dst_v,const void * src_v,size_t data_count)318 static int CopyOnnxData(void *dst_v, const void *src_v, size_t data_count) {
319   if (dst_v == nullptr || src_v == nullptr) {
320     MS_LOG(ERROR) << "Dst or src data cannot be nullptr";
321     return RET_ERROR;
322   }
323   if (sizeof(DstT) == sizeof(SrcT)) {
324     if (memcpy_s(dst_v, data_count * sizeof(DstT), src_v, data_count * sizeof(SrcT)) != EOK) {
325       MS_LOG(ERROR) << "memcpy_s failed, data size " << data_count * sizeof(DstT);
326       return RET_ERROR;
327     }
328     return RET_OK;
329   }
330   auto src = reinterpret_cast<const SrcT *>(src_v);
331   auto dst = reinterpret_cast<DstT *>(dst_v);
332   for (size_t i = 0; i < data_count; i++) {
333     dst[i] = static_cast<DstT>(src[i]);
334   }
335   return RET_OK;
336 }
337 
GetOnnxRawData(const onnx::TensorProto & onnx_const_tensor,size_t data_count,const tensor::TensorPtr & tensor_info)338 int OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
339                                    const tensor::TensorPtr &tensor_info) {
340   auto data_size = LongToSize(tensor_info->data().nbytes());
341   auto tensor_data = tensor_info->data_c();
342   auto onnx_data = onnx_const_tensor.raw_data().data();
343   if (onnx_const_tensor.raw_data().size() != data_size) {
344     MS_LOG(ERROR) << "Tensor raw data size " << onnx_const_tensor.raw_data().size() << " != expected size "
345                   << data_size;
346     return RET_ERROR;
347   }
348   return CopyOnnxData<uint8_t, uint8_t>(tensor_data, onnx_data, data_size);
349 }
350 
GetOnnxListData(const onnx::TensorProto & onnx_const_tensor,size_t data_count,const tensor::TensorPtr & tensor_info)351 int OnnxNodeParser::GetOnnxListData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
352                                     const tensor::TensorPtr &tensor_info) {
353   const void *onnx_data = nullptr;
354   auto tensor_data = tensor_info->data_c();
355   TypeId data_type = tensor_info->Dtype()->type_id();
356   auto type_size = lite::DataTypeSize(data_type);
357   switch (data_type) {
358     case kNumberTypeFloat32:
359       MS_CHECK_EQ(onnx_const_tensor.float_data_size(), SizeToLong(data_count), RET_ERROR);
360       onnx_data = onnx_const_tensor.float_data().data();
361       return CopyOnnxData<float, float>(tensor_data, onnx_data, data_count);
362     case kNumberTypeFloat64:
363       MS_CHECK_EQ(onnx_const_tensor.double_data_size(), SizeToLong(data_count), RET_ERROR);
364       onnx_data = onnx_const_tensor.double_data().data();
365       return CopyOnnxData<double, double>(tensor_data, onnx_data, data_count);
366     case kNumberTypeInt64:
367       MS_CHECK_EQ(onnx_const_tensor.int64_data_size(), SizeToLong(data_count), RET_ERROR);
368       onnx_data = onnx_const_tensor.int64_data().data();
369       return CopyOnnxData<int64_t, int64_t>(tensor_data, onnx_data, data_count);
370     case kNumberTypeUInt64:
371     case kNumberTypeUInt32:
372       MS_CHECK_EQ(onnx_const_tensor.uint64_data_size(), SizeToLong(data_count), RET_ERROR);
373       onnx_data = onnx_const_tensor.uint64_data().data();
374       if (data_type == kNumberTypeUInt32) {
375         return CopyOnnxData<uint32_t, uint64_t>(tensor_data, onnx_data, data_count);
376       } else {
377         return CopyOnnxData<uint64_t, uint64_t>(tensor_data, onnx_data, data_count);
378       }
379     case kNumberTypeInt32:
380     case kNumberTypeInt16:
381     case kNumberTypeInt8:
382     case kNumberTypeUInt16:
383     case kNumberTypeUInt8:
384     case kNumberTypeBool:
385     case kNumberTypeFloat16:
386       MS_CHECK_EQ(onnx_const_tensor.int32_data_size(), SizeToLong(data_count), RET_ERROR);
387       onnx_data = onnx_const_tensor.int32_data().data();
388       if (type_size == sizeof(int32_t)) {
389         return CopyOnnxData<int32_t, int32_t>(tensor_data, onnx_data, data_count);
390       } else if (type_size == sizeof(uint16_t)) {
391         return CopyOnnxData<uint16_t, int32_t>(tensor_data, onnx_data, data_count);
392       } else if (type_size == sizeof(uint8_t)) {
393         return CopyOnnxData<uint8_t, int32_t>(tensor_data, onnx_data, data_count);
394       }
395       break;
396     default:
397       break;
398   }
399   MS_LOG(ERROR) << "unsupported data type " << data_type;
400   return RET_ERROR;
401 }
402 
LoadOnnxRawData(const onnx::TensorProto & onnx_const_tensor,size_t * data_size,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)403 const void *OnnxNodeParser::LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
404                                             const std::string &model_file,
405                                             std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
406   MS_ERROR_IF_NULL_W_RET_VAL(data_size, nullptr);
407   MS_ERROR_IF_NULL_W_RET_VAL(external_datas, nullptr);
408   ExternalDataInfo external_data_info;
409   if (ExternalDataInfo::Create(onnx_const_tensor.external_data(), &external_data_info) != RET_OK) {
410     MS_LOG(ERROR) << "Create ExternalDataInfo failed.";
411     return nullptr;
412   }
413   auto data_path = external_data_info.GetRelativePath();
414   auto it = external_datas->find(data_path);
415   size_t external_data_size = 0;
416   uint8_t *external_data = nullptr;
417   if (it == external_datas->end()) {
418     std::string external_tensor_dir;
419     if (SetExternalTensorFile(model_file, &external_tensor_dir) != RET_OK) {
420       MS_LOG(ERROR) << "Failed to set external tensor file.";
421       return nullptr;
422     }
423 #ifdef _WIN32
424     std::string external_data_file = external_tensor_dir + "\\" + data_path;
425 #else
426     std::string external_data_file = external_tensor_dir + "/" + data_path;
427 #endif
428     external_data = reinterpret_cast<uint8_t *>(ReadFile(external_data_file.c_str(), &external_data_size));
429     if (external_data == nullptr || external_data_size == 0) {
430       MS_LOG(ERROR) << "Failed to read external tensor file " << external_data_file;
431       return nullptr;
432     }
433     external_datas->emplace(data_path, std::make_pair(external_data_size, external_data));
434   } else {
435     external_data_size = it->second.first;
436     external_data = it->second.second;
437   }
438   auto offset = external_data_info.GetOffset();
439   auto length = external_data_info.GetLength();
440   if (length == 0 && offset == 0) {  // not set length and offset
441     *data_size = external_data_size;
442     return external_data;
443   }
444   if (length == 0 || external_data_size < offset || external_data_size - offset < length) {
445     MS_LOG(ERROR) << "Invalid external data info, data path " << data_path << ", offset " << offset << ", length "
446                   << length << ", file length " << external_data_size;
447     return nullptr;
448   }
449   *data_size = length;
450   return external_data + offset;
451 }
452 
GetConstantTensorData(const onnx::GraphProto & onnx_graph,const std::string & input_name)453 const onnx::TensorProto *OnnxNodeParser::GetConstantTensorData(const onnx::GraphProto &onnx_graph,
454                                                                const std::string &input_name) {
455   auto &initializer = onnx_graph.initializer();
456   auto init_iter = std::find_if(initializer.begin(), initializer.end(),
457                                 [input_name](const onnx::TensorProto &proto) { return proto.name() == input_name; });
458   if (init_iter != initializer.end()) {
459     return &(*init_iter);
460   }
461   auto &nodes = onnx_graph.node();
462   auto node_iter = std::find_if(nodes.begin(), nodes.end(), [input_name](const onnx::NodeProto &proto) {
463     if (proto.op_type() != "Constant" || proto.output_size() != 1) {
464       return false;
465     }
466     return proto.output(0) == input_name;
467   });
468   if (node_iter == nodes.end()) {
469     MS_LOG(ERROR) << "Cannot find const input " << input_name;
470     return nullptr;
471   }
472   auto &onnx_node = *node_iter;
473   for (const auto &onnx_node_attr : onnx_node.attribute()) {
474     const auto &attribute_name = onnx_node_attr.name();
475     if (attribute_name == "value") {
476       if (onnx_node_attr.has_t()) {
477         return &onnx_node_attr.t();
478       }
479       break;
480     }
481   }
482   MS_LOG(ERROR) << "Failed to find const value from input " << input_name;
483   return nullptr;
484 }
485 }  // namespace lite
486 }  // namespace mindspore
487