• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/onnx/onnx_node_parser.h"
18 #include <algorithm>
19 #include <vector>
20 #include <memory>
21 #include <unordered_map>
22 #include "tools/converter/parser/onnx/onnx_model_parser.h"
23 #include "nnacl/op_base.h"
24 
25 namespace mindspore {
26 namespace lite {
27 namespace {
28 static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
29   {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
30   {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
31   {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16},
32   {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32},
33   {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
34   {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
35   {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
36   {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
37   {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
38 }  // namespace
39 
40 int64_t OnnxNodeParser::opset_version_ = 0;
41 
GetOnnxPadMode(const onnx::AttributeProto & onnx_node_attr)42 mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
43   if (onnx_node_attr.s() == "NOTSET") {
44     return mindspore::PadMode::PAD;
45   } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") {
46     return mindspore::PadMode::SAME;
47   } else if (onnx_node_attr.s() == "VALID") {
48     return mindspore::PadMode::VALID;
49   } else {
50     MS_LOG(ERROR) << "unsupported padMode";
51     return mindspore::PadMode::PAD;
52   }
53 }
54 
CopyOnnxTensorData(const onnx::TensorProto & onnx_const_tensor,const tensor::TensorPtr & tensor_info)55 STATUS OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
56                                           const tensor::TensorPtr &tensor_info) {
57   if (tensor_info == nullptr) {
58     MS_LOG(ERROR) << "tensor_info is nullptr.";
59     return RET_NULL_PTR;
60   }
61   bool overflow = false;
62   auto data_count = GetOnnxElementNum(onnx_const_tensor, &overflow);
63   if (overflow) {
64     MS_LOG(ERROR) << "data count overflow";
65     return RET_ERROR;
66   }
67   size_t data_size = 0;
68   auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()));
69   const void *onnx_data = GetOnnxRawData(onnx_const_tensor, data_type, data_count, &data_size);
70   if (data_size == 0) {
71     return RET_OK;
72   }
73   if (onnx_data == nullptr) {
74     MS_LOG(ERROR) << "origin data in onnx model is nullptr";
75     return RET_MEMORY_FAILED;
76   }
77   auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
78   if (memcpy_s(tensor_data, tensor_info->data().nbytes(), onnx_data, data_size) != EOK) {
79     MS_LOG(ERROR) << "memcpy_s failed";
80     return RET_ERROR;
81   }
82   return RET_OK;
83 }
84 
GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type)85 TypeId OnnxNodeParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
86   auto iter = kOnnxTypeTransferMap.find(onnx_type);
87   if (iter == kOnnxTypeTransferMap.end()) {
88     MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type;
89     return kTypeUnknown;
90   }
91   return iter->second;
92 }
93 
GetTensorDataFromOnnx(const onnx::TensorProto & onnx_tensor,std::vector<float> * value,int * type)94 STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
95                                              int *type) {
96   if (value == nullptr || type == nullptr) {
97     MS_LOG(ERROR) << "Input value or type is nullptr";
98     return RET_INPUT_PARAM_INVALID;
99   }
100   bool overflow = false;
101   auto data_count = GetOnnxElementNum(onnx_tensor, &overflow);
102   if (overflow) {
103     MS_LOG(ERROR) << "data count overflow";
104     return RET_ERROR;
105   }
106   switch (onnx_tensor.data_type()) {
107     case onnx::TensorProto_DataType_FLOAT:
108       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
109       if (onnx_tensor.float_data_size() > 0) {
110         for (int i = 0; i < onnx_tensor.float_data_size(); i++) {
111           value->push_back(onnx_tensor.float_data(i));
112         }
113       } else {
114         for (size_t i = 0; i < data_count; i++) {
115           value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
116         }
117       }
118       break;
119     case onnx::TensorProto_DataType_INT32:
120       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
121       if (onnx_tensor.int32_data_size() > 0) {
122         for (int i = 0; i < onnx_tensor.int32_data_size(); i++) {
123           value->push_back(onnx_tensor.int32_data(i));
124         }
125       } else {
126         for (size_t i = 0; i < data_count; i++) {
127           value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
128         }
129       }
130       break;
131     case onnx::TensorProto_DataType_INT64:
132       *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
133       if (onnx_tensor.int64_data_size() > 0) {
134         for (int i = 0; i < onnx_tensor.int64_data_size(); i++) {
135           value->push_back(onnx_tensor.int64_data(i));
136         }
137       } else {
138         for (size_t i = 0; i < data_count; i++) {
139           value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
140         }
141       }
142       break;
143     default:
144       MS_LOG(ERROR) << "The data type is not supported.";
145       return RET_ERROR;
146   }
147   return RET_OK;
148 }
149 
GetOnnxElementNum(const onnx::TensorProto & onnx_tensor,bool * overflowed)150 size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed) {
151   size_t data_count = 1;
152   bool is_overflow = false;
153   if (!onnx_tensor.dims().empty()) {
154     std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count, &is_overflow](int dim) {
155       if (is_overflow || dim < 0) {
156         is_overflow = true;
157         data_count = 0;
158         return;
159       }
160       auto udim = static_cast<size_t>(dim);
161       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, udim, SIZE_MAX)) {
162         is_overflow = true;
163         data_count = 0;
164         return;
165       }
166       data_count *= udim;
167     });
168   }
169   if (overflowed != nullptr) {
170     *overflowed = is_overflow;
171   }
172   return data_count;
173 }
174 
GetOnnxRawData(const onnx::TensorProto & onnx_const_tensor,TypeId data_type,size_t data_count,size_t * data_size)175 const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type,
176                                            size_t data_count, size_t *data_size) {
177   MS_ASSERT(data_size != nullptr);
178   const void *onnx_data = nullptr;
179   switch (data_type) {
180     case kNumberTypeFloat32:
181       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(float), SIZE_MAX)) {
182         MS_LOG(ERROR) << "data_size overflow";
183         return nullptr;
184       }
185       *data_size = data_count * sizeof(float);
186       if (onnx_const_tensor.float_data_size() == 0) {
187         onnx_data = onnx_const_tensor.raw_data().data();
188       } else {
189         onnx_data = onnx_const_tensor.float_data().data();
190       }
191       break;
192     case kNumberTypeFloat64:
193       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(double), SIZE_MAX)) {
194         MS_LOG(ERROR) << "data_size overflow";
195         return nullptr;
196       }
197       *data_size = data_count * sizeof(double);
198       if (onnx_const_tensor.double_data_size() == 0) {
199         onnx_data = onnx_const_tensor.raw_data().data();
200       } else {
201         onnx_data = onnx_const_tensor.double_data().data();
202       }
203       break;
204     case kNumberTypeInt32:
205       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(int), SIZE_MAX)) {
206         MS_LOG(ERROR) << "data_size overflow";
207         return nullptr;
208       }
209       *data_size = data_count * sizeof(int);
210       if (onnx_const_tensor.int32_data_size() == 0) {
211         onnx_data = onnx_const_tensor.raw_data().data();
212       } else {
213         onnx_data = onnx_const_tensor.int32_data().data();
214       }
215       break;
216     case kNumberTypeInt64:
217       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(int64_t), SIZE_MAX)) {
218         MS_LOG(ERROR) << "data_size overflow";
219         return nullptr;
220       }
221       *data_size = data_count * sizeof(int64_t);
222       if (onnx_const_tensor.int64_data_size() == 0) {
223         onnx_data = onnx_const_tensor.raw_data().data();
224       } else {
225         onnx_data = onnx_const_tensor.int64_data().data();
226       }
227       break;
228     case kNumberTypeUInt8:
229     case kNumberTypeInt8:
230     case kNumberTypeBool:
231       if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(uint8_t), SIZE_MAX)) {
232         MS_LOG(ERROR) << "data_size overflow";
233         return nullptr;
234       }
235       *data_size = data_count * sizeof(uint8_t);
236       onnx_data = onnx_const_tensor.raw_data().data();
237       break;
238     default:
239       MS_LOG(ERROR) << "unsupported data type " << data_type;
240       return nullptr;
241   }
242   return onnx_data;
243 }
244 }  // namespace lite
245 }  // namespace mindspore
246