• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/converter.h"
17 #include "include/api/data_type.h"
18 #include "tools/converter/cxx_api/converter_para.h"
19 #include "tools/converter/converter_context.h"
20 #include "tools/converter/converter.h"
21 #include "src/common/log_adapter.h"
22 
23 namespace mindspore {
24 namespace {
25 constexpr size_t kMaxSectionNum = 100;
26 constexpr size_t kMaxConfigNumPerSection = 1000;
27 }  // namespace
28 namespace lite {
29 int RunConverter(const std::shared_ptr<ConverterPara> &data_);
30 }
Converter()31 Converter::Converter() {
32   data_ = std::make_shared<ConverterPara>();
33   if (data_ == nullptr) {
34     MS_LOG(ERROR) << "Create ConverterPara failed";
35   }
36 }
37 
Converter(converter::FmkType fmk_type,const std::vector<char> & model_file,const std::vector<char> & output_file,const std::vector<char> & weight_file)38 Converter::Converter(converter::FmkType fmk_type, const std::vector<char> &model_file,
39                      const std::vector<char> &output_file, const std::vector<char> &weight_file) {
40   data_ = std::make_shared<ConverterPara>();
41   if (data_ != nullptr) {
42     data_->fmk_type = fmk_type;
43     data_->model_file = CharToString(model_file);
44     data_->output_file = CharToString(output_file);
45     data_->weight_file = CharToString(weight_file);
46   } else {
47     MS_LOG(ERROR) << "Create ConverterPara failed";
48   }
49 }
50 
SetConfigFile(const std::vector<char> & config_file)51 void Converter::SetConfigFile(const std::vector<char> &config_file) {
52   if (data_ != nullptr) {
53     data_->config_file = CharToString(config_file);
54   }
55 }
56 
GetConfigFileChar() const57 std::vector<char> Converter::GetConfigFileChar() const {
58   std::string cfg_file = "";
59   if (data_ != nullptr) {
60     cfg_file = data_->config_file;
61   }
62   return StringToChar(cfg_file);
63 }
64 
SetConfigInfo(const std::vector<char> & section,const std::map<std::vector<char>,std::vector<char>> & config)65 void Converter::SetConfigInfo(const std::vector<char> &section,
66                               const std::map<std::vector<char>, std::vector<char>> &config) {
67   auto section_str = CharToString(section);
68   auto config_str = MapVectorCharToString(config);
69   if (data_ != nullptr) {
70     if (data_->config_param.size() > kMaxSectionNum) {
71       MS_LOG(ERROR) << "Section num " << data_->config_param.size() << "exceeds max num " << kMaxSectionNum;
72       return;
73     }
74     if (data_->config_param.find(section_str) != data_->config_param.end()) {
75       MS_LOG(WARNING) << "Section " << section_str << "already exists, "
76                       << "value will be overwrite.";
77     }
78     if (config.size() > kMaxConfigNumPerSection) {
79       MS_LOG(ERROR) << "Config num " << config.size() << " exceeds max num " << kMaxConfigNumPerSection << " in "
80                     << section_str;
81       return;
82     }
83     data_->config_param[section_str] = config_str;
84   }
85 }
86 
GetConfigInfoChar() const87 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> Converter::GetConfigInfoChar() const {
88   return MapMapStringToChar(data_->config_param);
89 }
90 
SetWeightFp16(bool weight_fp16)91 void Converter::SetWeightFp16(bool weight_fp16) {
92   if (data_ != nullptr) {
93     data_->weight_fp16 = weight_fp16;
94   }
95 }
96 
GetWeightFp16() const97 bool Converter::GetWeightFp16() const {
98   if (data_ != nullptr) {
99     return data_->weight_fp16;
100   } else {
101     return false;
102   }
103 }
104 
SetInputShape(const std::map<std::vector<char>,std::vector<int64_t>> & input_shape)105 void Converter::SetInputShape(const std::map<std::vector<char>, std::vector<int64_t>> &input_shape) {
106   auto input_shape_str = MapCharToString(input_shape);
107   if (data_ != nullptr) {
108     for (auto &it : input_shape_str) {
109       lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(it.first, it.second);
110     }
111     data_->input_shape = input_shape_str;
112   }
113 }
114 
GetInputShapeChar() const115 std::map<std::vector<char>, std::vector<int64_t>> Converter::GetInputShapeChar() const {
116   std::map<std::string, std::vector<int64_t>> input_shape = {};
117   if (data_ != nullptr) {
118     input_shape = data_->input_shape;
119   }
120   return MapStringToChar(input_shape);
121 }
122 
SetInputFormat(Format format)123 void Converter::SetInputFormat(Format format) {
124   if (data_ != nullptr) {
125     if (format != DEFAULT_FORMAT) {
126       data_->input_format = format;
127     }
128     data_->spec_input_format = format;
129   }
130 }
131 
GetInputFormat() const132 Format Converter::GetInputFormat() const {
133   if (data_ != nullptr) {
134     return data_->input_format;
135   } else {
136     return DEFAULT_FORMAT;
137   }
138 }
139 
SetOutputFormat(Format format)140 void Converter::SetOutputFormat(Format format) {
141   if (data_ != nullptr) {
142     data_->spec_output_format = format;
143   }
144 }
145 
SetInputDataType(DataType data_type)146 void Converter::SetInputDataType(DataType data_type) {
147   if (data_ != nullptr) {
148     data_->input_data_type = data_type;
149   }
150 }
151 
GetInputDataType()152 DataType Converter::GetInputDataType() {
153   if (data_ != nullptr) {
154     return data_->input_data_type;
155   } else {
156     return DataType::kTypeUnknown;
157   }
158 }
159 
SetOutputDataType(DataType data_type)160 void Converter::SetOutputDataType(DataType data_type) {
161   if (data_ != nullptr) {
162     data_->output_data_type = data_type;
163   }
164 }
165 
GetOutputDataType()166 DataType Converter::GetOutputDataType() {
167   if (data_ != nullptr) {
168     return data_->output_data_type;
169   } else {
170     return DataType::kTypeUnknown;
171   }
172 }
173 
SetSaveType(ModelType save_type)174 void Converter::SetSaveType(ModelType save_type) {
175   if (data_ != nullptr) {
176     data_->save_type = save_type;
177   }
178 }
179 
GetSaveType() const180 ModelType Converter::GetSaveType() const {
181   if (data_ != nullptr) {
182     return data_->save_type;
183   } else {
184     return kMindIR_Lite;
185   }
186 }
187 
SetDecryptKey(const std::vector<char> & key)188 void Converter::SetDecryptKey(const std::vector<char> &key) {
189   if (data_ != nullptr) {
190     data_->decrypt_key = CharToString(key);
191   }
192 }
193 
GetDecryptKeyChar() const194 std::vector<char> Converter::GetDecryptKeyChar() const {
195   std::string decrypt_key = "";
196   if (data_ != nullptr) {
197     decrypt_key = data_->decrypt_key;
198   }
199   return StringToChar(decrypt_key);
200 }
201 
SetDecryptMode(const std::vector<char> & mode)202 void Converter::SetDecryptMode(const std::vector<char> &mode) {
203   if (data_ != nullptr) {
204     data_->decrypt_mode = CharToString(mode);
205   }
206 }
207 
GetDecryptModeChar() const208 std::vector<char> Converter::GetDecryptModeChar() const {
209   std::string decrypt_mode = "";
210   if (data_ != nullptr) {
211     decrypt_mode = data_->decrypt_mode;
212   }
213   return StringToChar(decrypt_mode);
214 }
215 
SetEnableEncryption(bool encryption)216 void Converter::SetEnableEncryption(bool encryption) {
217   if (data_ != nullptr) {
218     data_->enable_encryption = encryption;
219   }
220 }
221 
GetEnableEncryption() const222 bool Converter::GetEnableEncryption() const {
223   if (data_ != nullptr) {
224     return data_->enable_encryption;
225   } else {
226     return false;
227   }
228 }
229 
SetEncryptKey(const std::vector<char> & key)230 void Converter::SetEncryptKey(const std::vector<char> &key) {
231   if (data_ != nullptr) {
232     data_->encrypt_key = CharToString(key);
233   }
234 }
235 
GetEncryptKeyChar() const236 std::vector<char> Converter::GetEncryptKeyChar() const {
237   std::string encrypt_key = "";
238   if (data_ != nullptr) {
239     encrypt_key = data_->encrypt_key;
240   }
241   return StringToChar(encrypt_key);
242 }
243 
SetInfer(bool infer)244 void Converter::SetInfer(bool infer) {
245   if (data_ != nullptr) {
246     data_->pre_infer = infer;
247   }
248 }
249 
GetInfer() const250 bool Converter::GetInfer() const {
251   if (data_ != nullptr) {
252     return data_->pre_infer;
253   } else {
254     return false;
255   }
256 }
257 
SetTrainModel(bool train_model)258 void Converter::SetTrainModel(bool train_model) {
259   if (data_ != nullptr) {
260     data_->train_model = train_model;
261   }
262 }
263 
GetTrainModel() const264 bool Converter::GetTrainModel() const {
265   if (data_ != nullptr) {
266     return data_->train_model;
267   } else {
268     return false;
269   }
270 }
271 
SetNoFusion(bool no_fusion)272 void Converter::SetNoFusion(bool no_fusion) {
273   if (data_ != nullptr) {
274     data_->no_fusion = no_fusion;
275   }
276 }
277 
GetNoFusion()278 bool Converter::GetNoFusion() {
279   if (data_ != nullptr) {
280     return data_->no_fusion;
281   } else {
282     return false;
283   }
284 }
285 
SetOptimizeTransformer(bool optimizeTransformer)286 void Converter::SetOptimizeTransformer(bool optimizeTransformer) {
287   if (data_ != nullptr) {
288     data_->optimize_transformer = optimizeTransformer;
289   }
290 }
291 
GetOptimizeTransformer()292 bool Converter::GetOptimizeTransformer() {
293   if (data_ != nullptr) {
294     return data_->optimize_transformer;
295   } else {
296     return false;
297   }
298 }
299 
SetDevice(const std::vector<char> & device)300 void Converter::SetDevice(const std::vector<char> &device) {
301   if (data_ != nullptr) {
302     data_->device = CharToString(device);
303   }
304 }
305 
GetDeviceChar()306 std::vector<char> Converter::GetDeviceChar() {
307   std::string device = "";
308   if (data_ != nullptr) {
309     device = data_->device;
310   }
311   return StringToChar(device);
312 }
313 
SetDeviceId(int32_t device_id)314 void Converter::SetDeviceId(int32_t device_id) {
315   if (data_ != nullptr) {
316     data_->aclModelOptionCfgParam.device_id = device_id;
317   }
318 }
319 
GetDeviceId()320 int32_t Converter::GetDeviceId() {
321   if (data_ != nullptr) {
322     return data_->aclModelOptionCfgParam.device_id;
323   }
324   return 0;
325 }
326 
SetRankId(int32_t rank_id)327 void Converter::SetRankId(int32_t rank_id) {
328   if (data_ != nullptr) {
329     data_->aclModelOptionCfgParam.rank_id = rank_id;
330   }
331 }
332 
GetRankId()333 int32_t Converter::GetRankId() {
334   if (data_ != nullptr) {
335     return data_->aclModelOptionCfgParam.rank_id;
336   }
337   return 0;
338 }
339 
SetProvider(const std::vector<char> & provider)340 void Converter::SetProvider(const std::vector<char> &provider) {
341   if (data_ != nullptr) {
342     data_->provider = CharToString(provider);
343   }
344 }
345 
GetProviderChar()346 std::vector<char> Converter::GetProviderChar() {
347   std::string provider = "";
348   if (data_ != nullptr) {
349     provider = data_->provider;
350   }
351   return StringToChar(provider);
352 }
353 
SetChipName(const std::vector<char> & chip_name)354 void Converter::SetChipName(const std::vector<char> &chip_name) {
355   if (data_ != nullptr) {
356     data_->chip_name = CharToString(chip_name);
357   }
358 }
359 
GetChipNameChar()360 std::vector<char> Converter::GetChipNameChar() {
361   std::string chip_name = "";
362   if (data_ != nullptr) {
363     chip_name = data_->chip_name;
364   }
365   return StringToChar(chip_name);
366 }
367 
Convert()368 Status Converter::Convert() {
369   if (data_ != nullptr) {
370     Status ret = Status(static_cast<StatusCode>(lite::RunConverter(data_, nullptr, nullptr, false)));
371     data_->decrypt_key.clear();  // clear key
372     data_->encrypt_key.clear();  // clear key
373     if (ret != kSuccess) {
374       MS_LOG(ERROR) << "Convert model failed, ret=" << ret;
375     }
376     return ret;
377   } else {
378     return kLiteError;
379   }
380 }
381 
Convert(size_t * data_size)382 void *Converter::Convert(size_t *data_size) {
383   void *model_data = nullptr;
384   if (data_ != nullptr) {
385     Status ret = Status(static_cast<StatusCode>(lite::RunConverter(data_, &model_data, data_size, true)));
386     data_->decrypt_key.clear();  // clear key
387     data_->encrypt_key.clear();  // clear key
388     if (ret != kSuccess) {
389       MS_LOG(ERROR) << "Convert model failed, ret=" << ret;
390     }
391   } else {
392     MS_LOG(ERROR) << "Convert model failed, data is null.";
393   }
394   return model_data;
395 }
396 
Convert(converter::FmkType fmk_type,const std::vector<char> & model_file,const std::vector<char> & output_file,const std::vector<char> & weight_file)397 Status Converter::Convert(converter::FmkType fmk_type, const std::vector<char> &model_file,
398                           const std::vector<char> &output_file, const std::vector<char> &weight_file) {
399   if (data_ != nullptr) {
400     data_->fmk_type = fmk_type;
401     data_->model_file = CharToString(model_file);
402     data_->output_file = CharToString(output_file);
403     data_->weight_file = CharToString(weight_file);
404     Status ret = Converter::Convert();
405     if (ret != kSuccess) {
406       MS_LOG(ERROR) << "Convert model " << CharToString(model_file) << " failed, ret=" << ret;
407     }
408     lite::ConverterInnerContext::GetInstance()->Free();
409     return ret;
410   } else {
411     MS_LOG(ERROR) << "Convert model " << CharToString(model_file) << " failed, data is null.";
412     return kLiteError;
413   }
414 }
415 }  // namespace mindspore
416