• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/config_file_parser.h"
18 #include "tools/common/parse_config_utils.h"
19 #include "include/errorcode.h"
20 #include "src/common/log_adapter.h"
21 
22 namespace mindspore {
23 namespace lite {
24 namespace {
25 constexpr auto kCommonQuantParam = "common_quant_param";
26 constexpr auto kFullQuantParam = "full_quant_param";
27 constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
28 constexpr auto kDataPreprocessParam = "data_preprocess_param";
29 constexpr auto kRegistry = "registry";
30 }  // namespace
ParseConfigFile(const std::string & config_file_path)31 int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
32   std::map<std::string, std::map<std::string, std::string>> maps;
33   auto ret = mindspore::lite::ParseConfigFile(config_file_path, &maps);
34   if (ret != RET_OK) {
35     MS_LOG(ERROR) << "Parse config file failed.";
36     return ret;
37   }
38   ret = ParseDataPreProcessString(maps);
39   if (ret != RET_OK) {
40     MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
41     return ret;
42   }
43   ret = ParseCommonQuantString(maps);
44   if (ret != RET_OK) {
45     MS_LOG(ERROR) << "ParseCommonQuantString failed.";
46     return ret;
47   }
48   ret = ParseMixedBitQuantString(maps);
49   if (ret != RET_OK) {
50     MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
51     return ret;
52   }
53   ret = ParseFullQuantString(maps);
54   if (ret != RET_OK) {
55     MS_LOG(ERROR) << "ParseFullQuantString failed.";
56     return ret;
57   }
58   ret = ParseRegistryInfoString(maps);
59   if (ret != RET_OK) {
60     MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
61     return ret;
62   }
63   return RET_OK;
64 }
65 
SetMapData(const std::map<std::string,std::string> & input_map,const std::map<std::string,std::string &> & parse_map,const std::string & section)66 int ConfigFileParser::SetMapData(const std::map<std::string, std::string> &input_map,
67                                  const std::map<std::string, std::string &> &parse_map, const std::string &section) {
68   for (const auto &map : input_map) {
69     if (parse_map.find(map.first) == parse_map.end()) {
70       MS_LOG(ERROR) << "INPUT ILLEGAL: `" << map.first << "` is not supported in "
71                     << "[" << section << "]";
72       return RET_INPUT_PARAM_INVALID;
73     } else {
74       parse_map.at(map.first) = map.second;
75     }
76   }
77   return RET_OK;
78 }
79 
ParseDataPreProcessString(const std::map<std::string,std::map<std::string,std::string>> & maps)80 int ConfigFileParser::ParseDataPreProcessString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
81   if (maps.find(kDataPreprocessParam) != maps.end()) {
82     const auto &map = maps.at(kDataPreprocessParam);
83     std::map<std::string, std::string &> parse_map{
84       {"calibrate_path", data_pre_process_string_.calibrate_path},
85       {"calibrate_size", data_pre_process_string_.calibrate_size},
86       {"input_type", data_pre_process_string_.input_type},
87       {"image_to_format", data_pre_process_string_.image_to_format},
88       {"normalize_mean", data_pre_process_string_.normalize_mean},
89       {"normalize_std", data_pre_process_string_.normalize_std},
90       {"resize_width", data_pre_process_string_.resize_width},
91       {"resize_height", data_pre_process_string_.resize_height},
92       {"resize_method", data_pre_process_string_.resize_method},
93       {"center_crop_width", data_pre_process_string_.center_crop_width},
94       {"center_crop_height", data_pre_process_string_.center_crop_height},
95     };
96     return SetMapData(map, parse_map, kDataPreprocessParam);
97   }
98   return RET_OK;
99 }
100 
ParseCommonQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)101 int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
102   if (maps.find(kCommonQuantParam) != maps.end()) {
103     const auto &map = maps.at(kCommonQuantParam);
104     std::map<std::string, std::string &> parse_map{
105       {"quant_type", common_quant_string_.quant_type},
106       {"bit_num", common_quant_string_.bit_num},
107       {"min_quant_weight_size", common_quant_string_.min_quant_weight_size},
108       {"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
109     };
110     return SetMapData(map, parse_map, kCommonQuantParam);
111   }
112   return RET_OK;
113 }
114 
ParseMixedBitQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)115 int ConfigFileParser::ParseMixedBitQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
116   if (maps.find(kMixedBitWeightQuantParam) != maps.end()) {
117     const auto &map = maps.at(kMixedBitWeightQuantParam);
118     std::map<std::string, std::string &> parse_map{
119       {"init_scale", mixed_bit_quant_string_.init_scale},
120     };
121     return SetMapData(map, parse_map, kMixedBitWeightQuantParam);
122   }
123   return RET_OK;
124 }
125 
ParseFullQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)126 int ConfigFileParser::ParseFullQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
127   if (maps.find(kFullQuantParam) != maps.end()) {
128     const auto &map = maps.at(kFullQuantParam);
129     std::map<std::string, std::string &> parse_map{
130       {"activation_quant_method", full_quant_string_.activation_quant_method},
131       {"bias_correction", full_quant_string_.bias_correction},
132     };
133     return SetMapData(map, parse_map, kFullQuantParam);
134   }
135   return RET_OK;
136 }
137 
ParseRegistryInfoString(const std::map<std::string,std::map<std::string,std::string>> & maps)138 int ConfigFileParser::ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
139   if (maps.find(kRegistry) != maps.end()) {
140     const auto &map = maps.at(kRegistry);
141     std::map<std::string, std::string &> parse_map{
142       {"plugin_path", registry_info_string_.plugin_path},
143       {"disable_fusion", registry_info_string_.disable_fusion},
144     };
145     return SetMapData(map, parse_map, kRegistry);
146   }
147   return RET_OK;
148 }
149 }  // namespace lite
150 }  // namespace mindspore
151