• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/converter_flags.h"
18 #include <climits>
19 #include <cstdlib>
20 #include <string>
21 #include <fstream>
22 #include <vector>
23 #include <memory>
24 #include <algorithm>
25 #include "ir/dtype/type_id.h"
26 #include "common/file_utils.h"
27 #include "tools/common/string_util.h"
28 #include "common/log_util.h"
29 #include "tools/converter/converter_context.h"
30 #include "tools/converter/config_parser/config_file_parser.h"
31 #include "tools/converter/config_parser/preprocess_parser.h"
32 #include "tools/converter/config_parser/quant_param_parser.h"
33 
34 namespace mindspore {
35 namespace converter {
36 using mindspore::lite::RET_INPUT_PARAM_INVALID;
37 using mindspore::lite::RET_OK;
38 namespace {
39 constexpr size_t kPluginPathMaxNum = 10;
40 constexpr int kQuantBitNumInt16 = 16;
41 constexpr int kPathLengthUpperLimit = 1024;
42 constexpr int kMinShapeSizeInStr = 2;
43 }  // namespace
Flags()44 Flags::Flags() {
45   AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
46   AddFlag(&Flags::modelFile, "modelFile",
47           "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
48   AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
49   AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
50           "");
51   AddFlag(&Flags::inputDataTypeStr, "inputDataType",
52           "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
53           "DEFAULT");
54   AddFlag(&Flags::outputDataTypeStr, "outputDataType",
55           "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
56           "UINT8 | DEFAULT",
57           "DEFAULT");
58   AddFlag(&Flags::configFile, "configFile",
59           "Configuration for post-training, offline split op to parallel,"
60           "disable op fusion ability and set plugin so path",
61           "");
62   AddFlag(&Flags::saveFP16Str, "fp16",
63           "Serialize const tensor in Float16 data type, only effective for const tensor in Float32 data type. on | off",
64           "off");
65   AddFlag(&Flags::trainModelIn, "trainModel",
66           "whether the model is going to be trained on device. "
67           "true | false",
68           "false");
69   AddFlag(&Flags::dec_key, "decryptKey",
70           "The key used to decrypt the file, expressed in hexadecimal characters. Only valid when fmkIn is 'MINDIR'",
71           "");
72   AddFlag(&Flags::dec_mode, "decryptMode",
73           "Decryption method for the MindIR file. Only valid when dec_key is set."
74           "AES-GCM | AES-CBC",
75           "AES-GCM");
76   AddFlag(&Flags::inTensorShape, "inputShape",
77           "Set the dimension of the model input, the order of input dimensions is consistent with the original model. "
78           "For some models, the model structure can be further optimized, but the transformed model may lose the "
79           "characteristics of dynamic shape. "
80           "e.g. \"inTensor1:1,32,32,32;inTensor2:1,1,32,32,4\"",
81           "");
82   AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
83           "Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
84 }
85 
InitInputOutputDataType()86 int Flags::InitInputOutputDataType() {
87   if (this->inputDataTypeStr == "FLOAT") {
88     this->inputDataType = TypeId::kNumberTypeFloat32;
89   } else if (this->inputDataTypeStr == "INT8") {
90     this->inputDataType = TypeId::kNumberTypeInt8;
91   } else if (this->inputDataTypeStr == "UINT8") {
92     this->inputDataType = TypeId::kNumberTypeUInt8;
93   } else if (this->inputDataTypeStr == "DEFAULT") {
94     this->inputDataType = TypeId::kTypeUnknown;
95   } else {
96     std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
97       this->inputDataTypeStr.c_str();
98     return RET_INPUT_PARAM_INVALID;
99   }
100 
101   if (this->outputDataTypeStr == "FLOAT") {
102     this->outputDataType = TypeId::kNumberTypeFloat32;
103   } else if (this->outputDataTypeStr == "INT8") {
104     this->outputDataType = TypeId::kNumberTypeInt8;
105   } else if (this->outputDataTypeStr == "UINT8") {
106     this->outputDataType = TypeId::kNumberTypeUInt8;
107   } else if (this->outputDataTypeStr == "DEFAULT") {
108     this->outputDataType = TypeId::kTypeUnknown;
109   } else {
110     std::cerr
111       << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
112       this->outputDataTypeStr.c_str();
113     return RET_INPUT_PARAM_INVALID;
114   }
115   return RET_OK;
116 }
117 
InitFmk()118 int Flags::InitFmk() {
119   if (this->fmkIn == "CAFFE") {
120     this->fmk = kFmkTypeCaffe;
121   } else if (this->fmkIn == "MINDIR") {
122     this->fmk = kFmkTypeMs;
123   } else if (this->fmkIn == "TFLITE") {
124     this->fmk = kFmkTypeTflite;
125   } else if (this->fmkIn == "ONNX") {
126     this->fmk = kFmkTypeOnnx;
127   } else if (this->fmkIn == "TF") {
128     this->fmk = kFmkTypeTf;
129   } else {
130     std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
131     return RET_INPUT_PARAM_INVALID;
132   }
133 
134   if (this->fmk != kFmkTypeCaffe && !weightFile.empty()) {
135     std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag" << std::endl;
136     return RET_INPUT_PARAM_INVALID;
137   }
138   return RET_OK;
139 }
140 
InitTrainModel()141 int Flags::InitTrainModel() {
142   if (this->trainModelIn == "true") {
143     this->trainModel = true;
144   } else if (this->trainModelIn == "false") {
145     this->trainModel = false;
146   } else {
147     std::cerr << "INPUT ILLEGAL: trainModel must be true|false " << std::endl;
148     return RET_INPUT_PARAM_INVALID;
149   }
150 
151   if (this->trainModel) {
152     if (this->fmk != kFmkTypeMs) {
153       std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl;
154       return RET_INPUT_PARAM_INVALID;
155     }
156     if ((this->inputDataType != TypeId::kNumberTypeFloat32) && (this->inputDataType != TypeId::kTypeUnknown)) {
157       std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors" << std::endl;
158       return RET_INPUT_PARAM_INVALID;
159     }
160     if ((this->outputDataType != TypeId::kNumberTypeFloat32) && (this->outputDataType != TypeId::kTypeUnknown)) {
161       std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors" << std::endl;
162       return RET_INPUT_PARAM_INVALID;
163     }
164   }
165   return RET_OK;
166 }
167 
InitInTensorShape()168 int Flags::InitInTensorShape() {
169   if (this->inTensorShape.empty()) {
170     return RET_OK;
171   }
172   std::string content = this->inTensorShape;
173   std::vector<int64_t> shape;
174   auto shape_strs = lite::StrSplit(content, std::string(";"));
175   for (const auto &shape_str : shape_strs) {
176     if (shape_str.empty()) {
177       continue;
178     }
179     shape.clear();
180     auto string_split = lite::StrSplit(shape_str, std::string(":"));
181     CHECK_LESS_RETURN(string_split.size(), kMinShapeSizeInStr);
182     auto name = string_split[0];
183     if (name.empty()) {
184       MS_LOG(ERROR) << "input tensor name is empty";
185     }
186     auto dim_strs = string_split[1];
187     if (dim_strs.empty()) {
188       MS_LOG(ERROR) << "input tensor dim string is empty";
189     }
190     auto dims = lite::StrSplit(dim_strs, std::string(","));
191     if (dims.empty()) {
192       MS_LOG(ERROR) << "input tensor dim is empty";
193     }
194     for (const auto &dim : dims) {
195       auto dim_value = -1;
196       try {
197         dim_value = std::stoi(dim);
198       } catch (const std::exception &e) {
199         MS_LOG(ERROR) << "Get dim failed: " << e.what();
200         return lite::RET_ERROR;
201       }
202       if (dim_value < 0) {
203         MS_LOG(ERROR) << "Unsupported dim < 0.";
204         return lite::RET_ERROR;
205       } else {
206         shape.push_back(dim_value);
207       }
208     }
209     lite::ConverterContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
210   }
211   return RET_OK;
212 }
213 
InitGraphInputFormat()214 int Flags::InitGraphInputFormat() {
215   if (this->graphInputFormatStr == "NHWC") {
216     graphInputFormat = mindspore::NHWC;
217   } else if (this->graphInputFormatStr == "NCHW") {
218     graphInputFormat = mindspore::NCHW;
219   } else if (!this->graphInputFormatStr.empty()) {
220     MS_LOG(ERROR) << "graph input format is invalid.";
221     return RET_INPUT_PARAM_INVALID;
222   }
223   return RET_OK;
224 }
225 
InitExtendedIntegrationInfo(const lite::ConfigFileParser & config_file_parser)226 int Flags::InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser) {
227   auto extended_info = config_file_parser.GetRegistryInfoString();
228   if (!extended_info.plugin_path.empty()) {
229     const char *delimiter = ";";
230     auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, *delimiter);
231     if (relative_path.size() > kPluginPathMaxNum) {
232       MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than 10.";
233       return RET_INPUT_PARAM_INVALID;
234     }
235     for (size_t i = 0; i < relative_path.size(); i++) {
236       this->pluginsPath.push_back(lite::RealPath(relative_path[i].c_str()));
237     }
238   }
239 
240   if (!extended_info.disable_fusion.empty()) {
241     if (extended_info.disable_fusion == "on") {
242       this->disableFusion = true;
243     } else if (extended_info.disable_fusion == "off") {
244       this->disableFusion = false;
245     } else {
246       std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
247       return RET_INPUT_PARAM_INVALID;
248     }
249   }
250   return RET_OK;
251 }
252 
InitConfigFile()253 int Flags::InitConfigFile() {
254   lite::ConfigFileParser config_file_parser;
255   auto ret = config_file_parser.ParseConfigFile(this->configFile);
256   if (ret != RET_OK) {
257     MS_LOG(ERROR) << "Parse config file failed.";
258     return ret;
259   }
260   lite::PreprocessParser preprocess_parser;
261   ret = preprocess_parser.ParsePreprocess(config_file_parser.GetDataPreProcessString(), &this->dataPreProcessParam);
262   if (ret != RET_OK) {
263     MS_LOG(ERROR) << "Parse preprocess failed.";
264     return ret;
265   }
266   lite::QuantParamParser quant_param_parser;
267   ret = quant_param_parser.ParseCommonQuant(config_file_parser.GetCommonQuantString(), &this->commonQuantParam);
268   if (ret != RET_OK) {
269     MS_LOG(ERROR) << "Parse common quant param failed.";
270     return ret;
271   }
272   ret = quant_param_parser.ParseFullQuant(config_file_parser.GetFullQuantString(), &this->fullQuantParam);
273   if (ret != RET_OK) {
274     MS_LOG(ERROR) << "Parse full quant param failed.";
275     return ret;
276   }
277   ret = quant_param_parser.ParseMixedBitWeightQuant(config_file_parser.GetMixedBitWeightQuantString(),
278                                                     &this->mixedBitWeightQuantParam);
279   if (ret != RET_OK) {
280     MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
281     return ret;
282   }
283   ret = InitExtendedIntegrationInfo(config_file_parser);
284   if (ret != RET_OK) {
285     MS_LOG(ERROR) << "Parse extended integration info failed.";
286     return ret;
287   }
288   (void)CheckOfflineParallelConfig(this->configFile, &parallel_split_config_);
289   return RET_OK;
290 }
291 
Init(int argc,const char ** argv)292 int Flags::Init(int argc, const char **argv) {
293   int ret;
294   if (argc == 1) {
295     std::cout << this->Usage() << std::endl;
296     return lite::RET_SUCCESS_EXIT;
297   }
298   lite::Option<std::string> err = this->ParseFlags(argc, argv);
299 
300   if (err.IsSome()) {
301     std::cerr << err.Get() << std::endl;
302     std::cerr << this->Usage() << std::endl;
303     return RET_INPUT_PARAM_INVALID;
304   }
305 
306   if (this->help) {
307     std::cout << this->Usage() << std::endl;
308     return lite::RET_SUCCESS_EXIT;
309   }
310   if (this->modelFile.empty()) {
311     std::cerr << "INPUT MISSING: model file path is necessary" << std::endl;
312     return RET_INPUT_PARAM_INVALID;
313   }
314   if (this->outputFile.empty()) {
315     std::cerr << "INPUT MISSING: output file path is necessary" << std::endl;
316     return RET_INPUT_PARAM_INVALID;
317   }
318 
319 #ifdef _WIN32
320   replace(this->outputFile.begin(), this->outputFile.end(), '/', '\\');
321 #endif
322 
323   if (this->outputFile.rfind('/') == this->outputFile.length() - 1 ||
324       this->outputFile.rfind('\\') == this->outputFile.length() - 1) {
325     std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path" << std::endl;
326     return RET_INPUT_PARAM_INVALID;
327   }
328 
329   if (this->fmkIn.empty()) {
330     std::cerr << "INPUT MISSING: fmk is necessary" << std::endl;
331     return RET_INPUT_PARAM_INVALID;
332   }
333 
334   if (!this->configFile.empty()) {
335     ret = InitConfigFile();
336     if (ret != RET_OK) {
337       std::cerr << "Init config file failed." << std::endl;
338       return RET_INPUT_PARAM_INVALID;
339     }
340   }
341 
342   if (saveFP16Str == "on") {
343     saveFP16 = true;
344   } else if (saveFP16Str == "off") {
345     saveFP16 = false;
346   } else {
347     std::cerr << "Init save_fp16 failed." << std::endl;
348     return RET_INPUT_PARAM_INVALID;
349   }
350 
351   ret = InitInputOutputDataType();
352   if (ret != RET_OK) {
353     std::cerr << "Init input output datatype failed." << std::endl;
354     return RET_INPUT_PARAM_INVALID;
355   }
356 
357   ret = InitFmk();
358   if (ret != RET_OK) {
359     std::cerr << "Init fmk failed." << std::endl;
360     return RET_INPUT_PARAM_INVALID;
361   }
362 
363   ret = InitTrainModel();
364   if (ret != RET_OK) {
365     std::cerr << "Init train model failed." << std::endl;
366     return RET_INPUT_PARAM_INVALID;
367   }
368 
369   ret = InitInTensorShape();
370   if (ret != RET_OK) {
371     std::cerr << "Init input tensor shape failed." << std::endl;
372     return RET_INPUT_PARAM_INVALID;
373   }
374 
375   ret = InitGraphInputFormat();
376   if (ret != RET_OK) {
377     std::cerr << "Init graph input format failed." << std::endl;
378     return RET_INPUT_PARAM_INVALID;
379   }
380   return RET_OK;
381 }
382 
CheckOfflineParallelConfig(const std::string & file,ParallelSplitConfig * parallel_split_config)383 bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
384   // device: [device0 device1] ---> {cpu, gpu}
385   // computeRate: [x: y] x >=0 && y >=0 && x/y < 10
386   MS_ASSERT(parallel_split_config != nullptr);
387   std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
388   auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
389   if (compute_rate_result.empty()) {
390     return false;
391   }
392   std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
393   if (device0_result.empty()) {
394     return false;
395   }
396   std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
397   if (device1_result.empty()) {
398     return false;
399   }
400   bool device0_flag = false;
401   bool device1_flag = false;
402   for (const auto &device : config_devices) {
403     if (device == device0_result) {
404       device0_flag = true;
405     }
406     if (device == device1_result) {
407       device1_flag = true;
408     }
409   }
410   if (!device0_flag || !device1_flag) {
411     return false;
412   }
413   const char *delimiter = ";";
414   std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, *delimiter);
415   const char *colon = ":";
416   for (const auto &device : device_rates) {
417     std::vector<std::string> rate = lite::SplitStringToVector(device, *colon);
418     int64_t compute_rate = 0;
419     try {
420       compute_rate = std::stoi(rate.back());
421     } catch (const std::exception &e) {
422       MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
423       return false;
424     }
425     parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
426   }
427   if (parallel_split_config->parallel_compute_rates_.size() != 2) {
428     return false;
429   }
430   int64_t bigger_rate = INT32_MIN;
431   int64_t smaller_rate = INT32_MAX;
432   for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
433     if (rate <= 0 || rate > INT32_MAX) {
434       return false;
435     }
436     bigger_rate = std::max(rate, bigger_rate);
437     smaller_rate = std::min(rate, smaller_rate);
438   }
439   parallel_split_config->parallel_devices_.push_back(device0_result);
440   parallel_split_config->parallel_devices_.push_back(device1_result);
441   // parall_split_type will extend by other user's attr
442   parallel_split_config->parallel_split_type_ = SplitByUserRatio;
443   // unsuitable rate
444   return bigger_rate / smaller_rate <= kMaxSplitRatio;
445 }
446 
GetStrFromConfigFile(const std::string & file,const std::string & target_key)447 std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
448   std::string res;
449   if (file.empty()) {
450     MS_LOG(ERROR) << "file is nullptr";
451     return res;
452   }
453   auto resolved_path = std::make_unique<char[]>(PATH_MAX);
454   if (resolved_path == nullptr) {
455     MS_LOG(ERROR) << "new resolved_path failed";
456     return "";
457   }
458 
459 #ifdef _WIN32
460   char *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
461 #else
462   char *real_path = realpath(file.c_str(), resolved_path.get());
463 #endif
464   if (real_path == nullptr || strlen(real_path) == 0) {
465     MS_LOG(ERROR) << "file path is not valid : " << file;
466     return "";
467   }
468   std::ifstream ifs(resolved_path.get());
469   if (!ifs.good()) {
470     MS_LOG(ERROR) << "file: " << real_path << " is not exist";
471     return res;
472   }
473   if (!ifs.is_open()) {
474     MS_LOG(ERROR) << "file: " << real_path << "open failed";
475     return res;
476   }
477   std::string line;
478   while (std::getline(ifs, line)) {
479     lite::Trim(&line);
480     if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
481       continue;
482     }
483     auto index = line.find('=');
484     if (index == std::string::npos) {
485       MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
486       return "";
487     }
488     auto key = line.substr(0, index);
489     auto value = line.substr(index + 1);
490     lite::Trim(&key);
491     lite::Trim(&value);
492     if (key == target_key) {
493       return value;
494     }
495   }
496   return res;
497 }
498 }  // namespace converter
499 }  // namespace mindspore
500