• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_INCLUDE_CONVERTER_H_
17 #define MINDSPORE_LITE_INCLUDE_CONVERTER_H_
18 
19 #include <map>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include "include/api/format.h"
24 #include "include/api/status.h"
25 #include "include/registry/converter_context.h"
26 #include "include/api/dual_abi_helper.h"
27 
28 namespace mindspore {
29 struct ConverterPara;
30 /// \brief Converter provides C++ API for user to integrate model conversion into user application.
31 ///
32 /// \note Converter C++ API cannot be used in Converter main process.
33 ///
34 /// \note Converter C++ API doesn't support calling with multi-threads in a single process.
35 class MS_API Converter {
36  public:
37   Converter();
38   inline Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
39                    const std::string &weight_file = "");
40   ~Converter() = default;
41 
42   inline void SetConfigFile(const std::string &config_file);
43   inline std::string GetConfigFile() const;
44 
45   inline void SetConfigInfo(const std::string &section, const std::map<std::string, std::string> &config);
46   inline std::map<std::string, std::map<std::string, std::string>> GetConfigInfo() const;
47 
48   void SetWeightFp16(bool weight_fp16);
49   bool GetWeightFp16() const;
50 
51   inline void SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape);
52   inline std::map<std::string, std::vector<int64_t>> GetInputShape() const;
53 
54   void SetInputFormat(Format format);
55   Format GetInputFormat() const;
56 
57   void SetOutputFormat(Format format);
58 
59   void SetInputDataType(DataType data_type);
60   DataType GetInputDataType();
61 
62   void SetOutputDataType(DataType data_type);
63   DataType GetOutputDataType();
64 
65   void SetSaveType(ModelType save_type);
66   ModelType GetSaveType() const;
67 
68   inline void SetDecryptKey(const std::string &key);
69   inline std::string GetDecryptKey() const;
70 
71   inline void SetDecryptMode(const std::string &mode);
72   inline std::string GetDecryptMode() const;
73 
74   void SetEnableEncryption(bool encryption);
75   bool GetEnableEncryption() const;
76 
77   inline void SetEncryptKey(const std::string &key);
78   inline std::string GetEncryptKey() const;
79 
80   void SetInfer(bool infer);
81   bool GetInfer() const;
82 
83   void SetTrainModel(bool train_model);
84   bool GetTrainModel() const;
85 
86   void SetNoFusion(bool no_fusion);
87   bool GetNoFusion();
88 
89   void SetOptimizeTransformer(bool optimize_transformer);
90   bool GetOptimizeTransformer();
91 
92   inline void SetDevice(const std::string &device);
93   inline std::string GetDevice();
94   void SetDeviceId(int32_t device_id);
95   int32_t GetDeviceId();
96   void SetRankId(int32_t rank_id);
97   int32_t GetRankId();
98 
99   inline void SetProvider(const std::string &provider);
100   inline std::string GetProvider();
101 
102   inline void SetChipName(const std::string &device);
103   inline std::string GetChipName();
104 
105   /// \brief Convert model and save .ms format model into `output_file` that passed in constructor.
106   Status Convert();
107 
108   /// \brief Convert model and return converted FlatBuffer model binary buffer.
109   ///
110   /// \param[in] data_size Converted FlatBuffer model's buffer size.
111   ///
112   /// \return A pointer to converted FlatBuffer model buffer.
113   void *Convert(size_t *data_size);
114 
115   /// \brief Convert multiple models and save .ms format models into `output_file` that passed in constructor.
116   inline Status Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
117                         const std::string &weight_file = "");
118 
119  private:
120   Converter(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
121             const std::vector<char> &weight_file);
122   void SetConfigFile(const std::vector<char> &config_file);
123   std::vector<char> GetConfigFileChar() const;
124   void SetConfigInfo(const std::vector<char> &section, const std::map<std::vector<char>, std::vector<char>> &config);
125   std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> GetConfigInfoChar() const;
126   void SetInputShape(const std::map<std::vector<char>, std::vector<int64_t>> &input_shape);
127   std::map<std::vector<char>, std::vector<int64_t>> GetInputShapeChar() const;
128   void SetDecryptKey(const std::vector<char> &key);
129   std::vector<char> GetDecryptKeyChar() const;
130   void SetDecryptMode(const std::vector<char> &mode);
131   std::vector<char> GetDecryptModeChar() const;
132   void SetEncryptKey(const std::vector<char> &key);
133   std::vector<char> GetEncryptKeyChar() const;
134   void SetDevice(const std::vector<char> &device);
135   std::vector<char> GetDeviceChar();
136   void SetProvider(const std::vector<char> &provider);
137   std::vector<char> GetProviderChar();
138   void SetChipName(const std::vector<char> &chip_name);
139   std::vector<char> GetChipNameChar();
140   Status Convert(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
141                  const std::vector<char> &weight_file);
142   std::shared_ptr<ConverterPara> data_;
143 };
144 
Converter(converter::FmkType fmk_type,const std::string & model_file,const std::string & output_file,const std::string & weight_file)145 Converter::Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
146                      const std::string &weight_file)
147     : Converter(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file)) {}
148 
SetConfigFile(const std::string & config_file)149 void Converter::SetConfigFile(const std::string &config_file) { SetConfigFile(StringToChar(config_file)); }
150 
GetConfigFile()151 std::string Converter::GetConfigFile() const { return CharToString(GetConfigFileChar()); }
152 
SetConfigInfo(const std::string & section,const std::map<std::string,std::string> & config)153 void Converter::SetConfigInfo(const std::string &section, const std::map<std::string, std::string> &config) {
154   SetConfigInfo(StringToChar(section), MapStringToVectorChar(config));
155 }
156 
GetConfigInfo()157 std::map<std::string, std::map<std::string, std::string>> Converter::GetConfigInfo() const {
158   return MapMapCharToString(GetConfigInfoChar());
159 }
160 
SetInputShape(const std::map<std::string,std::vector<int64_t>> & input_shape)161 void Converter::SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape) {
162   SetInputShape(MapStringToChar(input_shape));
163 }
164 
GetInputShape()165 std::map<std::string, std::vector<int64_t>> Converter::GetInputShape() const {
166   return MapCharToString(GetInputShapeChar());
167 }
168 
SetDecryptKey(const std::string & key)169 void Converter::SetDecryptKey(const std::string &key) { SetDecryptKey(StringToChar(key)); }
170 
GetDecryptKey()171 std::string Converter::GetDecryptKey() const { return CharToString(GetDecryptKeyChar()); }
172 
SetDecryptMode(const std::string & mode)173 void Converter::SetDecryptMode(const std::string &mode) { SetDecryptMode(StringToChar(mode)); }
174 
GetDecryptMode()175 std::string Converter::GetDecryptMode() const { return CharToString(GetDecryptModeChar()); }
176 
SetEncryptKey(const std::string & key)177 void Converter::SetEncryptKey(const std::string &key) { SetEncryptKey(StringToChar(key)); }
178 
GetEncryptKey()179 std::string Converter::GetEncryptKey() const { return CharToString(GetEncryptKeyChar()); }
180 
SetDevice(const std::string & device)181 void Converter::SetDevice(const std::string &device) { SetDevice(StringToChar(device)); }
182 
GetDevice()183 std::string Converter::GetDevice() { return CharToString(GetDeviceChar()); }
184 
SetProvider(const std::string & provider)185 void Converter::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
186 
GetProvider()187 std::string Converter::GetProvider() { return CharToString(GetProviderChar()); }
188 
SetChipName(const std::string & chip_name)189 void Converter::SetChipName(const std::string &chip_name) { SetChipName(StringToChar(chip_name)); }
190 
GetChipName()191 std::string Converter::GetChipName() { return CharToString(GetChipNameChar()); }
192 
Convert(converter::FmkType fmk_type,const std::string & model_file,const std::string & output_file,const std::string & weight_file)193 Status Converter::Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
194                           const std::string &weight_file) {
195   return Convert(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file));
196 }
197 }  // namespace mindspore
198 #endif  // MINDSPORE_LITE_INCLUDE_CONVERTER_H_
199