1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/onnx/onnx_node_parser.h"
18 #include <algorithm>
19 #include <vector>
20 #include <memory>
21 #include <unordered_map>
22 #include "tools/converter/parser/onnx/onnx_model_parser.h"
23 #include "nnacl/op_base.h"
24 #include "src/common/file_utils.h"
25 #include "utils/ms_utils_secure.h"
26
27 namespace mindspore {
28 namespace lite {
29 namespace {
30 constexpr int kMaxValidCharacters = 10;
31 static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
32 {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
33 {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
34 {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16},
35 {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32},
36 {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
37 {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
38 {onnx::TensorProto_DataType_UINT64, mindspore::kNumberTypeUInt64},
39 {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
40 {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
41 {onnx::TensorProto_DataType_DOUBLE, mindspore::kNumberTypeFloat64},
42 {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
43 } // namespace
44
45 int64_t OnnxNodeParser::opset_version_ = 0;
46
Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> & external_data,ExternalDataInfo * external_data_info)47 STATUS ExternalDataInfo::Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &external_data,
48 ExternalDataInfo *external_data_info) {
49 const int data_size = external_data.size();
50 for (int i = 0; i != data_size; ++i) {
51 onnx::StringStringEntryProto string_map = external_data[i];
52 if (!string_map.has_key()) {
53 MS_LOG(ERROR) << "No key is in external data.";
54 return RET_ERROR;
55 }
56 if (!string_map.has_value()) {
57 MS_LOG(ERROR) << "No value is in external data.";
58 return RET_ERROR;
59 }
60
61 if (StringMapKeyIs("location", string_map)) {
62 external_data_info->relative_path_ = string_map.value();
63 } else if (StringMapKeyIs("offset", string_map)) {
64 external_data_info->offset_ = strtol(string_map.value().c_str(), nullptr, kMaxValidCharacters);
65 if (std::to_string(external_data_info->offset_).length() != string_map.value().length()) {
66 MS_LOG(ERROR) << "Failed to parse offset with size " << std::to_string(external_data_info->offset_).length()
67 << ", expected size is " << string_map.value().length();
68 return RET_ERROR;
69 }
70 } else if (StringMapKeyIs("length", string_map)) {
71 external_data_info->length_ =
72 static_cast<size_t>(strtol(string_map.value().c_str(), nullptr, kMaxValidCharacters));
73 if (std::to_string(external_data_info->length_).length() != string_map.value().length()) {
74 MS_LOG(ERROR) << "Failed to parse length with size " << std::to_string(external_data_info->offset_).length()
75 << ", expected size is " << string_map.value().length();
76 return RET_ERROR;
77 }
78 } else if (StringMapKeyIs("checksum", string_map)) {
79 external_data_info->checksum_ = string_map.value();
80 } else {
81 MS_LOG(ERROR) << "Invalid model format";
82 return RET_ERROR;
83 }
84 }
85 return RET_OK;
86 }
87
StringMapKeyIs(const std::string & key,const onnx::StringStringEntryProto & string_map)88 bool ExternalDataInfo::StringMapKeyIs(const std::string &key, const onnx::StringStringEntryProto &string_map) {
89 return string_map.key() == key && !string_map.value().empty();
90 }
91
GetOnnxPadMode(const onnx::AttributeProto & onnx_node_attr)92 mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
93 if (onnx_node_attr.s() == "NOTSET") {
94 return mindspore::PadMode::PAD;
95 } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") {
96 return mindspore::PadMode::SAME;
97 } else if (onnx_node_attr.s() == "VALID") {
98 return mindspore::PadMode::VALID;
99 } else {
100 MS_LOG(ERROR) << "unsupported padMode";
101 return mindspore::PadMode::PAD;
102 }
103 }
104
CopyOnnxTensorData(const onnx::TensorProto & onnx_const_tensor)105 tensor::TensorPtr OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor) {
106 auto onnx_data_type = static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type());
107 auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_data_type);
108 if (data_type == kTypeUnknown) {
109 MS_LOG(ERROR) << "not support onnx data type " << onnx_data_type;
110 return nullptr;
111 }
112 std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
113 auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
114 if (tensor_info == nullptr) {
115 MS_LOG(ERROR) << "new a tensor::Tensor failed, data type: " << data_type << ", shape: " << shape_vector;
116 return nullptr;
117 }
118 bool overflow = false;
119 auto data_count = GetOnnxElementNum(onnx_const_tensor, &overflow);
120 if (overflow) {
121 MS_LOG(ERROR) << "data count overflow, tensor shape: " << shape_vector;
122 return nullptr;
123 }
124 if (data_count == 0) {
125 return tensor_info;
126 }
127 auto type_size = lite::DataTypeSize(data_type);
128 if (type_size == 0) {
129 MS_LOG(ERROR) << "Unsupported data type: " << data_type;
130 return nullptr;
131 }
132 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, type_size, SIZE_MAX)) {
133 MS_LOG(ERROR) << "data_size overflow";
134 return nullptr;
135 }
136 auto data_size = data_count * type_size;
137 auto tensor_data = tensor_info->data_c();
138 if (tensor_data == nullptr) {
139 MS_LOG(ERROR) << "Dst tensor cannot be nullptr";
140 return nullptr;
141 }
142 auto dst_bytes_size = tensor_info->data().nbytes();
143 if (dst_bytes_size != SizeToLong(data_size)) {
144 MS_LOG(ERROR) << "Calculated data size " << data_size << " != tensor bytes size " << dst_bytes_size;
145 return nullptr;
146 }
147 if (onnx_const_tensor.raw_data().size() != 0) {
148 auto ret = GetOnnxRawData(onnx_const_tensor, data_count, tensor_info);
149 if (ret != RET_OK) {
150 MS_LOG(ERROR) << "Failed to get tensor data, data count " << data_count << ", data type " << data_type;
151 return nullptr;
152 }
153 } else {
154 auto ret = GetOnnxListData(onnx_const_tensor, data_count, tensor_info);
155 if (ret != RET_OK) {
156 MS_LOG(ERROR) << "Failed to get tensor data, data count " << data_count << ", data type " << data_type;
157 return nullptr;
158 }
159 }
160 return tensor_info;
161 }
162
GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type)163 TypeId OnnxNodeParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
164 auto iter = kOnnxTypeTransferMap.find(onnx_type);
165 if (iter == kOnnxTypeTransferMap.end()) {
166 MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type;
167 return kTypeUnknown;
168 }
169 return iter->second;
170 }
171
SetTypeAndValueForFloat(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count)172 void OnnxNodeParser::SetTypeAndValueForFloat(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
173 size_t data_count) {
174 for (size_t i = 0; i < data_count; i++) {
175 value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i]));
176 }
177 }
178
SetTypeAndValueForBool(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count)179 void OnnxNodeParser::SetTypeAndValueForBool(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
180 size_t data_count) {
181 for (size_t i = 0; i < data_count; i++) {
182 value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i]));
183 }
184 }
185
SetDataTypeAndValue(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,size_t data_count,int * type)186 STATUS OnnxNodeParser::SetDataTypeAndValue(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
187 size_t data_count, int *type) {
188 switch (onnx_tensor.data_type()) {
189 case onnx::TensorProto_DataType_FLOAT:
190 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
191 if (onnx_tensor.float_data_size() > 0) {
192 for (int i = 0; i < onnx_tensor.float_data_size(); i++) {
193 value->push_back(onnx_tensor.float_data(i));
194 }
195 } else {
196 for (size_t i = 0; i < data_count; i++) {
197 value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
198 }
199 }
200 break;
201 case onnx::TensorProto_DataType_INT32:
202 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
203 if (onnx_tensor.int32_data_size() > 0) {
204 for (int i = 0; i < onnx_tensor.int32_data_size(); i++) {
205 value->push_back(onnx_tensor.int32_data(i));
206 }
207 } else {
208 for (size_t i = 0; i < data_count; i++) {
209 value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
210 }
211 }
212 break;
213 case onnx::TensorProto_DataType_INT64:
214 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT64);
215 if (onnx_tensor.int64_data_size() > 0) {
216 for (int i = 0; i < onnx_tensor.int64_data_size(); i++) {
217 value->push_back(onnx_tensor.int64_data(i));
218 }
219 } else {
220 for (size_t i = 0; i < data_count; i++) {
221 value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
222 }
223 }
224 break;
225 case onnx::TensorProto_DataType_FLOAT16:
226 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT16);
227 SetTypeAndValueForFloat(onnx_tensor, value, data_count);
228 break;
229 case onnx::TensorProto_DataType_BOOL:
230 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL);
231 SetTypeAndValueForBool(onnx_tensor, value, data_count);
232 break;
233 default:
234 MS_LOG(ERROR) << "The data type is not supported.";
235 return RET_ERROR;
236 }
237 return RET_OK;
238 }
239
GetTensorDataFromOnnx(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,int * type)240 STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
241 int *type) {
242 if (value == nullptr || type == nullptr) {
243 MS_LOG(ERROR) << "Input value or type is nullptr";
244 return RET_INPUT_PARAM_INVALID;
245 }
246 bool overflow = false;
247 auto data_count = GetOnnxElementNum(onnx_tensor, &overflow);
248 if (overflow) {
249 MS_LOG(ERROR) << "data count overflow";
250 return RET_ERROR;
251 }
252 return SetDataTypeAndValue(onnx_tensor, value, data_count, type);
253 }
254
GetOnnxElementNum(const onnx::TensorProto & onnx_tensor,bool * overflowed)255 size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed) {
256 size_t data_count = 1;
257 bool is_overflow = false;
258 if (!onnx_tensor.dims().empty()) {
259 std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count, &is_overflow](int dim) {
260 if (is_overflow || dim < 0) {
261 is_overflow = true;
262 data_count = 0;
263 return;
264 }
265 auto udim = static_cast<size_t>(dim);
266 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, udim, SIZE_MAX)) {
267 is_overflow = true;
268 data_count = 0;
269 return;
270 }
271 data_count *= udim;
272 });
273 }
274 if (overflowed != nullptr) {
275 *overflowed = is_overflow;
276 }
277 return data_count;
278 }
279
LoadOnnxExternalTensorData(const onnx::TensorProto & onnx_const_tensor,const tensor::TensorPtr & tensor_info,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)280 STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
281 const tensor::TensorPtr &tensor_info, const std::string &model_file,
282 std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
283 if (tensor_info == nullptr) {
284 MS_LOG(ERROR) << "tensor_info is nullptr.";
285 return RET_NULL_PTR;
286 }
287 size_t data_size = 0;
288 const void *onnx_data = LoadOnnxRawData(onnx_const_tensor, &data_size, model_file, external_datas);
289 if (onnx_data == nullptr) {
290 MS_LOG(ERROR) << "origin data from external data is nullptr.";
291 return RET_MEMORY_FAILED;
292 }
293 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
294 if (common::huge_memcpy(tensor_data, static_cast<size_t>(tensor_info->data().nbytes()),
295 static_cast<const uint8_t *>(onnx_data), data_size) != EOK) {
296 MS_LOG(ERROR) << "memcpy_s from onnx tensor data to mindspore tensor data failed, dst size "
297 << tensor_info->data().nbytes() << ", src size " << data_size;
298 return RET_ERROR;
299 }
300 return RET_OK;
301 }
302
SetExternalTensorFile(const std::string & model_file,std::string * external_tensor_dir)303 STATUS OnnxNodeParser::SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir) {
304 CHECK_NULL_RETURN(external_tensor_dir);
305 auto i_end_index = model_file.find_last_of('/');
306 if (i_end_index == std::string::npos) {
307 i_end_index = model_file.find_last_of('\\');
308 }
309 if (i_end_index == std::string::npos) {
310 *external_tensor_dir = ".";
311 } else {
312 *external_tensor_dir = model_file.substr(0, i_end_index);
313 }
314 return RET_OK;
315 }
316
317 template <class DstT, class SrcT>
CopyOnnxData(void * dst_v,const void * src_v,size_t data_count)318 static int CopyOnnxData(void *dst_v, const void *src_v, size_t data_count) {
319 if (dst_v == nullptr || src_v == nullptr) {
320 MS_LOG(ERROR) << "Dst or src data cannot be nullptr";
321 return RET_ERROR;
322 }
323 if (sizeof(DstT) == sizeof(SrcT)) {
324 if (memcpy_s(dst_v, data_count * sizeof(DstT), src_v, data_count * sizeof(SrcT)) != EOK) {
325 MS_LOG(ERROR) << "memcpy_s failed, data size " << data_count * sizeof(DstT);
326 return RET_ERROR;
327 }
328 return RET_OK;
329 }
330 auto src = reinterpret_cast<const SrcT *>(src_v);
331 auto dst = reinterpret_cast<DstT *>(dst_v);
332 for (size_t i = 0; i < data_count; i++) {
333 dst[i] = static_cast<DstT>(src[i]);
334 }
335 return RET_OK;
336 }
337
GetOnnxRawData(const onnx::TensorProto & onnx_const_tensor,size_t data_count,const tensor::TensorPtr & tensor_info)338 int OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
339 const tensor::TensorPtr &tensor_info) {
340 auto data_size = LongToSize(tensor_info->data().nbytes());
341 auto tensor_data = tensor_info->data_c();
342 auto onnx_data = onnx_const_tensor.raw_data().data();
343 if (onnx_const_tensor.raw_data().size() != data_size) {
344 MS_LOG(ERROR) << "Tensor raw data size " << onnx_const_tensor.raw_data().size() << " != expected size "
345 << data_size;
346 return RET_ERROR;
347 }
348 return CopyOnnxData<uint8_t, uint8_t>(tensor_data, onnx_data, data_size);
349 }
350
GetOnnxListData(const onnx::TensorProto & onnx_const_tensor,size_t data_count,const tensor::TensorPtr & tensor_info)351 int OnnxNodeParser::GetOnnxListData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
352 const tensor::TensorPtr &tensor_info) {
353 const void *onnx_data = nullptr;
354 auto tensor_data = tensor_info->data_c();
355 TypeId data_type = tensor_info->Dtype()->type_id();
356 auto type_size = lite::DataTypeSize(data_type);
357 switch (data_type) {
358 case kNumberTypeFloat32:
359 MS_CHECK_EQ(onnx_const_tensor.float_data_size(), SizeToLong(data_count), RET_ERROR);
360 onnx_data = onnx_const_tensor.float_data().data();
361 return CopyOnnxData<float, float>(tensor_data, onnx_data, data_count);
362 case kNumberTypeFloat64:
363 MS_CHECK_EQ(onnx_const_tensor.double_data_size(), SizeToLong(data_count), RET_ERROR);
364 onnx_data = onnx_const_tensor.double_data().data();
365 return CopyOnnxData<double, double>(tensor_data, onnx_data, data_count);
366 case kNumberTypeInt64:
367 MS_CHECK_EQ(onnx_const_tensor.int64_data_size(), SizeToLong(data_count), RET_ERROR);
368 onnx_data = onnx_const_tensor.int64_data().data();
369 return CopyOnnxData<int64_t, int64_t>(tensor_data, onnx_data, data_count);
370 case kNumberTypeUInt64:
371 case kNumberTypeUInt32:
372 MS_CHECK_EQ(onnx_const_tensor.uint64_data_size(), SizeToLong(data_count), RET_ERROR);
373 onnx_data = onnx_const_tensor.uint64_data().data();
374 if (data_type == kNumberTypeUInt32) {
375 return CopyOnnxData<uint32_t, uint64_t>(tensor_data, onnx_data, data_count);
376 } else {
377 return CopyOnnxData<uint64_t, uint64_t>(tensor_data, onnx_data, data_count);
378 }
379 case kNumberTypeInt32:
380 case kNumberTypeInt16:
381 case kNumberTypeInt8:
382 case kNumberTypeUInt16:
383 case kNumberTypeUInt8:
384 case kNumberTypeBool:
385 case kNumberTypeFloat16:
386 MS_CHECK_EQ(onnx_const_tensor.int32_data_size(), SizeToLong(data_count), RET_ERROR);
387 onnx_data = onnx_const_tensor.int32_data().data();
388 if (type_size == sizeof(int32_t)) {
389 return CopyOnnxData<int32_t, int32_t>(tensor_data, onnx_data, data_count);
390 } else if (type_size == sizeof(uint16_t)) {
391 return CopyOnnxData<uint16_t, int32_t>(tensor_data, onnx_data, data_count);
392 } else if (type_size == sizeof(uint8_t)) {
393 return CopyOnnxData<uint8_t, int32_t>(tensor_data, onnx_data, data_count);
394 }
395 break;
396 default:
397 break;
398 }
399 MS_LOG(ERROR) << "unsupported data type " << data_type;
400 return RET_ERROR;
401 }
402
LoadOnnxRawData(const onnx::TensorProto & onnx_const_tensor,size_t * data_size,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)403 const void *OnnxNodeParser::LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
404 const std::string &model_file,
405 std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
406 MS_ERROR_IF_NULL_W_RET_VAL(data_size, nullptr);
407 MS_ERROR_IF_NULL_W_RET_VAL(external_datas, nullptr);
408 ExternalDataInfo external_data_info;
409 if (ExternalDataInfo::Create(onnx_const_tensor.external_data(), &external_data_info) != RET_OK) {
410 MS_LOG(ERROR) << "Create ExternalDataInfo failed.";
411 return nullptr;
412 }
413 auto data_path = external_data_info.GetRelativePath();
414 auto it = external_datas->find(data_path);
415 size_t external_data_size = 0;
416 uint8_t *external_data = nullptr;
417 if (it == external_datas->end()) {
418 std::string external_tensor_dir;
419 if (SetExternalTensorFile(model_file, &external_tensor_dir) != RET_OK) {
420 MS_LOG(ERROR) << "Failed to set external tensor file.";
421 return nullptr;
422 }
423 #ifdef _WIN32
424 std::string external_data_file = external_tensor_dir + "\\" + data_path;
425 #else
426 std::string external_data_file = external_tensor_dir + "/" + data_path;
427 #endif
428 external_data = reinterpret_cast<uint8_t *>(ReadFile(external_data_file.c_str(), &external_data_size));
429 if (external_data == nullptr || external_data_size == 0) {
430 MS_LOG(ERROR) << "Failed to read external tensor file " << external_data_file;
431 return nullptr;
432 }
433 external_datas->emplace(data_path, std::make_pair(external_data_size, external_data));
434 } else {
435 external_data_size = it->second.first;
436 external_data = it->second.second;
437 }
438 auto offset = external_data_info.GetOffset();
439 auto length = external_data_info.GetLength();
440 if (length == 0 && offset == 0) { // not set length and offset
441 *data_size = external_data_size;
442 return external_data;
443 }
444 if (length == 0 || external_data_size < offset || external_data_size - offset < length) {
445 MS_LOG(ERROR) << "Invalid external data info, data path " << data_path << ", offset " << offset << ", length "
446 << length << ", file length " << external_data_size;
447 return nullptr;
448 }
449 *data_size = length;
450 return external_data + offset;
451 }
452
GetConstantTensorData(const onnx::GraphProto & onnx_graph,const std::string & input_name)453 const onnx::TensorProto *OnnxNodeParser::GetConstantTensorData(const onnx::GraphProto &onnx_graph,
454 const std::string &input_name) {
455 auto &initializer = onnx_graph.initializer();
456 auto init_iter = std::find_if(initializer.begin(), initializer.end(),
457 [input_name](const onnx::TensorProto &proto) { return proto.name() == input_name; });
458 if (init_iter != initializer.end()) {
459 return &(*init_iter);
460 }
461 auto &nodes = onnx_graph.node();
462 auto node_iter = std::find_if(nodes.begin(), nodes.end(), [input_name](const onnx::NodeProto &proto) {
463 if (proto.op_type() != "Constant" || proto.output_size() != 1) {
464 return false;
465 }
466 return proto.output(0) == input_name;
467 });
468 if (node_iter == nodes.end()) {
469 MS_LOG(ERROR) << "Cannot find const input " << input_name;
470 return nullptr;
471 }
472 auto &onnx_node = *node_iter;
473 for (const auto &onnx_node_attr : onnx_node.attribute()) {
474 const auto &attribute_name = onnx_node_attr.name();
475 if (attribute_name == "value") {
476 if (onnx_node_attr.has_t()) {
477 return &onnx_node_attr.t();
478 }
479 break;
480 }
481 }
482 MS_LOG(ERROR) << "Failed to find const value from input " << input_name;
483 return nullptr;
484 }
485 } // namespace lite
486 } // namespace mindspore
487