• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H
19 
20 #include <string>
21 #include <vector>
22 #include "include/api/format.h"
23 #include "include/registry/converter_context.h"
24 #include "tools/common/flag_parser.h"
25 #include "ir/dtype/type_id.h"
26 #include "schema/inner/model_generated.h"
27 #include "tools/converter/preprocess/preprocess_param.h"
28 #include "tools/converter/quantizer/quant_params.h"
29 
30 namespace mindspore {
31 namespace lite {
32 class ConfigFileParser;
33 }  // namespace lite
34 namespace converter {
35 using mindspore::schema::QuantType;
36 enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
37 constexpr auto kMaxSplitRatio = 10;
38 constexpr auto kComputeRate = "computeRate";
39 constexpr auto kSplitDevice0 = "device0";
40 constexpr auto kSplitDevice1 = "device1";
41 struct ParallelSplitConfig {
42   ParallelSplitType parallel_split_type_ = SplitNo;
43   std::vector<int64_t> parallel_compute_rates_;
44   std::vector<std::string> parallel_devices_;
45 };
46 
47 class Flags : public virtual mindspore::lite::FlagParser {
48  public:
49   Flags();
50 
51   ~Flags() override = default;
52 
53   int InitInputOutputDataType();
54 
55   int InitFmk();
56 
57   int InitTrainModel();
58 
59   int InitConfigFile();
60 
61   int InitInTensorShape();
62 
63   int InitGraphInputFormat();
64 
65   int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
66 
67   int Init(int argc, const char **argv);
68 
69  public:
70   std::string modelFile;
71   std::string outputFile;
72   std::string fmkIn;
73   FmkType fmk;
74   std::string weightFile;
75   TypeId inputDataType;
76   TypeId outputDataType;
77   std::string saveFP16Str = "off";
78   bool saveFP16 = false;
79   std::string inputDataTypeStr;
80   std::string outputDataTypeStr;
81   ParallelSplitConfig parallel_split_config_{};
82   std::string configFile;
83   std::string trainModelIn;
84   bool trainModel = false;
85   std::vector<std::string> pluginsPath;
86   bool disableFusion = false;
87   std::string inTensorShape;
88   std::string dec_key = "";
89   std::string dec_mode = "AES-GCM";
90   std::string graphInputFormatStr;
91   mindspore::Format graphInputFormat = mindspore::NHWC;
92 
93   lite::quant::CommonQuantParam commonQuantParam;
94   lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
95   lite::quant::FullQuantParam fullQuantParam;
96   lite::preprocess::DataPreProcessParam dataPreProcessParam;
97 };
98 
99 bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
100 
101 std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
102 }  // namespace converter
103 }  // namespace mindspore
104 
105 #endif
106