1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/config_parser/third_party_param_parser.h"
18 #include <vector>
19 #include <string>
20 #include <map>
21 #include "include/errorcode.h"
22 #include "src/common/log_adapter.h"
23 #include "nnacl/op_base.h"
24 #include "tools/common/string_util.h"
25
26 namespace mindspore {
27 namespace lite {
28 namespace {
29 const std::map<std::string, TypeId> kDataTypeMap = {
30 {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32},
31 {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64},
32 {"int32", TypeId::kNumberTypeInt32}, {"int16", TypeId::kNumberTypeInt16},
33 {"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8},
34 {"bool", TypeId::kNumberTypeBool},
35 };
36
ConvertDataType(const std::string & type)37 TypeId ConvertDataType(const std::string &type) {
38 auto iter = kDataTypeMap.find(type);
39 if (iter == kDataTypeMap.end()) {
40 return TypeId::kTypeUnknown;
41 }
42 return iter->second;
43 }
44 } // namespace
45
46 /**
47 * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]].
48 */
DoParseShape(const std::string & src,std::vector<std::vector<int64_t>> * dst_shapes)49 int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) {
50 MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR);
51 dst_shapes->clear();
52
53 auto tmp_shapes = SplitStringToVector(src, ";");
54 for (auto tmp_shape : tmp_shapes) {
55 auto tmp = SplitStringToVector(tmp_shape, ",");
56 std::vector<int64_t> shape = {};
57 for (auto t : tmp) {
58 int value = 0;
59 if (!ConvertIntNum(t, &value)) {
60 MS_LOG(ERROR) << "Found error when convert shape string to integer";
61 return RET_ERROR;
62 }
63 if (value <= 0) { // Valid shape value should be greater than 0.
64 MS_LOG(ERROR) << "Only support fixed shapes in third party param";
65 return RET_ERROR;
66 }
67 shape.push_back(value);
68 }
69 dst_shapes->push_back(shape);
70 }
71 return RET_OK;
72 }
73
74 /**
75 * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}.
76 */
DoParseExtendedParameters(const std::string & src,std::map<std::string,std::vector<uint8_t>> * dst_ext_param)77 int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src,
78 std::map<std::string, std::vector<uint8_t>> *dst_ext_param) {
79 MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR);
80 constexpr size_t kKeyIndex = 0U;
81 constexpr size_t kValueIndex = 1U;
82 constexpr size_t kKeyValueSize = 2U;
83
84 if (src == "") { // Just return if 'extended_parameters' is configured.
85 return RET_OK;
86 }
87
88 auto tmp_list = SplitStringToVector(src, ";");
89 std::map<std::string, std::vector<uint8_t>> tmp_map = {};
90 for (auto tmp : tmp_list) {
91 auto key_and_value = SplitStringToVector(tmp, ":");
92 if (key_and_value.size() != kKeyValueSize) {
93 MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format";
94 return RET_ERROR;
95 }
96 auto key = key_and_value[kKeyIndex];
97 auto value = key_and_value[kValueIndex];
98 if (tmp_map.find(key) != tmp_map.end()) {
99 MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated";
100 return RET_ERROR;
101 }
102 tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end()));
103 }
104
105 *dst_ext_param = tmp_map;
106 return RET_OK;
107 }
108
109 /**
110 * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32]
111 */
DoParseDtypes(const std::string & src,std::vector<TypeId> * dst_dtypes)112 int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) {
113 MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR);
114 dst_dtypes->clear();
115 auto tmp_dtypes = SplitStringToVector(src, ";");
116 for (auto tmp_dtype : tmp_dtypes) {
117 TypeId type = ConvertDataType(tmp_dtype);
118 if (type == kTypeUnknown) {
119 MS_LOG(ERROR) << "Parse dtypes in third party model config failed";
120 return RET_ERROR;
121 }
122 dst_dtypes->push_back(type);
123 }
124 return RET_OK;
125 }
126
127 /**
128 * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"]
129 * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n"
130 */
DoParseNames(const std::string & src,size_t num,const std::string & default_prefix,std::vector<std::string> * dst_names)131 int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
132 std::vector<std::string> *dst_names) {
133 MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR);
134 std::string tmp_names = src;
135 if (tmp_names.empty()) {
136 std::string tmp = "";
137 for (size_t i = 0; i < num; i++) {
138 tmp += default_prefix + "_" + std::to_string(i);
139 if (i + 1 < num) {
140 tmp += ";";
141 }
142 }
143 tmp_names = tmp;
144 }
145
146 *dst_names = SplitStringToVector(tmp_names, ";");
147 if (dst_names->size() != num) {
148 MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal";
149 return RET_ERROR;
150 }
151 return RET_OK;
152 }
153
154 /**
155 * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC]
156 */
157 namespace {
StringToFormat(const std::string & format_string,schema::Format * format)158 int StringToFormat(const std::string &format_string, schema::Format *format) {
159 static const std::unordered_map<std::string, schema::Format> kFormatTable = {
160 {"NCHW", schema::Format::Format_NCHW},
161 {"NHWC", schema::Format::Format_NHWC},
162 {"NHWC4", schema::Format::Format_NHWC4},
163 {"HWKC", schema::Format::Format_HWKC},
164 {"HWCK", schema::Format::Format_HWCK},
165 {"KCHW", schema::Format::Format_KCHW},
166 {"CKHW", schema::Format::Format_CKHW},
167 {"KHWC", schema::Format::Format_KHWC},
168 {"CHWK", schema::Format::Format_CHWK},
169 {"HW", schema::Format::Format_HW},
170 {"HW4", schema::Format::Format_HW4},
171 {"NC", schema::Format::Format_NC},
172 {"NC4", schema::Format::Format_NC4},
173 {"NC4HW4", schema::Format::Format_NC4HW4},
174 {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT},
175 {"NCDHW", schema::Format::Format_NCDHW},
176 {"NWC", schema::Format::Format_NWC},
177 {"NCW", schema::Format::Format_NCW},
178 };
179
180 if (format == nullptr) {
181 return RET_NULL_PTR;
182 }
183
184 auto iter = kFormatTable.find(format_string);
185 if (iter == kFormatTable.end()) {
186 return RET_PARAM_INVALID;
187 }
188
189 *format = iter->second;
190 return RET_OK;
191 }
192 }
193
DoParseFormats(const std::string & src,size_t num,std::vector<schema::Format> * result_formats)194 int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num,
195 std::vector<schema::Format> *result_formats) {
196 MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR);
197 std::string tmp_names = src;
198 if (tmp_names.empty()) {
199 std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC);
200 *result_formats = default_formats;
201 return RET_OK;
202 }
203
204 auto format_strings = SplitStringToVector(tmp_names, ";");
205 if (format_strings.size() != num) {
206 MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal";
207 return RET_ERROR;
208 }
209
210 std::vector<schema::Format> result(num);
211 for (size_t i = 0; i < num; i++) {
212 if (StringToFormat(format_strings[i], &result[i]) != RET_OK) {
213 MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid";
214 return RET_PARAM_INVALID;
215 }
216 }
217 *result_formats = result;
218 return RET_OK;
219 }
220
Parse(const ThirdPartyModelString & param_string,ThirdPartyModelParam * param)221 int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) {
222 MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
223
224 auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes));
225 if (ret != RET_OK) {
226 MS_LOG(ERROR) << "Parse input shapes of third party param failed";
227 return RET_ERROR;
228 }
229
230 ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes));
231 if (ret != RET_OK) {
232 MS_LOG(ERROR) << "Parse input dtypes of third party param failed";
233 return RET_ERROR;
234 }
235
236 auto input_shape_num = param->input_shapes.size();
237 auto input_dtype_num = param->input_dtypes.size();
238 if (input_shape_num != input_dtype_num) {
239 MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num
240 << " are not equal";
241 return RET_ERROR;
242 }
243
244 ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats));
245 if (ret != RET_OK) {
246 MS_LOG(ERROR) << "Parse input formats of third party param failed";
247 return RET_ERROR;
248 }
249
250 const std::string kInputNamePrefix = "in";
251 ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names));
252 if (ret != RET_OK) {
253 MS_LOG(ERROR) << "Parse input names of third party param failed";
254 return RET_ERROR;
255 }
256
257 ret = DoParseShape(param_string.output_shapes, &(param->output_shapes));
258 if (ret != RET_OK) {
259 MS_LOG(ERROR) << "Parse output shaped of third party param failed";
260 return RET_ERROR;
261 }
262
263 ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes));
264 if (ret != RET_OK) {
265 MS_LOG(ERROR) << "Parse output dtypes of third party param failed";
266 return RET_ERROR;
267 }
268
269 auto output_shape_num = param->output_shapes.size();
270 auto output_dtype_num = param->output_dtypes.size();
271 if (output_shape_num != output_dtype_num) {
272 MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num
273 << " are not equal";
274 return RET_ERROR;
275 }
276
277 ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats));
278 if (ret != RET_OK) {
279 MS_LOG(ERROR) << "Parse output formats of third party param failed";
280 return RET_ERROR;
281 }
282
283 const std::string kOutputNamePrefix = "out";
284 ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names));
285 if (ret != RET_OK) {
286 MS_LOG(ERROR) << "Parse output names of third party param failed";
287 return RET_ERROR;
288 }
289
290 ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters));
291 if (ret != RET_OK) {
292 MS_LOG(ERROR) << "Parse extended parameter of third party param failed";
293 return RET_ERROR;
294 }
295
296 return RET_OK;
297 }
298 } // namespace lite
299 } // namespace mindspore
300