1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/common/config_infos.h"
18 #include <vector>
19 #include <algorithm>
20 #include "src/common/log_adapter.h"
21 #include "src/common/common.h"
22 #include "src/common/utils.h"
23
24 namespace mindspore {
ParseRangeStr(const std::string & range_str,int64_t * min_ptr,int64_t * max_ptr)25 bool ProfileParser::ParseRangeStr(const std::string &range_str, int64_t *min_ptr, int64_t *max_ptr) {
26 if (min_ptr == nullptr || max_ptr == nullptr) {
27 return false;
28 }
29 auto min_and_max = lite::StrSplit(range_str, "~");
30 int min = 0;
31 int max = 0;
32 constexpr size_t number_range_count = 1;
33 constexpr size_t min_max_range_count = 2;
34 if (min_and_max.size() == number_range_count) {
35 if (!lite::ConvertStrToInt(min_and_max[0], &min)) {
36 MS_LOG(ERROR) << "Invalid dynamic dim value range, dim value range or format is invalid: " << range_str
37 << ". It should be 'min~max' or a number.";
38 return false;
39 }
40 max = min;
41 } else if (min_and_max.size() == min_max_range_count) {
42 if (!lite::ConvertStrToInt(min_and_max[0], &min) || !lite::ConvertStrToInt(min_and_max[1], &max)) {
43 MS_LOG(ERROR) << "Invalid dynamic dim value range, dim value range or format is invalid: " << range_str
44 << ". It should be 'min~max' or a number.";
45 return false;
46 }
47 } else {
48 MS_LOG(ERROR) << "Invalid dynamic dim value range, dim value range or format is invalid: " << range_str
49 << ". It should be 'min~max' or a number.";
50 return false;
51 }
52 if (min > max || min <= 0) {
53 MS_LOG(ERROR) << "Invalid dimension range string format of '" << lite::kDynamicDimsKey << "': " << range_str;
54 return false;
55 }
56 *min_ptr = min;
57 *max_ptr = max;
58 return true;
59 }
60
ParseOptDimStr(const std::string & opt_dim_str,int64_t * opt_ptr)61 bool ProfileParser::ParseOptDimStr(const std::string &opt_dim_str, int64_t *opt_ptr) {
62 if (opt_ptr == nullptr) {
63 return false;
64 }
65 int opt = 0;
66 if (!lite::ConvertStrToInt(opt_dim_str, &opt)) {
67 MS_LOG(ERROR) << "Invalid opt dim value range, dim value range or format is invalid: " << opt_dim_str
68 << ". It should be a number.";
69 return false;
70 }
71 if (opt <= 0) {
72 MS_LOG(ERROR) << "Invalid opt dim value range '" << lite::kOptimizeDimsKey << "': " << opt_dim_str;
73 return false;
74 }
75 *opt_ptr = opt;
76 return true;
77 }
78
ParseInputShape(const std::string & input_shapes_str,ProfileConfigs * profile_configs_ptr)79 bool ProfileParser::ParseInputShape(const std::string &input_shapes_str, ProfileConfigs *profile_configs_ptr) {
80 auto &profile_configs = *profile_configs_ptr;
81 auto input_slices = lite::StrSplit(input_shapes_str, ";");
82 for (auto &input_slice : input_slices) {
83 auto split_pos = input_slice.rfind(':');
84 if (split_pos == std::string::npos) {
85 MS_LOG(ERROR) << "The input_shape should be in format of name:shape;name:shape, but got [" << input_shapes_str
86 << "]";
87 return false;
88 }
89 std::string name = input_slice.substr(0, split_pos);
90 std::string shape_str = input_slice.substr(split_pos + 1);
91 if (shape_str.front() != '[' || shape_str.back() != ']') {
92 MS_LOG(ERROR) << "shape format check fail.";
93 return false;
94 }
95 constexpr size_t trim_len = 2;
96 shape_str = shape_str.substr(1, shape_str.size() - trim_len);
97 ProfileInputInfo info;
98 info.name = name;
99 ShapeVector &shape = info.input_shape;
100 if (!lite::ParseShapeStr(shape_str, &shape)) {
101 MS_LOG(ERROR) << "Invalid input shape dims: " << shape_str;
102 return false;
103 }
104 info.is_dynamic_shape = std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
105 profile_configs.input_infos.push_back(info);
106 }
107 return true;
108 }
109
ParseDynamicDims(const std::string & dynamic_dims_str,ProfileConfigs * profile_configs_ptr)110 bool ProfileParser::ParseDynamicDims(const std::string &dynamic_dims_str, ProfileConfigs *profile_configs_ptr) {
111 auto &profile_configs = *profile_configs_ptr;
112 auto inputs_of_str = lite::StrSplit(dynamic_dims_str, ";");
113 if (inputs_of_str.size() != profile_configs.input_infos.size()) {
114 MS_LOG(ERROR) << "The input count " << inputs_of_str.size() << " in '" << lite::kDynamicDimsKey
115 << "' != the input count " << profile_configs.input_infos.size() << " '"
116 << " in '" << lite::kInputShapeKey;
117 return false;
118 }
119 // for every input
120 for (size_t input_index = 0; input_index != inputs_of_str.size(); ++input_index) {
121 auto &info = profile_configs.input_infos[input_index];
122 auto one_input_str = inputs_of_str[input_index];
123 auto profiles_of_str = lite::StrSplit(one_input_str, "],[");
124 if (profiles_of_str.empty()) {
125 MS_LOG(ERROR) << "The profile count of " << input_index << "th input in '" << lite::kDynamicDimsKey << "' is 0";
126 return false;
127 }
128 if (profile_configs.profiles.empty()) {
129 profile_configs.profiles.resize(profiles_of_str.size());
130 for (auto &profile : profile_configs.profiles) {
131 profile.inputs.resize(profile_configs.input_infos.size());
132 }
133 }
134 if (profiles_of_str.size() != profile_configs.profiles.size()) {
135 MS_LOG(ERROR) << "The profile count " << profiles_of_str.size() << " of " << input_index << "th input in '"
136 << lite::kDynamicDimsKey << "' != profile count " << profile_configs.profiles.size() << " of "
137 << (input_index - 1) << " th input";
138 return false;
139 }
140 // for every profile in one input, parse input range: min, max
141 for (size_t profile_index = 0; profile_index != profiles_of_str.size(); ++profile_index) {
142 ProfileItem &profile_item = profile_configs.profiles[profile_index];
143 ProfileInputRange &input_range = profile_item.inputs[input_index];
144
145 auto one_profile_str = profiles_of_str[profile_index];
146 while (one_profile_str.front() == '[' || one_profile_str.front() == ' ') {
147 one_profile_str = one_profile_str.substr(1);
148 }
149 while (one_profile_str.back() == ']' || one_profile_str.back() == ' ') {
150 one_profile_str = one_profile_str.substr(0, one_profile_str.size() - 1);
151 }
152 auto dim_ranges = lite::StrSplit(one_profile_str, ",");
153
154 auto &input_shape = info.input_shape;
155 size_t dynamic_nbdims = std::count(input_shape.begin(), input_shape.end(), -1);
156 if (dim_ranges.size() != dynamic_nbdims) {
157 MS_LOG(ERROR) << "Number of dynamic dims in config '" << lite::kDynamicDimsKey << "' " << dim_ranges.size()
158 << " != that in '" << lite::kInputShapeKey << "' " << dynamic_nbdims << ".";
159 return false;
160 }
161 size_t range_index = 0;
162 input_range.min_dims = input_shape;
163 input_range.max_dims = input_shape;
164 for (size_t i = 0; i < input_shape.size(); ++i) {
165 if (input_shape[i] != -1) {
166 continue;
167 }
168 if (!ParseRangeStr(dim_ranges[range_index++], &input_range.min_dims[i], &input_range.max_dims[i])) {
169 return false;
170 }
171 }
172 input_range.opt_dims = input_range.min_dims; // default
173 }
174 }
175 return true;
176 }
177
ParseOptDims(const std::string & opt_dims_str,ProfileConfigs * profile_configs_ptr)178 bool ProfileParser::ParseOptDims(const std::string &opt_dims_str, ProfileConfigs *profile_configs_ptr) {
179 if (opt_dims_str.empty()) {
180 MS_LOG(ERROR) << "The option " << lite::kOptimizeDimsKey << " cannot be empty in [gpu_context]";
181 return false;
182 }
183 auto &profile_configs = *profile_configs_ptr;
184 auto inputs_of_str = lite::StrSplit(opt_dims_str, ";");
185 if (inputs_of_str.size() != profile_configs.input_infos.size()) {
186 MS_LOG(ERROR) << "The input count " << inputs_of_str.size() << " in '" << lite::kOptimizeDimsKey
187 << "' != the input count " << profile_configs.input_infos.size() << " '"
188 << " in '" << lite::kInputShapeKey;
189 return false;
190 }
191 // for every input
192 for (size_t input_index = 0; input_index != inputs_of_str.size(); ++input_index) {
193 auto &info = profile_configs.input_infos[input_index];
194 auto one_input_str = inputs_of_str[input_index];
195 auto profiles_of_str = lite::StrSplit(one_input_str, "],[");
196 if (profiles_of_str.size() != profile_configs.profiles.size()) {
197 MS_LOG(ERROR) << "The profile count " << profiles_of_str.size() << " of " << input_index << "th input in '"
198 << lite::kOptimizeDimsKey << "' != profile count " << profile_configs.profiles.size() << " in '"
199 << lite::kDynamicDimsKey << "'";
200 return false;
201 }
202 // for every profile in one input, parse input range: min, max
203 for (size_t profile_index = 0; profile_index != profiles_of_str.size(); ++profile_index) {
204 ProfileItem &profile_item = profile_configs.profiles[profile_index];
205 ProfileInputRange &input_range = profile_item.inputs[input_index];
206
207 auto one_profile_str = profiles_of_str[profile_index];
208 while (one_profile_str.front() == '[' || one_profile_str.front() == ' ') {
209 one_profile_str = one_profile_str.substr(1);
210 }
211 while (one_profile_str.back() == ']' || one_profile_str.back() == ' ') {
212 one_profile_str = one_profile_str.substr(0, one_profile_str.size() - 1);
213 }
214 auto opt_dims_vec = lite::StrSplit(one_profile_str, ",");
215
216 auto &input_shape = info.input_shape;
217 size_t dynamic_nbdims = std::count(input_shape.begin(), input_shape.end(), -1);
218 if (opt_dims_vec.size() != dynamic_nbdims) {
219 MS_LOG(ERROR) << "Number of dynamic dims in config '" << lite::kOptimizeDimsKey << "' " << opt_dims_vec.size()
220 << " != that in '" << lite::kInputShapeKey << "' " << dynamic_nbdims << ".";
221 return false;
222 }
223 size_t dynamic_index = 0;
224 input_range.opt_dims = input_shape;
225 for (size_t i = 0; i < input_shape.size(); ++i) {
226 if (input_shape[i] != -1) {
227 continue;
228 }
229 if (!ParseOptDimStr(opt_dims_vec[dynamic_index++], &input_range.opt_dims[i])) {
230 return false;
231 }
232 }
233 }
234 }
235 return true;
236 }
237
Parse(const std::map<std::string,std::string> & context,bool require_opt_when_dym,ProfileConfigs * profile_configs_ptr)238 bool ProfileParser::Parse(const std::map<std::string, std::string> &context, bool require_opt_when_dym,
239 ProfileConfigs *profile_configs_ptr) {
240 if (profile_configs_ptr == nullptr) {
241 return false;
242 }
243 auto &profile_configs = *profile_configs_ptr;
244 auto input_shapes = GetOption(context, lite::kInputShapeKey, "");
245 auto dynamic_dims = GetOption(context, lite::kDynamicDimsKey, "");
246 auto opt_dims = GetOption(context, lite::kOptimizeDimsKey, "");
247 if (input_shapes.empty() && dynamic_dims.empty() && opt_dims.empty()) {
248 MS_LOG(INFO) << "Do not found config of input range('" << lite::kInputShapeKey << "')";
249 return true;
250 }
251 if (input_shapes.empty()) {
252 MS_LOG(ERROR) << "Config of '" << lite::kInputShapeKey << " cannot be empty when '" << lite::kInputShapeKey
253 << "' or '" << lite::kOptimizeDimsKey << "' is not empty";
254 return false;
255 }
256 if (!ParseInputShape(input_shapes, &profile_configs)) {
257 MS_LOG(ERROR) << "parse input shape failed.";
258 return false;
259 }
260 if (dynamic_dims.empty()) {
261 ProfileItem profile_item;
262 for (size_t i = 0; i < profile_configs.input_infos.size(); i++) {
263 if (profile_configs.input_infos[i].is_dynamic_shape) {
264 MS_LOG(ERROR) << "Config of '" << lite::kDynamicDimsKey << "' cannot be empty when " << lite::kInputShapeKey
265 << " is dynamic";
266 return false;
267 }
268 auto &input_shape = profile_configs.input_infos[i].input_shape;
269 ProfileInputRange input_range;
270 input_range.min_dims = input_shape;
271 input_range.max_dims = input_shape;
272 input_range.opt_dims = input_shape;
273 profile_item.inputs.push_back(input_range);
274 }
275 profile_configs.profiles.push_back(profile_item);
276 return true;
277 }
278 if (!ParseDynamicDims(dynamic_dims, &profile_configs)) {
279 MS_LOG(ERROR) << "parse dynamic dims failed.";
280 return false;
281 }
282 if (require_opt_when_dym) {
283 if (!ParseOptDims(opt_dims, &profile_configs)) {
284 MS_LOG(ERROR) << "parse optimization dims failed.";
285 return false;
286 }
287 }
288 return true;
289 }
290
ReorderByInputNames(const std::vector<std::string> & input_names,ProfileConfigs * profile_configs)291 bool ProfileParser::ReorderByInputNames(const std::vector<std::string> &input_names, ProfileConfigs *profile_configs) {
292 if (input_names.size() != profile_configs->input_infos.size()) {
293 MS_LOG(ERROR) << "Input name size " << input_names.size() << " != profile config input size "
294 << profile_configs->input_infos.size();
295 return false;
296 }
297 ProfileConfigs new_profile_configs = *profile_configs;
298 auto &input_infos = profile_configs->input_infos;
299 auto &profiles = profile_configs->profiles;
300
301 for (size_t input_index = 0; input_index < input_names.size(); input_index++) {
302 const auto &input_name = input_names[input_index];
303 size_t i = 0;
304 for (; i < input_infos.size(); i++) {
305 if (input_infos[i].name == input_name) {
306 new_profile_configs.input_infos[input_index] = input_infos[i];
307 for (size_t profile_index = 0; profile_index < profiles.size(); profile_index++) {
308 new_profile_configs.profiles[profile_index].inputs[input_index] = profiles[profile_index].inputs[i];
309 }
310 break;
311 }
312 }
313 if (i >= input_infos.size()) {
314 MS_LOG(ERROR) << "Cannot find input " << input_name << " in profile '" << lite::kInputShapeKey << "' config";
315 return false;
316 }
317 }
318 *profile_configs = new_profile_configs;
319 return true;
320 }
321
GetOption(const std::map<std::string,std::string> & context,const std::string & option,const std::string & default_val)322 std::string ProfileParser::GetOption(const std::map<std::string, std::string> &context, const std::string &option,
323 const std::string &default_val) {
324 auto it = context.find(option);
325 if (it == context.end()) {
326 return default_val;
327 }
328 return it->second;
329 }
330 } // namespace mindspore
331