1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/onnx/onnx_node_parser.h"
18 #include <algorithm>
19 #include <vector>
20 #include <memory>
21 #include <unordered_map>
22 #include "tools/converter/parser/onnx/onnx_model_parser.h"
23 #include "nnacl/op_base.h"
24
25 namespace mindspore {
26 namespace lite {
27 namespace {
28 static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
29 {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
30 {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
31 {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16},
32 {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32},
33 {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
34 {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
35 {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
36 {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
37 {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
38 } // namespace
39
40 int64_t OnnxNodeParser::opset_version_ = 0;
41
GetOnnxPadMode(const onnx::AttributeProto & onnx_node_attr)42 mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
43 if (onnx_node_attr.s() == "NOTSET") {
44 return mindspore::PadMode::PAD;
45 } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") {
46 return mindspore::PadMode::SAME;
47 } else if (onnx_node_attr.s() == "VALID") {
48 return mindspore::PadMode::VALID;
49 } else {
50 MS_LOG(ERROR) << "unsupported padMode";
51 return mindspore::PadMode::PAD;
52 }
53 }
54
CopyOnnxTensorData(const onnx::TensorProto & onnx_const_tensor,const tensor::TensorPtr & tensor_info)55 STATUS OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
56 const tensor::TensorPtr &tensor_info) {
57 if (tensor_info == nullptr) {
58 MS_LOG(ERROR) << "tensor_info is nullptr.";
59 return RET_NULL_PTR;
60 }
61 bool overflow = false;
62 auto data_count = GetOnnxElementNum(onnx_const_tensor, &overflow);
63 if (overflow) {
64 MS_LOG(ERROR) << "data count overflow";
65 return RET_ERROR;
66 }
67 size_t data_size = 0;
68 auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()));
69 const void *onnx_data = GetOnnxRawData(onnx_const_tensor, data_type, data_count, &data_size);
70 if (data_size == 0) {
71 return RET_OK;
72 }
73 if (onnx_data == nullptr) {
74 MS_LOG(ERROR) << "origin data in onnx model is nullptr";
75 return RET_MEMORY_FAILED;
76 }
77 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
78 if (memcpy_s(tensor_data, tensor_info->data().nbytes(), onnx_data, data_size) != EOK) {
79 MS_LOG(ERROR) << "memcpy_s failed";
80 return RET_ERROR;
81 }
82 return RET_OK;
83 }
84
GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type)85 TypeId OnnxNodeParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
86 auto iter = kOnnxTypeTransferMap.find(onnx_type);
87 if (iter == kOnnxTypeTransferMap.end()) {
88 MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type;
89 return kTypeUnknown;
90 }
91 return iter->second;
92 }
93
GetTensorDataFromOnnx(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,int * type)94 STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
95 int *type) {
96 if (value == nullptr || type == nullptr) {
97 MS_LOG(ERROR) << "Input value or type is nullptr";
98 return RET_INPUT_PARAM_INVALID;
99 }
100 bool overflow = false;
101 auto data_count = GetOnnxElementNum(onnx_tensor, &overflow);
102 if (overflow) {
103 MS_LOG(ERROR) << "data count overflow";
104 return RET_ERROR;
105 }
106 switch (onnx_tensor.data_type()) {
107 case onnx::TensorProto_DataType_FLOAT:
108 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
109 if (onnx_tensor.float_data_size() > 0) {
110 for (int i = 0; i < onnx_tensor.float_data_size(); i++) {
111 value->push_back(onnx_tensor.float_data(i));
112 }
113 } else {
114 for (size_t i = 0; i < data_count; i++) {
115 value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
116 }
117 }
118 break;
119 case onnx::TensorProto_DataType_INT32:
120 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
121 if (onnx_tensor.int32_data_size() > 0) {
122 for (int i = 0; i < onnx_tensor.int32_data_size(); i++) {
123 value->push_back(onnx_tensor.int32_data(i));
124 }
125 } else {
126 for (size_t i = 0; i < data_count; i++) {
127 value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
128 }
129 }
130 break;
131 case onnx::TensorProto_DataType_INT64:
132 *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
133 if (onnx_tensor.int64_data_size() > 0) {
134 for (int i = 0; i < onnx_tensor.int64_data_size(); i++) {
135 value->push_back(onnx_tensor.int64_data(i));
136 }
137 } else {
138 for (size_t i = 0; i < data_count; i++) {
139 value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
140 }
141 }
142 break;
143 default:
144 MS_LOG(ERROR) << "The data type is not supported.";
145 return RET_ERROR;
146 }
147 return RET_OK;
148 }
149
GetOnnxElementNum(const onnx::TensorProto & onnx_tensor,bool * overflowed)150 size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed) {
151 size_t data_count = 1;
152 bool is_overflow = false;
153 if (!onnx_tensor.dims().empty()) {
154 std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count, &is_overflow](int dim) {
155 if (is_overflow || dim < 0) {
156 is_overflow = true;
157 data_count = 0;
158 return;
159 }
160 auto udim = static_cast<size_t>(dim);
161 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, udim, SIZE_MAX)) {
162 is_overflow = true;
163 data_count = 0;
164 return;
165 }
166 data_count *= udim;
167 });
168 }
169 if (overflowed != nullptr) {
170 *overflowed = is_overflow;
171 }
172 return data_count;
173 }
174
GetOnnxRawData(const onnx::TensorProto & onnx_const_tensor,TypeId data_type,size_t data_count,size_t * data_size)175 const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type,
176 size_t data_count, size_t *data_size) {
177 MS_ASSERT(data_size != nullptr);
178 const void *onnx_data = nullptr;
179 switch (data_type) {
180 case kNumberTypeFloat32:
181 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(float), SIZE_MAX)) {
182 MS_LOG(ERROR) << "data_size overflow";
183 return nullptr;
184 }
185 *data_size = data_count * sizeof(float);
186 if (onnx_const_tensor.float_data_size() == 0) {
187 onnx_data = onnx_const_tensor.raw_data().data();
188 } else {
189 onnx_data = onnx_const_tensor.float_data().data();
190 }
191 break;
192 case kNumberTypeFloat64:
193 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(double), SIZE_MAX)) {
194 MS_LOG(ERROR) << "data_size overflow";
195 return nullptr;
196 }
197 *data_size = data_count * sizeof(double);
198 if (onnx_const_tensor.double_data_size() == 0) {
199 onnx_data = onnx_const_tensor.raw_data().data();
200 } else {
201 onnx_data = onnx_const_tensor.double_data().data();
202 }
203 break;
204 case kNumberTypeInt32:
205 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(int), SIZE_MAX)) {
206 MS_LOG(ERROR) << "data_size overflow";
207 return nullptr;
208 }
209 *data_size = data_count * sizeof(int);
210 if (onnx_const_tensor.int32_data_size() == 0) {
211 onnx_data = onnx_const_tensor.raw_data().data();
212 } else {
213 onnx_data = onnx_const_tensor.int32_data().data();
214 }
215 break;
216 case kNumberTypeInt64:
217 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(int64_t), SIZE_MAX)) {
218 MS_LOG(ERROR) << "data_size overflow";
219 return nullptr;
220 }
221 *data_size = data_count * sizeof(int64_t);
222 if (onnx_const_tensor.int64_data_size() == 0) {
223 onnx_data = onnx_const_tensor.raw_data().data();
224 } else {
225 onnx_data = onnx_const_tensor.int64_data().data();
226 }
227 break;
228 case kNumberTypeUInt8:
229 case kNumberTypeInt8:
230 case kNumberTypeBool:
231 if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(uint8_t), SIZE_MAX)) {
232 MS_LOG(ERROR) << "data_size overflow";
233 return nullptr;
234 }
235 *data_size = data_count * sizeof(uint8_t);
236 onnx_data = onnx_const_tensor.raw_data().data();
237 break;
238 default:
239 MS_LOG(ERROR) << "unsupported data type " << data_type;
240 return nullptr;
241 }
242 return onnx_data;
243 }
244 } // namespace lite
245 } // namespace mindspore
246