1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_LITE_INCLUDE_CONVERTER_H_
17 #define MINDSPORE_LITE_INCLUDE_CONVERTER_H_
18
19 #include <map>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include "include/api/format.h"
24 #include "include/api/status.h"
25 #include "include/registry/converter_context.h"
26 #include "include/api/dual_abi_helper.h"
27
28 namespace mindspore {
29 struct ConverterPara;
30 /// \brief Converter provides C++ API for user to integrate model conversion into user application.
31 ///
32 /// \note Converter C++ API cannot be used in Converter main process.
33 ///
34 /// \note Converter C++ API doesn't support calling with multi-threads in a single process.
35 class MS_API Converter {
36 public:
37 Converter();
38 inline Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
39 const std::string &weight_file = "");
40 ~Converter() = default;
41
42 inline void SetConfigFile(const std::string &config_file);
43 inline std::string GetConfigFile() const;
44
45 inline void SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config);
46 inline std::map<std::string, std::map<std::string, std::string>> GetConfigInfo() const;
47
48 void SetWeightFp16(bool weight_fp16);
49 bool GetWeightFp16() const;
50
51 inline void SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape);
52 inline std::map<std::string, std::vector<int64_t>> GetInputShape() const;
53
54 void SetInputFormat(Format format);
55 Format GetInputFormat() const;
56
57 void SetOutputFormat(Format format);
58
59 void SetInputDataType(DataType data_type);
60 DataType GetInputDataType();
61
62 void SetOutputDataType(DataType data_type);
63 DataType GetOutputDataType();
64
65 void SetSaveType(ModelType save_type);
66 ModelType GetSaveType() const;
67
68 inline void SetDecryptKey(const std::string &key);
69 inline std::string GetDecryptKey() const;
70
71 inline void SetDecryptMode(const std::string &mode);
72 inline std::string GetDecryptMode() const;
73
74 void SetEnableEncryption(bool encryption);
75 bool GetEnableEncryption() const;
76
77 inline void SetEncryptKey(const std::string &key);
78 inline std::string GetEncryptKey() const;
79
80 void SetInfer(bool infer);
81 bool GetInfer() const;
82
83 void SetTrainModel(bool train_model);
84 bool GetTrainModel() const;
85
86 void SetNoFusion(bool no_fusion);
87 bool GetNoFusion();
88
89 void SetOptimizeTransformer(bool optimize_transformer);
90 bool GetOptimizeTransformer();
91
92 inline void SetDevice(const std::string &device);
93 inline std::string GetDevice();
94 void SetDeviceId(int32_t device_id);
95 int32_t GetDeviceId();
96 void SetRankId(int32_t rank_id);
97 int32_t GetRankId();
98
99 inline void SetProvider(const std::string &provider);
100 inline std::string GetProvider();
101
102 inline void SetChipName(const std::string &device);
103 inline std::string GetChipName();
104
105 /// \brief Convert model and save .ms format model into `output_file` that passed in constructor.
106 Status Convert();
107
108 /// \brief Convert model and return converted FlatBuffer model binary buffer.
109 ///
110 /// \param[in] data_size Converted FlatBuffer model's buffer size.
111 ///
112 /// \return A pointer to converted FlatBuffer model buffer.
113 void *Convert(size_t *data_size);
114
115 /// \brief Convert multiple models and save .ms format models into `output_file` that passed in constructor.
116 inline Status Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
117 const std::string &weight_file = "");
118
119 private:
120 Converter(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
121 const std::vector<char> &weight_file);
122 void SetConfigFile(const std::vector<char> &config_file);
123 std::vector<char> GetConfigFileChar() const;
124 void SetConfigInfo(const std::vector<char> §ion, const std::map<std::vector<char>, std::vector<char>> &config);
125 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> GetConfigInfoChar() const;
126 void SetInputShape(const std::map<std::vector<char>, std::vector<int64_t>> &input_shape);
127 std::map<std::vector<char>, std::vector<int64_t>> GetInputShapeChar() const;
128 void SetDecryptKey(const std::vector<char> &key);
129 std::vector<char> GetDecryptKeyChar() const;
130 void SetDecryptMode(const std::vector<char> &mode);
131 std::vector<char> GetDecryptModeChar() const;
132 void SetEncryptKey(const std::vector<char> &key);
133 std::vector<char> GetEncryptKeyChar() const;
134 void SetDevice(const std::vector<char> &device);
135 std::vector<char> GetDeviceChar();
136 void SetProvider(const std::vector<char> &provider);
137 std::vector<char> GetProviderChar();
138 void SetChipName(const std::vector<char> &chip_name);
139 std::vector<char> GetChipNameChar();
140 Status Convert(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
141 const std::vector<char> &weight_file);
142 std::shared_ptr<ConverterPara> data_;
143 };
144
Converter(converter::FmkType fmk_type,const std::string & model_file,const std::string & output_file,const std::string & weight_file)145 Converter::Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
146 const std::string &weight_file)
147 : Converter(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file)) {}
148
SetConfigFile(const std::string & config_file)149 void Converter::SetConfigFile(const std::string &config_file) { SetConfigFile(StringToChar(config_file)); }
150
GetConfigFile()151 std::string Converter::GetConfigFile() const { return CharToString(GetConfigFileChar()); }
152
SetConfigInfo(const std::string & section,const std::map<std::string,std::string> & config)153 void Converter::SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config) {
154 SetConfigInfo(StringToChar(section), MapStringToVectorChar(config));
155 }
156
GetConfigInfo()157 std::map<std::string, std::map<std::string, std::string>> Converter::GetConfigInfo() const {
158 return MapMapCharToString(GetConfigInfoChar());
159 }
160
SetInputShape(const std::map<std::string,std::vector<int64_t>> & input_shape)161 void Converter::SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape) {
162 SetInputShape(MapStringToChar(input_shape));
163 }
164
GetInputShape()165 std::map<std::string, std::vector<int64_t>> Converter::GetInputShape() const {
166 return MapCharToString(GetInputShapeChar());
167 }
168
SetDecryptKey(const std::string & key)169 void Converter::SetDecryptKey(const std::string &key) { SetDecryptKey(StringToChar(key)); }
170
GetDecryptKey()171 std::string Converter::GetDecryptKey() const { return CharToString(GetDecryptKeyChar()); }
172
SetDecryptMode(const std::string & mode)173 void Converter::SetDecryptMode(const std::string &mode) { SetDecryptMode(StringToChar(mode)); }
174
GetDecryptMode()175 std::string Converter::GetDecryptMode() const { return CharToString(GetDecryptModeChar()); }
176
SetEncryptKey(const std::string & key)177 void Converter::SetEncryptKey(const std::string &key) { SetEncryptKey(StringToChar(key)); }
178
GetEncryptKey()179 std::string Converter::GetEncryptKey() const { return CharToString(GetEncryptKeyChar()); }
180
SetDevice(const std::string & device)181 void Converter::SetDevice(const std::string &device) { SetDevice(StringToChar(device)); }
182
GetDevice()183 std::string Converter::GetDevice() { return CharToString(GetDeviceChar()); }
184
SetProvider(const std::string & provider)185 void Converter::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
186
GetProvider()187 std::string Converter::GetProvider() { return CharToString(GetProviderChar()); }
188
SetChipName(const std::string & chip_name)189 void Converter::SetChipName(const std::string &chip_name) { SetChipName(StringToChar(chip_name)); }
190
GetChipName()191 std::string Converter::GetChipName() { return CharToString(GetChipNameChar()); }
192
Convert(converter::FmkType fmk_type,const std::string & model_file,const std::string & output_file,const std::string & weight_file)193 Status Converter::Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
194 const std::string &weight_file) {
195 return Convert(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file));
196 }
197 } // namespace mindspore
198 #endif // MINDSPORE_LITE_INCLUDE_CONVERTER_H_
199