1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/converter_flags.h"
18 #include <climits>
19 #include <cstdlib>
20 #include <string>
21 #include <fstream>
22 #include <vector>
23 #include <memory>
24 #include <algorithm>
25 #include "ir/dtype/type_id.h"
26 #include "common/file_utils.h"
27 #include "tools/common/string_util.h"
28 #include "common/log_util.h"
29 #include "tools/converter/converter_context.h"
30 #include "tools/converter/config_parser/config_file_parser.h"
31 #include "tools/converter/config_parser/preprocess_parser.h"
32 #include "tools/converter/config_parser/quant_param_parser.h"
33
34 namespace mindspore {
35 namespace converter {
36 using mindspore::lite::RET_INPUT_PARAM_INVALID;
37 using mindspore::lite::RET_OK;
38 namespace {
39 constexpr size_t kPluginPathMaxNum = 10;
40 constexpr int kQuantBitNumInt16 = 16;
41 constexpr int kPathLengthUpperLimit = 1024;
42 constexpr int kMinShapeSizeInStr = 2;
43 } // namespace
Flags()44 Flags::Flags() {
45 AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
46 AddFlag(&Flags::modelFile, "modelFile",
47 "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
48 AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
49 AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
50 "");
51 AddFlag(&Flags::inputDataTypeStr, "inputDataType",
52 "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
53 "DEFAULT");
54 AddFlag(&Flags::outputDataTypeStr, "outputDataType",
55 "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
56 "UINT8 | DEFAULT",
57 "DEFAULT");
58 AddFlag(&Flags::configFile, "configFile",
59 "Configuration for post-training, offline split op to parallel,"
60 "disable op fusion ability and set plugin so path",
61 "");
62 AddFlag(&Flags::saveFP16Str, "fp16",
63 "Serialize const tensor in Float16 data type, only effective for const tensor in Float32 data type. on | off",
64 "off");
65 AddFlag(&Flags::trainModelIn, "trainModel",
66 "whether the model is going to be trained on device. "
67 "true | false",
68 "false");
69 AddFlag(&Flags::dec_key, "decryptKey",
70 "The key used to decrypt the file, expressed in hexadecimal characters. Only valid when fmkIn is 'MINDIR'",
71 "");
72 AddFlag(&Flags::dec_mode, "decryptMode",
73 "Decryption method for the MindIR file. Only valid when dec_key is set."
74 "AES-GCM | AES-CBC",
75 "AES-GCM");
76 AddFlag(&Flags::inTensorShape, "inputShape",
77 "Set the dimension of the model input, the order of input dimensions is consistent with the original model. "
78 "For some models, the model structure can be further optimized, but the transformed model may lose the "
79 "characteristics of dynamic shape. "
80 "e.g. \"inTensor1:1,32,32,32;inTensor2:1,1,32,32,4\"",
81 "");
82 AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
83 "Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
84 }
85
InitInputOutputDataType()86 int Flags::InitInputOutputDataType() {
87 if (this->inputDataTypeStr == "FLOAT") {
88 this->inputDataType = TypeId::kNumberTypeFloat32;
89 } else if (this->inputDataTypeStr == "INT8") {
90 this->inputDataType = TypeId::kNumberTypeInt8;
91 } else if (this->inputDataTypeStr == "UINT8") {
92 this->inputDataType = TypeId::kNumberTypeUInt8;
93 } else if (this->inputDataTypeStr == "DEFAULT") {
94 this->inputDataType = TypeId::kTypeUnknown;
95 } else {
96 std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
97 this->inputDataTypeStr.c_str();
98 return RET_INPUT_PARAM_INVALID;
99 }
100
101 if (this->outputDataTypeStr == "FLOAT") {
102 this->outputDataType = TypeId::kNumberTypeFloat32;
103 } else if (this->outputDataTypeStr == "INT8") {
104 this->outputDataType = TypeId::kNumberTypeInt8;
105 } else if (this->outputDataTypeStr == "UINT8") {
106 this->outputDataType = TypeId::kNumberTypeUInt8;
107 } else if (this->outputDataTypeStr == "DEFAULT") {
108 this->outputDataType = TypeId::kTypeUnknown;
109 } else {
110 std::cerr
111 << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
112 this->outputDataTypeStr.c_str();
113 return RET_INPUT_PARAM_INVALID;
114 }
115 return RET_OK;
116 }
117
InitFmk()118 int Flags::InitFmk() {
119 if (this->fmkIn == "CAFFE") {
120 this->fmk = kFmkTypeCaffe;
121 } else if (this->fmkIn == "MINDIR") {
122 this->fmk = kFmkTypeMs;
123 } else if (this->fmkIn == "TFLITE") {
124 this->fmk = kFmkTypeTflite;
125 } else if (this->fmkIn == "ONNX") {
126 this->fmk = kFmkTypeOnnx;
127 } else if (this->fmkIn == "TF") {
128 this->fmk = kFmkTypeTf;
129 } else {
130 std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
131 return RET_INPUT_PARAM_INVALID;
132 }
133
134 if (this->fmk != kFmkTypeCaffe && !weightFile.empty()) {
135 std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag" << std::endl;
136 return RET_INPUT_PARAM_INVALID;
137 }
138 return RET_OK;
139 }
140
InitTrainModel()141 int Flags::InitTrainModel() {
142 if (this->trainModelIn == "true") {
143 this->trainModel = true;
144 } else if (this->trainModelIn == "false") {
145 this->trainModel = false;
146 } else {
147 std::cerr << "INPUT ILLEGAL: trainModel must be true|false " << std::endl;
148 return RET_INPUT_PARAM_INVALID;
149 }
150
151 if (this->trainModel) {
152 if (this->fmk != kFmkTypeMs) {
153 std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl;
154 return RET_INPUT_PARAM_INVALID;
155 }
156 if ((this->inputDataType != TypeId::kNumberTypeFloat32) && (this->inputDataType != TypeId::kTypeUnknown)) {
157 std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors" << std::endl;
158 return RET_INPUT_PARAM_INVALID;
159 }
160 if ((this->outputDataType != TypeId::kNumberTypeFloat32) && (this->outputDataType != TypeId::kTypeUnknown)) {
161 std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors" << std::endl;
162 return RET_INPUT_PARAM_INVALID;
163 }
164 }
165 return RET_OK;
166 }
167
InitInTensorShape()168 int Flags::InitInTensorShape() {
169 if (this->inTensorShape.empty()) {
170 return RET_OK;
171 }
172 std::string content = this->inTensorShape;
173 std::vector<int64_t> shape;
174 auto shape_strs = lite::StrSplit(content, std::string(";"));
175 for (const auto &shape_str : shape_strs) {
176 if (shape_str.empty()) {
177 continue;
178 }
179 shape.clear();
180 auto string_split = lite::StrSplit(shape_str, std::string(":"));
181 CHECK_LESS_RETURN(string_split.size(), kMinShapeSizeInStr);
182 auto name = string_split[0];
183 if (name.empty()) {
184 MS_LOG(ERROR) << "input tensor name is empty";
185 }
186 auto dim_strs = string_split[1];
187 if (dim_strs.empty()) {
188 MS_LOG(ERROR) << "input tensor dim string is empty";
189 }
190 auto dims = lite::StrSplit(dim_strs, std::string(","));
191 if (dims.empty()) {
192 MS_LOG(ERROR) << "input tensor dim is empty";
193 }
194 for (const auto &dim : dims) {
195 auto dim_value = -1;
196 try {
197 dim_value = std::stoi(dim);
198 } catch (const std::exception &e) {
199 MS_LOG(ERROR) << "Get dim failed: " << e.what();
200 return lite::RET_ERROR;
201 }
202 if (dim_value < 0) {
203 MS_LOG(ERROR) << "Unsupported dim < 0.";
204 return lite::RET_ERROR;
205 } else {
206 shape.push_back(dim_value);
207 }
208 }
209 lite::ConverterContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
210 }
211 return RET_OK;
212 }
213
InitGraphInputFormat()214 int Flags::InitGraphInputFormat() {
215 if (this->graphInputFormatStr == "NHWC") {
216 graphInputFormat = mindspore::NHWC;
217 } else if (this->graphInputFormatStr == "NCHW") {
218 graphInputFormat = mindspore::NCHW;
219 } else if (!this->graphInputFormatStr.empty()) {
220 MS_LOG(ERROR) << "graph input format is invalid.";
221 return RET_INPUT_PARAM_INVALID;
222 }
223 return RET_OK;
224 }
225
InitExtendedIntegrationInfo(const lite::ConfigFileParser & config_file_parser)226 int Flags::InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser) {
227 auto extended_info = config_file_parser.GetRegistryInfoString();
228 if (!extended_info.plugin_path.empty()) {
229 const char *delimiter = ";";
230 auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, *delimiter);
231 if (relative_path.size() > kPluginPathMaxNum) {
232 MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than 10.";
233 return RET_INPUT_PARAM_INVALID;
234 }
235 for (size_t i = 0; i < relative_path.size(); i++) {
236 this->pluginsPath.push_back(lite::RealPath(relative_path[i].c_str()));
237 }
238 }
239
240 if (!extended_info.disable_fusion.empty()) {
241 if (extended_info.disable_fusion == "on") {
242 this->disableFusion = true;
243 } else if (extended_info.disable_fusion == "off") {
244 this->disableFusion = false;
245 } else {
246 std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
247 return RET_INPUT_PARAM_INVALID;
248 }
249 }
250 return RET_OK;
251 }
252
InitConfigFile()253 int Flags::InitConfigFile() {
254 lite::ConfigFileParser config_file_parser;
255 auto ret = config_file_parser.ParseConfigFile(this->configFile);
256 if (ret != RET_OK) {
257 MS_LOG(ERROR) << "Parse config file failed.";
258 return ret;
259 }
260 lite::PreprocessParser preprocess_parser;
261 ret = preprocess_parser.ParsePreprocess(config_file_parser.GetDataPreProcessString(), &this->dataPreProcessParam);
262 if (ret != RET_OK) {
263 MS_LOG(ERROR) << "Parse preprocess failed.";
264 return ret;
265 }
266 lite::QuantParamParser quant_param_parser;
267 ret = quant_param_parser.ParseCommonQuant(config_file_parser.GetCommonQuantString(), &this->commonQuantParam);
268 if (ret != RET_OK) {
269 MS_LOG(ERROR) << "Parse common quant param failed.";
270 return ret;
271 }
272 ret = quant_param_parser.ParseFullQuant(config_file_parser.GetFullQuantString(), &this->fullQuantParam);
273 if (ret != RET_OK) {
274 MS_LOG(ERROR) << "Parse full quant param failed.";
275 return ret;
276 }
277 ret = quant_param_parser.ParseMixedBitWeightQuant(config_file_parser.GetMixedBitWeightQuantString(),
278 &this->mixedBitWeightQuantParam);
279 if (ret != RET_OK) {
280 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
281 return ret;
282 }
283 ret = InitExtendedIntegrationInfo(config_file_parser);
284 if (ret != RET_OK) {
285 MS_LOG(ERROR) << "Parse extended integration info failed.";
286 return ret;
287 }
288 (void)CheckOfflineParallelConfig(this->configFile, ¶llel_split_config_);
289 return RET_OK;
290 }
291
Init(int argc,const char ** argv)292 int Flags::Init(int argc, const char **argv) {
293 int ret;
294 if (argc == 1) {
295 std::cout << this->Usage() << std::endl;
296 return lite::RET_SUCCESS_EXIT;
297 }
298 lite::Option<std::string> err = this->ParseFlags(argc, argv);
299
300 if (err.IsSome()) {
301 std::cerr << err.Get() << std::endl;
302 std::cerr << this->Usage() << std::endl;
303 return RET_INPUT_PARAM_INVALID;
304 }
305
306 if (this->help) {
307 std::cout << this->Usage() << std::endl;
308 return lite::RET_SUCCESS_EXIT;
309 }
310 if (this->modelFile.empty()) {
311 std::cerr << "INPUT MISSING: model file path is necessary" << std::endl;
312 return RET_INPUT_PARAM_INVALID;
313 }
314 if (this->outputFile.empty()) {
315 std::cerr << "INPUT MISSING: output file path is necessary" << std::endl;
316 return RET_INPUT_PARAM_INVALID;
317 }
318
319 #ifdef _WIN32
320 replace(this->outputFile.begin(), this->outputFile.end(), '/', '\\');
321 #endif
322
323 if (this->outputFile.rfind('/') == this->outputFile.length() - 1 ||
324 this->outputFile.rfind('\\') == this->outputFile.length() - 1) {
325 std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path" << std::endl;
326 return RET_INPUT_PARAM_INVALID;
327 }
328
329 if (this->fmkIn.empty()) {
330 std::cerr << "INPUT MISSING: fmk is necessary" << std::endl;
331 return RET_INPUT_PARAM_INVALID;
332 }
333
334 if (!this->configFile.empty()) {
335 ret = InitConfigFile();
336 if (ret != RET_OK) {
337 std::cerr << "Init config file failed." << std::endl;
338 return RET_INPUT_PARAM_INVALID;
339 }
340 }
341
342 if (saveFP16Str == "on") {
343 saveFP16 = true;
344 } else if (saveFP16Str == "off") {
345 saveFP16 = false;
346 } else {
347 std::cerr << "Init save_fp16 failed." << std::endl;
348 return RET_INPUT_PARAM_INVALID;
349 }
350
351 ret = InitInputOutputDataType();
352 if (ret != RET_OK) {
353 std::cerr << "Init input output datatype failed." << std::endl;
354 return RET_INPUT_PARAM_INVALID;
355 }
356
357 ret = InitFmk();
358 if (ret != RET_OK) {
359 std::cerr << "Init fmk failed." << std::endl;
360 return RET_INPUT_PARAM_INVALID;
361 }
362
363 ret = InitTrainModel();
364 if (ret != RET_OK) {
365 std::cerr << "Init train model failed." << std::endl;
366 return RET_INPUT_PARAM_INVALID;
367 }
368
369 ret = InitInTensorShape();
370 if (ret != RET_OK) {
371 std::cerr << "Init input tensor shape failed." << std::endl;
372 return RET_INPUT_PARAM_INVALID;
373 }
374
375 ret = InitGraphInputFormat();
376 if (ret != RET_OK) {
377 std::cerr << "Init graph input format failed." << std::endl;
378 return RET_INPUT_PARAM_INVALID;
379 }
380 return RET_OK;
381 }
382
CheckOfflineParallelConfig(const std::string & file,ParallelSplitConfig * parallel_split_config)383 bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
384 // device: [device0 device1] ---> {cpu, gpu}
385 // computeRate: [x: y] x >=0 && y >=0 && x/y < 10
386 MS_ASSERT(parallel_split_config != nullptr);
387 std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
388 auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
389 if (compute_rate_result.empty()) {
390 return false;
391 }
392 std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
393 if (device0_result.empty()) {
394 return false;
395 }
396 std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
397 if (device1_result.empty()) {
398 return false;
399 }
400 bool device0_flag = false;
401 bool device1_flag = false;
402 for (const auto &device : config_devices) {
403 if (device == device0_result) {
404 device0_flag = true;
405 }
406 if (device == device1_result) {
407 device1_flag = true;
408 }
409 }
410 if (!device0_flag || !device1_flag) {
411 return false;
412 }
413 const char *delimiter = ";";
414 std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, *delimiter);
415 const char *colon = ":";
416 for (const auto &device : device_rates) {
417 std::vector<std::string> rate = lite::SplitStringToVector(device, *colon);
418 int64_t compute_rate = 0;
419 try {
420 compute_rate = std::stoi(rate.back());
421 } catch (const std::exception &e) {
422 MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
423 return false;
424 }
425 parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
426 }
427 if (parallel_split_config->parallel_compute_rates_.size() != 2) {
428 return false;
429 }
430 int64_t bigger_rate = INT32_MIN;
431 int64_t smaller_rate = INT32_MAX;
432 for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
433 if (rate <= 0 || rate > INT32_MAX) {
434 return false;
435 }
436 bigger_rate = std::max(rate, bigger_rate);
437 smaller_rate = std::min(rate, smaller_rate);
438 }
439 parallel_split_config->parallel_devices_.push_back(device0_result);
440 parallel_split_config->parallel_devices_.push_back(device1_result);
441 // parall_split_type will extend by other user's attr
442 parallel_split_config->parallel_split_type_ = SplitByUserRatio;
443 // unsuitable rate
444 return bigger_rate / smaller_rate <= kMaxSplitRatio;
445 }
446
GetStrFromConfigFile(const std::string & file,const std::string & target_key)447 std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
448 std::string res;
449 if (file.empty()) {
450 MS_LOG(ERROR) << "file is nullptr";
451 return res;
452 }
453 auto resolved_path = std::make_unique<char[]>(PATH_MAX);
454 if (resolved_path == nullptr) {
455 MS_LOG(ERROR) << "new resolved_path failed";
456 return "";
457 }
458
459 #ifdef _WIN32
460 char *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
461 #else
462 char *real_path = realpath(file.c_str(), resolved_path.get());
463 #endif
464 if (real_path == nullptr || strlen(real_path) == 0) {
465 MS_LOG(ERROR) << "file path is not valid : " << file;
466 return "";
467 }
468 std::ifstream ifs(resolved_path.get());
469 if (!ifs.good()) {
470 MS_LOG(ERROR) << "file: " << real_path << " is not exist";
471 return res;
472 }
473 if (!ifs.is_open()) {
474 MS_LOG(ERROR) << "file: " << real_path << "open failed";
475 return res;
476 }
477 std::string line;
478 while (std::getline(ifs, line)) {
479 lite::Trim(&line);
480 if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
481 continue;
482 }
483 auto index = line.find('=');
484 if (index == std::string::npos) {
485 MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
486 return "";
487 }
488 auto key = line.substr(0, index);
489 auto value = line.substr(index + 1);
490 lite::Trim(&key);
491 lite::Trim(&value);
492 if (key == target_key) {
493 return value;
494 }
495 }
496 return res;
497 }
498 } // namespace converter
499 } // namespace mindspore
500