1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ 17 #define MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ 18 19 #include <map> 20 #include <string> 21 #include <vector> 22 #include <set> 23 #include "include/converter.h" 24 #include "mindapi/base/type_id.h" 25 #include "tools/converter/quantizer/quant_params.h" 26 #include "tools/converter/preprocess/preprocess_param.h" 27 #include "tools/converter/adapter/acl/common/acl_types.h" 28 #include "tools/converter/micro/coder/config.h" 29 #include "src/common/config_infos.h" 30 31 namespace mindspore { 32 enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 }; 33 34 struct ParallelSplitConfig { 35 ParallelSplitType parallel_split_type_ = SplitNo; 36 std::vector<int64_t> parallel_compute_rates_; 37 std::vector<std::string> parallel_devices_; 38 }; 39 40 struct ThirdPartyModelParam { 41 std::vector<TypeId> input_dtypes; 42 std::vector<std::vector<int64_t>> input_shapes; 43 std::vector<std::string> input_names; 44 std::vector<schema::Format> input_formats; 45 std::vector<TypeId> output_dtypes; 46 std::vector<std::vector<int64_t>> output_shapes; 47 std::vector<std::string> output_names; 48 std::vector<schema::Format> output_formats; 49 std::map<std::string, std::vector<uint8_t>> extended_parameters; 50 }; 51 52 struct CpuOptionCfg { 53 std::string architecture; 54 std::string instruction; 55 }; 56 57 struct GraphKernelCfg { 58 std::string graph_kernel_flags; 59 }; 60 61 struct AscendGeOptionCfg { 62 std::vector<std::string> plugin_custom_ops; 63 std::map<std::string, std::map<std::string, std::string>> op_attrs_map; 64 std::vector<int64_t> inputs_to_variable; 65 std::vector<int64_t> outputs_to_variable; 66 }; 67 68 struct ConverterPara { 69 converter::FmkType fmk_type; 70 std::string model_file; 71 std::string output_file; 72 std::string weight_file; 73 74 std::string config_file; 75 std::map<std::string, std::map<std::string, std::string>> config_param; 76 bool weight_fp16 = false; 77 std::map<std::string, std::vector<int64_t>> input_shape; 78 Format input_format = NHWC; 79 Format spec_input_format = DEFAULT_FORMAT; 80 Format spec_output_format = DEFAULT_FORMAT; 81 DataType input_data_type = DataType::kNumberTypeFloat32; 82 DataType output_data_type = DataType::kNumberTypeFloat32; 83 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) 84 ModelType save_type = kMindIR; 85 #else 86 ModelType save_type = kMindIR_Lite; 87 #endif 88 std::string decrypt_key; 89 std::string decrypt_mode = "AES-GCM"; 90 std::string encrypt_key; 91 std::string encrypt_mode = "AES-GCM"; // inner 92 #ifdef ENABLE_OPENSSL 93 bool enable_encryption = true; 94 #else 95 bool enable_encryption = false; 96 #endif 97 bool pre_infer = false; 98 bool train_model = false; 99 bool no_fusion = false; 100 bool optimize_transformer = false; 101 bool is_runtime_converter = false; 102 bool enable_memory_offload = false; 103 std::set<std::string> fusion_blacklists; 104 105 // inner 106 std::vector<std::string> plugins_path; 107 lite::quant::CommonQuantParam commonQuantParam; 108 lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; 109 lite::quant::FullQuantParam fullQuantParam; 110 lite::quant::WeightQuantParam weightQuantParam; 111 lite::preprocess::DataPreProcessParam dataPreProcessParam; 112 lite::acl::AclModelOptionCfg aclModelOptionCfgParam; 113 lite::micro::MicroParam microParam; 114 ParallelSplitConfig parallel_split_config; 115 ThirdPartyModelParam thirdPartyModelParam; 116 AscendGeOptionCfg ascendGeOptionCfg; 117 std::string device; 118 std::string provider; 119 std::string chip_name; 120 CpuOptionCfg cpuOptionCfgParam; 121 lite::quant::TransformQuantParam transformQuantParam; 122 lite::quant::DynamicQuantParam dynamicQuantParam; 123 GraphKernelCfg graphKernelParam; 124 // configs parse from config_file 125 ConfigInfos config_infos; 126 }; 127 } // namespace mindspore 128 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ 129