• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/third_party_param_parser.h"
18 #include <vector>
19 #include <string>
20 #include <map>
21 #include "include/errorcode.h"
22 #include "src/common/log_adapter.h"
23 #include "nnacl/op_base.h"
24 #include "tools/common/string_util.h"
25 
26 namespace mindspore {
27 namespace lite {
28 namespace {
29 const std::map<std::string, TypeId> kDataTypeMap = {
30   {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32},
31   {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64},
32   {"int32", TypeId::kNumberTypeInt32},     {"int16", TypeId::kNumberTypeInt16},
33   {"int8", TypeId::kNumberTypeInt8},       {"uint8", TypeId::kNumberTypeUInt8},
34   {"bool", TypeId::kNumberTypeBool},
35 };
36 
ConvertDataType(const std::string & type)37 TypeId ConvertDataType(const std::string &type) {
38   auto iter = kDataTypeMap.find(type);
39   if (iter == kDataTypeMap.end()) {
40     return TypeId::kTypeUnknown;
41   }
42   return iter->second;
43 }
44 }  // namespace
45 
46 /**
47  * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]].
48  */
DoParseShape(const std::string & src,std::vector<std::vector<int64_t>> * dst_shapes)49 int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) {
50   MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR);
51   dst_shapes->clear();
52 
53   auto tmp_shapes = SplitStringToVector(src, ";");
54   for (auto tmp_shape : tmp_shapes) {
55     auto tmp = SplitStringToVector(tmp_shape, ",");
56     std::vector<int64_t> shape = {};
57     for (auto t : tmp) {
58       int value = 0;
59       if (!ConvertIntNum(t, &value)) {
60         MS_LOG(ERROR) << "Found error when convert shape string to integer";
61         return RET_ERROR;
62       }
63       if (value <= 0) {  // Valid shape value should be greater than 0.
64         MS_LOG(ERROR) << "Only support fixed shapes in third party param";
65         return RET_ERROR;
66       }
67       shape.push_back(value);
68     }
69     dst_shapes->push_back(shape);
70   }
71   return RET_OK;
72 }
73 
74 /**
75  * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}.
76  */
DoParseExtendedParameters(const std::string & src,std::map<std::string,std::vector<uint8_t>> * dst_ext_param)77 int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src,
78                                                      std::map<std::string, std::vector<uint8_t>> *dst_ext_param) {
79   MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR);
80   constexpr size_t kKeyIndex = 0U;
81   constexpr size_t kValueIndex = 1U;
82   constexpr size_t kKeyValueSize = 2U;
83 
84   if (src == "") {  // Just return if 'extended_parameters' is configured.
85     return RET_OK;
86   }
87 
88   auto tmp_list = SplitStringToVector(src, ";");
89   std::map<std::string, std::vector<uint8_t>> tmp_map = {};
90   for (auto tmp : tmp_list) {
91     auto key_and_value = SplitStringToVector(tmp, ":");
92     if (key_and_value.size() != kKeyValueSize) {
93       MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format";
94       return RET_ERROR;
95     }
96     auto key = key_and_value[kKeyIndex];
97     auto value = key_and_value[kValueIndex];
98     if (tmp_map.find(key) != tmp_map.end()) {
99       MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated";
100       return RET_ERROR;
101     }
102     tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end()));
103   }
104 
105   *dst_ext_param = tmp_map;
106   return RET_OK;
107 }
108 
109 /**
110  * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32]
111  */
DoParseDtypes(const std::string & src,std::vector<TypeId> * dst_dtypes)112 int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) {
113   MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR);
114   dst_dtypes->clear();
115   auto tmp_dtypes = SplitStringToVector(src, ";");
116   for (auto tmp_dtype : tmp_dtypes) {
117     TypeId type = ConvertDataType(tmp_dtype);
118     if (type == kTypeUnknown) {
119       MS_LOG(ERROR) << "Parse dtypes in third party model config failed";
120       return RET_ERROR;
121     }
122     dst_dtypes->push_back(type);
123   }
124   return RET_OK;
125 }
126 
127 /**
128  * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"]
129  * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n"
130  */
DoParseNames(const std::string & src,size_t num,const std::string & default_prefix,std::vector<std::string> * dst_names)131 int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
132                                         std::vector<std::string> *dst_names) {
133   MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR);
134   std::string tmp_names = src;
135   if (tmp_names.empty()) {
136     std::string tmp = "";
137     for (size_t i = 0; i < num; i++) {
138       tmp += default_prefix + "_" + std::to_string(i);
139       if (i + 1 < num) {
140         tmp += ";";
141       }
142     }
143     tmp_names = tmp;
144   }
145 
146   *dst_names = SplitStringToVector(tmp_names, ";");
147   if (dst_names->size() != num) {
148     MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal";
149     return RET_ERROR;
150   }
151   return RET_OK;
152 }
153 
154 /**
155  * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC]
156  */
157 namespace {
StringToFormat(const std::string & format_string,schema::Format * format)158   int StringToFormat(const std::string &format_string, schema::Format *format) {
159     static const std::unordered_map<std::string, schema::Format> kFormatTable = {
160       {"NCHW", schema::Format::Format_NCHW},
161       {"NHWC", schema::Format::Format_NHWC},
162       {"NHWC4", schema::Format::Format_NHWC4},
163       {"HWKC", schema::Format::Format_HWKC},
164       {"HWCK", schema::Format::Format_HWCK},
165       {"KCHW", schema::Format::Format_KCHW},
166       {"CKHW", schema::Format::Format_CKHW},
167       {"KHWC", schema::Format::Format_KHWC},
168       {"CHWK", schema::Format::Format_CHWK},
169       {"HW", schema::Format::Format_HW},
170       {"HW4", schema::Format::Format_HW4},
171       {"NC", schema::Format::Format_NC},
172       {"NC4", schema::Format::Format_NC4},
173       {"NC4HW4", schema::Format::Format_NC4HW4},
174       {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT},
175       {"NCDHW", schema::Format::Format_NCDHW},
176       {"NWC", schema::Format::Format_NWC},
177       {"NCW", schema::Format::Format_NCW},
178     };
179 
180     if (format == nullptr) {
181       return RET_NULL_PTR;
182     }
183 
184     auto iter = kFormatTable.find(format_string);
185     if (iter == kFormatTable.end()) {
186       return RET_PARAM_INVALID;
187     }
188 
189     *format = iter->second;
190     return RET_OK;
191   }
192 }
193 
DoParseFormats(const std::string & src,size_t num,std::vector<schema::Format> * result_formats)194 int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num,
195                                           std::vector<schema::Format> *result_formats) {
196   MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR);
197   std::string tmp_names = src;
198   if (tmp_names.empty()) {
199     std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC);
200     *result_formats = default_formats;
201     return RET_OK;
202   }
203 
204   auto format_strings = SplitStringToVector(tmp_names, ";");
205   if (format_strings.size() != num) {
206     MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal";
207     return RET_ERROR;
208   }
209 
210   std::vector<schema::Format> result(num);
211   for (size_t i = 0; i < num; i++) {
212     if (StringToFormat(format_strings[i], &result[i]) != RET_OK) {
213       MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid";
214       return RET_PARAM_INVALID;
215     }
216   }
217   *result_formats = result;
218   return RET_OK;
219 }
220 
Parse(const ThirdPartyModelString & param_string,ThirdPartyModelParam * param)221 int ThirdPartyParamParser::Parse(const ThirdPartyModelString &param_string, ThirdPartyModelParam *param) {
222   MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
223 
224   auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes));
225   if (ret != RET_OK) {
226     MS_LOG(ERROR) << "Parse input shapes of third party param failed";
227     return RET_ERROR;
228   }
229 
230   ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes));
231   if (ret != RET_OK) {
232     MS_LOG(ERROR) << "Parse input dtypes of third party param failed";
233     return RET_ERROR;
234   }
235 
236   auto input_shape_num = param->input_shapes.size();
237   auto input_dtype_num = param->input_dtypes.size();
238   if (input_shape_num != input_dtype_num) {
239     MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num
240                   << " are not equal";
241     return RET_ERROR;
242   }
243 
244   ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats));
245   if (ret != RET_OK) {
246     MS_LOG(ERROR) << "Parse input formats of third party param failed";
247     return RET_ERROR;
248   }
249 
250   const std::string kInputNamePrefix = "in";
251   ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names));
252   if (ret != RET_OK) {
253     MS_LOG(ERROR) << "Parse input names of third party param failed";
254     return RET_ERROR;
255   }
256 
257   ret = DoParseShape(param_string.output_shapes, &(param->output_shapes));
258   if (ret != RET_OK) {
259     MS_LOG(ERROR) << "Parse output shaped of third party param failed";
260     return RET_ERROR;
261   }
262 
263   ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes));
264   if (ret != RET_OK) {
265     MS_LOG(ERROR) << "Parse output dtypes of third party param failed";
266     return RET_ERROR;
267   }
268 
269   auto output_shape_num = param->output_shapes.size();
270   auto output_dtype_num = param->output_dtypes.size();
271   if (output_shape_num != output_dtype_num) {
272     MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num
273                   << " are not equal";
274     return RET_ERROR;
275   }
276 
277   ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats));
278   if (ret != RET_OK) {
279     MS_LOG(ERROR) << "Parse output formats of third party param failed";
280     return RET_ERROR;
281   }
282 
283   const std::string kOutputNamePrefix = "out";
284   ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names));
285   if (ret != RET_OK) {
286     MS_LOG(ERROR) << "Parse output names of third party param failed";
287     return RET_ERROR;
288   }
289 
290   ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters));
291   if (ret != RET_OK) {
292     MS_LOG(ERROR) << "Parse extended parameter of third party param failed";
293     return RET_ERROR;
294   }
295 
296   return RET_OK;
297 }
298 }  // namespace lite
299 }  // namespace mindspore
300