• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_
17 #define MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_
18 
19 #include <map>
20 #include <string>
21 #include <vector>
22 #include <set>
23 #include "include/converter.h"
24 #include "mindapi/base/type_id.h"
25 #include "tools/converter/quantizer/quant_params.h"
26 #include "tools/converter/preprocess/preprocess_param.h"
27 #include "tools/converter/adapter/acl/common/acl_types.h"
28 #include "tools/converter/micro/coder/config.h"
29 #include "src/common/config_infos.h"
30 
31 namespace mindspore {
32 enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
33 
34 struct ParallelSplitConfig {
35   ParallelSplitType parallel_split_type_ = SplitNo;
36   std::vector<int64_t> parallel_compute_rates_;
37   std::vector<std::string> parallel_devices_;
38 };
39 
40 struct ThirdPartyModelParam {
41   std::vector<TypeId> input_dtypes;
42   std::vector<std::vector<int64_t>> input_shapes;
43   std::vector<std::string> input_names;
44   std::vector<schema::Format> input_formats;
45   std::vector<TypeId> output_dtypes;
46   std::vector<std::vector<int64_t>> output_shapes;
47   std::vector<std::string> output_names;
48   std::vector<schema::Format> output_formats;
49   std::map<std::string, std::vector<uint8_t>> extended_parameters;
50 };
51 
52 struct CpuOptionCfg {
53   std::string architecture;
54   std::string instruction;
55 };
56 
57 struct GraphKernelCfg {
58   std::string graph_kernel_flags;
59 };
60 
61 struct AscendGeOptionCfg {
62   std::vector<std::string> plugin_custom_ops;
63   std::map<std::string, std::map<std::string, std::string>> op_attrs_map;
64   std::vector<int64_t> inputs_to_variable;
65   std::vector<int64_t> outputs_to_variable;
66 };
67 
68 struct ConverterPara {
69   converter::FmkType fmk_type;
70   std::string model_file;
71   std::string output_file;
72   std::string weight_file;
73 
74   std::string config_file;
75   std::map<std::string, std::map<std::string, std::string>> config_param;
76   bool weight_fp16 = false;
77   std::map<std::string, std::vector<int64_t>> input_shape;
78   Format input_format = NHWC;
79   Format spec_input_format = DEFAULT_FORMAT;
80   Format spec_output_format = DEFAULT_FORMAT;
81   DataType input_data_type = DataType::kNumberTypeFloat32;
82   DataType output_data_type = DataType::kNumberTypeFloat32;
83 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
84   ModelType save_type = kMindIR;
85 #else
86   ModelType save_type = kMindIR_Lite;
87 #endif
88   std::string decrypt_key;
89   std::string decrypt_mode = "AES-GCM";
90   std::string encrypt_key;
91   std::string encrypt_mode = "AES-GCM";  // inner
92 #ifdef ENABLE_OPENSSL
93   bool enable_encryption = true;
94 #else
95   bool enable_encryption = false;
96 #endif
97   bool pre_infer = false;
98   bool train_model = false;
99   bool no_fusion = false;
100   bool optimize_transformer = false;
101   bool is_runtime_converter = false;
102   bool enable_memory_offload = false;
103   std::set<std::string> fusion_blacklists;
104 
105   // inner
106   std::vector<std::string> plugins_path;
107   lite::quant::CommonQuantParam commonQuantParam;
108   lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
109   lite::quant::FullQuantParam fullQuantParam;
110   lite::quant::WeightQuantParam weightQuantParam;
111   lite::preprocess::DataPreProcessParam dataPreProcessParam;
112   lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
113   lite::micro::MicroParam microParam;
114   ParallelSplitConfig parallel_split_config;
115   ThirdPartyModelParam thirdPartyModelParam;
116   AscendGeOptionCfg ascendGeOptionCfg;
117   std::string device;
118   std::string provider;
119   std::string chip_name;
120   CpuOptionCfg cpuOptionCfgParam;
121   lite::quant::TransformQuantParam transformQuantParam;
122   lite::quant::DynamicQuantParam dynamicQuantParam;
123   GraphKernelCfg graphKernelParam;
124   // configs parse from config_file
125   ConfigInfos config_infos;
126 };
127 }  // namespace mindspore
128 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_
129