1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H 19 20 #include <string> 21 #include <vector> 22 #include "include/api/format.h" 23 #include "include/registry/converter_context.h" 24 #include "tools/common/flag_parser.h" 25 #include "ir/dtype/type_id.h" 26 #include "schema/inner/model_generated.h" 27 #include "tools/converter/preprocess/preprocess_param.h" 28 #include "tools/converter/quantizer/quant_params.h" 29 30 namespace mindspore { 31 namespace lite { 32 class ConfigFileParser; 33 } // namespace lite 34 namespace converter { 35 using mindspore::schema::QuantType; 36 enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 }; 37 constexpr auto kMaxSplitRatio = 10; 38 constexpr auto kComputeRate = "computeRate"; 39 constexpr auto kSplitDevice0 = "device0"; 40 constexpr auto kSplitDevice1 = "device1"; 41 struct ParallelSplitConfig { 42 ParallelSplitType parallel_split_type_ = SplitNo; 43 std::vector<int64_t> parallel_compute_rates_; 44 std::vector<std::string> parallel_devices_; 45 }; 46 47 class Flags : public virtual mindspore::lite::FlagParser { 48 public: 49 Flags(); 50 51 ~Flags() override = default; 52 53 int InitInputOutputDataType(); 54 55 int InitFmk(); 56 57 int InitTrainModel(); 58 59 int InitConfigFile(); 60 61 int InitInTensorShape(); 62 63 int InitGraphInputFormat(); 64 65 int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser); 66 67 int Init(int argc, const char **argv); 68 69 public: 70 std::string modelFile; 71 std::string outputFile; 72 std::string fmkIn; 73 FmkType fmk; 74 std::string weightFile; 75 TypeId inputDataType; 76 TypeId outputDataType; 77 std::string saveFP16Str = "off"; 78 bool saveFP16 = false; 79 std::string inputDataTypeStr; 80 std::string outputDataTypeStr; 81 ParallelSplitConfig parallel_split_config_{}; 82 std::string configFile; 83 std::string trainModelIn; 84 bool trainModel = false; 85 std::vector<std::string> pluginsPath; 86 bool disableFusion = false; 87 std::string inTensorShape; 88 std::string dec_key = ""; 89 std::string dec_mode = "AES-GCM"; 90 std::string graphInputFormatStr; 91 mindspore::Format graphInputFormat = mindspore::NHWC; 92 93 lite::quant::CommonQuantParam commonQuantParam; 94 lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; 95 lite::quant::FullQuantParam fullQuantParam; 96 lite::preprocess::DataPreProcessParam dataPreProcessParam; 97 }; 98 99 bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config); 100 101 std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key); 102 } // namespace converter 103 } // namespace mindspore 104 105 #endif 106