1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H
18 #define MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H
19
20 #include <functional>
21 #include <map>
22 #include <utility>
23 #include <string>
24 #include "src/common/utils.h"
25 #include "tools/common/option.h"
26
27 namespace mindspore {
28 namespace lite {
29 struct Nothing {};
30
31 class FlagParser {
32 public:
FlagParser()33 FlagParser() { AddFlag(&FlagParser::help, helpStr, "print usage message", ""); }
34
35 virtual ~FlagParser() = default;
36
37 // only support read flags from command line
38 virtual Option<std::string> ParseFlags(int argc, const char *const *argv, bool supportUnknown = false,
39 bool supportDuplicate = false);
40 std::string Usage(const Option<std::string> &usgMsg = Option<std::string>(None())) const;
41
42 template <typename Flags, typename T1, typename T2>
43 void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2);
44
45 template <typename Flags, typename T1, typename T2>
46 void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2);
47
48 // non-Option type fields in class
49 template <typename Flags, typename T1, typename T2>
50 void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2);
51
52 template <typename Flags, typename T1, typename T2>
53 void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2);
54
55 template <typename Flags, typename T>
56 void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo);
57
58 // Option-type fields
59 template <typename Flags, typename T>
60 void AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo);
61 bool help{};
62
63 protected:
64 template <typename Flags>
AddFlag(std::string Flags::* t1,const std::string & flagName,const std::string & helpInfo,const char * t2)65 void AddFlag(std::string Flags::*t1, const std::string &flagName, const std::string &helpInfo, const char *t2) {
66 AddFlag(t1, flagName, helpInfo, std::string(t2));
67 }
68
69 std::string binName;
70 Option<std::string> usageMsg;
71 std::string helpStr = "help";
72
73 private:
74 struct FlagInfo {
75 std::string flagName;
76 bool isRequired = false;
77 bool isBoolean = false;
78 std::string helpInfo;
79 bool isParsed = false;
80 std::function<Option<Nothing>(FlagParser *, const std::string &)> parse = nullptr;
81 };
82
83 inline void AddFlag(const FlagInfo &flag);
84
85 // construct a temporary flag
86 template <typename Flags, typename T>
87 void ConstructFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag);
88
89 // construct a temporary flag
90 template <typename Flags, typename T1>
91 void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag);
92
93 Option<std::string> InnerParseFlags(std::multimap<std::string, Option<std::string>> *values);
94
95 static bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName);
96
97 std::map<std::string, FlagInfo> flags;
98 };
99
100 // convert to std::string
101 template <typename Flags, typename T>
ConvertToString(T Flags::* t,const FlagParser & baseFlag)102 Option<std::string> ConvertToString(T Flags::*t, const FlagParser &baseFlag) {
103 const Flags *flag = dynamic_cast<Flags *>(&baseFlag);
104 if (flag != nullptr) {
105 return std::to_string(flag->*t);
106 }
107
108 return Option<std::string>(None());
109 }
110
111 // construct for a Option-type flag
112 template <typename Flags, typename T>
ConstructFlag(Option<T> Flags::* t1,const std::string & flagName,const std::string & helpInfo,FlagInfo * flag)113 void FlagParser::ConstructFlag(Option<T> Flags::*t1, const std::string &flagName, const std::string &helpInfo,
114 FlagInfo *flag) {
115 if (flag == nullptr) {
116 MS_LOG(ERROR) << "FlagInfo is nullptr";
117 return;
118 }
119 flag->flagName = flagName;
120 flag->helpInfo = helpInfo;
121
122 flag->isBoolean = typeid(T) == typeid(bool);
123 flag->isParsed = false;
124 }
125
126 // construct a temporary flag
127 template <typename Flags, typename T>
ConstructFlag(T Flags::* t1,const std::string & flagName,const std::string & helpInfo,FlagInfo * flag)128 void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) {
129 if (flag == nullptr) {
130 MS_LOG(ERROR) << "FlagInfo is nullptr";
131 return;
132 }
133 if (t1 == nullptr) {
134 MS_LOG(ERROR) << "t1 is nullptr";
135 return;
136 }
137 flag->flagName = flagName;
138 flag->helpInfo = helpInfo;
139 flag->isBoolean = typeid(T) == typeid(bool);
140 flag->isParsed = false;
141 }
142
AddFlag(const FlagInfo & flagItem)143 inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; }
144
145 template <typename Flags, typename T>
AddFlag(T Flags::* t,const std::string & flagName,const std::string & helpInfo)146 void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) {
147 if (t == nullptr) {
148 MS_LOG(ERROR) << "t1 is nullptr";
149 return;
150 }
151 AddFlag(t, flagName, helpInfo, static_cast<const T *>(nullptr));
152 }
153
154 template <typename Flags, typename T1, typename T2>
AddFlag(T1 Flags::* t1,const std::string & flagName,const std::string & helpInfo,const T2 & t2)155 void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) {
156 if (t1 == nullptr) {
157 MS_LOG(ERROR) << "t1 is nullptr";
158 return;
159 }
160 AddFlag(t1, flagName, helpInfo, &t2);
161 }
162
163 // just for test
164 template <typename Flags, typename T1, typename T2>
AddFlag(T1 * t1,const std::string & flagName,const std::string & helpInfo,const T2 & t2)165 void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) {
166 if (t1 == nullptr) {
167 MS_LOG(ERROR) << "t1 is nullptr";
168 return;
169 }
170 AddFlag(t1, flagName, helpInfo, &t2);
171 }
172
173 template <typename Flags, typename T1, typename T2>
AddFlag(T1 * t1,const std::string & flagName,const std::string & helpInfo,const T2 * t2)174 void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) {
175 if (t1 == nullptr) {
176 MS_LOG(ERROR) << "t1 is nullptr";
177 return;
178 }
179
180 FlagInfo flagItem;
181
182 // flagItem is as an output parameter
183 ConstructFlag(t1, flagName, helpInfo, flagItem);
184 flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
185 if (base != nullptr) {
186 Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
187 if (ret.IsNone()) {
188 return Option<T1>(None());
189 } else {
190 *t1 = ret.Get();
191 }
192 }
193
194 return Option<Nothing>(Nothing());
195 };
196
197 if (t2 != nullptr) {
198 flagItem.isRequired = false;
199 *t1 = *t2;
200 }
201
202 flagItem.helpInfo +=
203 !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: ";
204 if (t2 != nullptr) {
205 flagItem.helpInfo += ToString(*t2).Get();
206 }
207 flagItem.helpInfo += ")";
208
209 // add this flag to a std::map
210 AddFlag(flagItem);
211 }
212
213 template <typename Flags, typename T1, typename T2>
AddFlag(T1 Flags::* t1,const std::string & flagName,const std::string & helpInfo,const T2 * t2)214 void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) {
215 if (t1 == nullptr) {
216 MS_LOG(ERROR) << "t1 is nullptr";
217 return;
218 }
219
220 auto *flag = dynamic_cast<Flags *>(this);
221 if (flag == nullptr) {
222 return;
223 }
224
225 FlagInfo flagItem;
226
227 // flagItem is as a output parameter
228 ConstructFlag(t1, flagName, helpInfo, &flagItem);
229 flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
230 auto *flag = dynamic_cast<Flags *>(base);
231 if (flag == nullptr) {
232 return Option<Nothing>(None());
233 }
234 if (base != nullptr) {
235 Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
236 if (ret.IsNone()) {
237 return Option<Nothing>(None());
238 } else {
239 flag->*t1 = ret.Get();
240 }
241 }
242
243 return Option<Nothing>(Nothing());
244 };
245
246 if (t2 != nullptr) {
247 flagItem.isRequired = false;
248 flag->*t1 = *t2;
249 } else {
250 flagItem.isRequired = true;
251 }
252
253 flagItem.helpInfo +=
254 !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: ";
255 if (t2 != nullptr) {
256 flagItem.helpInfo += ToString(*t2).Get();
257 }
258 flagItem.helpInfo += ")";
259
260 // add this flag to a std::map
261 AddFlag(flagItem);
262 }
263
264 // option-type add flag
265 template <typename Flags, typename T>
AddFlag(Option<T> Flags::* t,const std::string & flagName,const std::string & helpInfo)266 void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo) {
267 if (t == nullptr) {
268 MS_LOG(ERROR) << "t is nullptr";
269 return;
270 }
271
272 auto *flag = dynamic_cast<Flags *>(this);
273 if (flag == nullptr) {
274 MS_LOG(ERROR) << "dynamic_cast failed";
275 return;
276 }
277
278 FlagInfo flagItem;
279 // flagItem is as a output parameter
280 ConstructFlag(t, flagName, helpInfo, &flagItem);
281 flagItem.isRequired = false;
282 flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> {
283 if (base == nullptr) {
284 return Option<Nothing>(Nothing());
285 }
286 auto *flag = dynamic_cast<Flags *>(base);
287 if (flag != nullptr) {
288 Option<T> ret = Option<std::string>(GenericParseValue<T>(value));
289 if (ret.IsNone()) {
290 return Option<Nothing>(None());
291 } else {
292 flag->*t = Option<T>(Some(ret.Get()));
293 }
294 }
295
296 return Option<Nothing>(Nothing());
297 };
298
299 // add this flag to a std::map
300 AddFlag(flagItem);
301 }
302 } // namespace lite
303 } // namespace mindspore
304
305 #endif // MINDSPORE_LITE_TOOLS_COMMON_FLAG_PARSER_H
306