1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "extendrt/delegate/ascend_ge/ge_dynamic_utils.h"
17 #include "common/utils.h"
18 #include "common/common.h"
19 #include "extendrt/delegate/ascend_ge/ge_utils.h"
20
21 namespace mindspore {
IsDynamicInputShapes(const std::vector<ShapeVector> & input_shapes)22 bool GeDynamicUtils::IsDynamicInputShapes(const std::vector<ShapeVector> &input_shapes) {
23 return std::any_of(input_shapes.begin(), input_shapes.end(), [](const ShapeVector &shape) {
24 return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
25 });
26 }
27
IsDynamicInputShapes(const std::vector<std::pair<std::string,ShapeVector>> & input_shapes)28 bool GeDynamicUtils::IsDynamicInputShapes(const std::vector<std::pair<std::string, ShapeVector>> &input_shapes) {
29 return std::any_of(input_shapes.begin(), input_shapes.end(), [](const auto &item) {
30 auto &shape = item.second;
31 return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
32 });
33 }
34
GetGraphInputShapes(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<GeDynamicShapeInfo> * input_shape_ptr,std::string * input_shape_str_ptr)35 bool GeDynamicUtils::GetGraphInputShapes(const std::shared_ptr<mindspore::Context> &context,
36 const ConfigInfos &config_infos,
37 std::vector<GeDynamicShapeInfo> *input_shape_ptr,
38 std::string *input_shape_str_ptr) {
39 if (input_shape_ptr == nullptr) {
40 MS_LOG(ERROR) << "Input argument input_shape_ptr is nullptr";
41 return false;
42 }
43 // get input shape from AscendDeviceInfo
44 auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
45 if (ascend_info == nullptr) {
46 MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
47 return false;
48 }
49 auto input_shape_str = ascend_info->GetInputShape();
50 if (!input_shape_str.empty()) {
51 MS_LOG(INFO) << "Find input shape " << input_shape_str
52 << " in AscendDeviceInfo, which may come from [ascend_context] or [acl_option_cfg_param]";
53 }
54 // get options from [ge_graph_options]
55 auto section_it = config_infos.find(lite::kGeGraphOptionsSection);
56 if (section_it != config_infos.end()) {
57 auto &options = section_it->second;
58 auto option_it = options.find("ge.inputShape");
59 if (option_it != options.end()) {
60 input_shape_str = option_it->second;
61 MS_LOG(INFO) << "Find ge.inputShape " << input_shape_str << " in " << lite::kGeGraphOptionsSection;
62 }
63 }
64 // get options from [aoe_tuning_options]
65 section_it = config_infos.find(lite::kAoeTuningOptionsSection);
66 if (section_it != config_infos.end()) {
67 auto &options = section_it->second;
68 auto option_it = options.find("input_shape");
69 if (option_it != options.end()) {
70 input_shape_str = option_it->second;
71 MS_LOG(INFO) << "Find input_shape " << input_shape_str << " in " << lite::kAoeTuningOptionsSection;
72 }
73 }
74 if (!input_shape_str.empty()) {
75 auto input_shape_strs = lite::StrSplit(input_shape_str, ";");
76 std::vector<GeDynamicShapeInfo> input_shapes;
77 for (auto &shape_item_str : input_shape_strs) {
78 GeDynamicShapeInfo dynamic_shape_info;
79 auto split_pos = shape_item_str.rfind(":");
80 if (split_pos == std::string::npos) {
81 MS_LOG(ERROR) << "The input_shape should be in format of name:shape;name:shape, but got " << input_shape_str;
82 return false;
83 }
84 std::string name = shape_item_str.substr(0, split_pos);
85 std::string shape_str = shape_item_str.substr(split_pos + 1);
86 if (!lite::ParseShapeStr(shape_str, &dynamic_shape_info.shape)) {
87 MS_LOG(ERROR) << "Invalid input shape dims: " << shape_str << ", input_shape: " << input_shape_str;
88 return false;
89 }
90 dynamic_shape_info.name = name;
91 dynamic_shape_info.shape_str = shape_str;
92 input_shapes.push_back(dynamic_shape_info);
93 }
94 if (input_shape_str_ptr != nullptr) {
95 *input_shape_str_ptr = input_shape_str;
96 }
97 *input_shape_ptr = input_shapes;
98 return true;
99 }
100 return true;
101 }
102
UpdateGraphInputShapes(const std::shared_ptr<mindspore::Context> & context,ConfigInfos * config_infos,const std::string & input_shape)103 void GeDynamicUtils::UpdateGraphInputShapes(const std::shared_ptr<mindspore::Context> &context,
104 ConfigInfos *config_infos, const std::string &input_shape) {
105 if (config_infos == nullptr) {
106 return;
107 }
108 // get options from [aoe_tuning_options]
109 auto section_it = config_infos->find(lite::kAoeTuningOptionsSection);
110 if (section_it != config_infos->end()) {
111 auto &options = section_it->second;
112 auto option_it = options.find("input_shape");
113 if (option_it != options.end()) {
114 auto &input_shape_str = option_it->second;
115 if (!input_shape_str.empty()) {
116 MS_LOG(INFO) << "Find input_shape " << input_shape_str << " in " << lite::kAoeTuningOptionsSection
117 << " and updated to " << input_shape;
118 input_shape_str = input_shape;
119 return;
120 }
121 }
122 }
123 // get options from [ge_graph_options]
124 section_it = config_infos->find(lite::kGeGraphOptionsSection);
125 if (section_it != config_infos->end()) {
126 auto &options = section_it->second;
127 auto option_it = options.find("ge.inputShape");
128 if (option_it != options.end()) {
129 auto &input_shape_str = option_it->second;
130 if (!input_shape_str.empty()) {
131 MS_LOG(INFO) << "Find ge.inputShape " << input_shape_str << " in " << lite::kGeGraphOptionsSection
132 << " and updated to " << input_shape;
133 input_shape_str = input_shape;
134 return;
135 }
136 }
137 }
138 // get input shape from AscendDeviceInfo
139 auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
140 if (ascend_info == nullptr) {
141 MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
142 return;
143 }
144 auto input_shape_str = ascend_info->GetInputShape();
145 if (!input_shape_str.empty()) {
146 MS_LOG(INFO) << "Find input shape " << input_shape_str
147 << " in AscendDeviceInfo, which may come from [ascend_context] or [acl_option_cfg_param]"
148 << " and updated to " << input_shape;
149 ascend_info->SetInputShape(input_shape);
150 }
151 }
152
GetDynamicBatchSize(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<int64_t> * dynamic_batch_size_ptr)153 bool GeDynamicUtils::GetDynamicBatchSize(const std::shared_ptr<mindspore::Context> &context,
154 const ConfigInfos &config_infos,
155 std::vector<int64_t> *dynamic_batch_size_ptr) {
156 if (dynamic_batch_size_ptr == nullptr) {
157 return false;
158 }
159 // get input shape from AscendDeviceInfo
160 auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
161 if (ascend_info == nullptr) {
162 MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
163 return false;
164 }
165 auto dynamic_batch_size = ascend_info->GetDynamicBatchSize();
166 // get options from [acl_option_cfg_param]
167 auto section_it = config_infos.find(lite::kAclOptionParam);
168 if (section_it != config_infos.end()) {
169 auto &options = section_it->second;
170 auto option_it = options.find("dynamic_batch_size");
171 if (option_it != options.end()) {
172 dynamic_batch_size = option_it->second;
173 MS_LOG(INFO) << "Find dynamic_batch_size " << dynamic_batch_size << " in " << lite::kAclOptionParam;
174 }
175 }
176 // get options from [aoe_tuning_options]
177 section_it = config_infos.find(lite::kAoeTuningOptionsSection);
178 if (section_it != config_infos.end()) {
179 auto &options = section_it->second;
180 auto option_it = options.find("dynamic_batch_size");
181 if (option_it != options.end()) {
182 dynamic_batch_size = option_it->second;
183 MS_LOG(INFO) << "Find dynamic_batch_size " << dynamic_batch_size << " in " << lite::kAoeTuningOptionsSection;
184 }
185 }
186 if (dynamic_batch_size.empty()) {
187 MS_LOG(INFO) << "Not found dynamic_batch_size in AscendDeviceInfo or config file";
188 return true;
189 }
190 // parse dynamic_batch_size
191 std::vector<int64_t> dynamic_batch_size_nums;
192 if (!lite::ParseShapeStr(dynamic_batch_size, &dynamic_batch_size_nums)) {
193 MS_LOG(ERROR) << "Invalid dynamic_batch_size " << dynamic_batch_size;
194 return false;
195 }
196 *dynamic_batch_size_ptr = dynamic_batch_size_nums;
197 return true;
198 }
199
GetDynamicImageSize(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<std::vector<int64_t>> * dynamic_image_size_ptr)200 bool GeDynamicUtils::GetDynamicImageSize(const std::shared_ptr<mindspore::Context> &context,
201 const ConfigInfos &config_infos,
202 std::vector<std::vector<int64_t>> *dynamic_image_size_ptr) {
203 if (dynamic_image_size_ptr == nullptr) {
204 return false;
205 }
206 // get input shape from AscendDeviceInfo
207 auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
208 if (ascend_info == nullptr) {
209 MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
210 return false;
211 }
212 auto dynamic_image_size = ascend_info->GetDynamicImageSize();
213 // get options from [acl_option_cfg_param]
214 auto section_it = config_infos.find(lite::kAclOptionParam);
215 if (section_it != config_infos.end()) {
216 auto &options = section_it->second;
217 auto option_it = options.find("dynamic_image_size");
218 if (option_it != options.end()) {
219 dynamic_image_size = option_it->second;
220 MS_LOG(INFO) << "Find dynamic_image_size " << dynamic_image_size << " in " << lite::kAclOptionParam;
221 }
222 }
223 // get options from [aoe_tuning_options]
224 section_it = config_infos.find(lite::kAoeTuningOptionsSection);
225 if (section_it != config_infos.end()) {
226 auto &options = section_it->second;
227 auto option_it = options.find("dynamic_image_size");
228 if (option_it != options.end()) {
229 dynamic_image_size = option_it->second;
230 MS_LOG(INFO) << "Find dynamic_image_size " << dynamic_image_size << " in " << lite::kAoeTuningOptionsSection;
231 }
232 }
233 if (dynamic_image_size.empty()) {
234 MS_LOG(INFO) << "Not found dynamic_image_size in AscendDeviceInfo or config file";
235 return true;
236 }
237 // parse dynamic_image_size
238 auto dynamic_image_strs = lite::StrSplit(dynamic_image_size, ";");
239 if (dynamic_image_strs.empty()) {
240 MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
241 return false;
242 }
243 std::vector<std::vector<int64_t>> dynamic_image_size_nums;
244 for (auto &item : dynamic_image_strs) {
245 std::vector<int64_t> real_dims;
246 if (!lite::ParseShapeStr(item, &real_dims)) {
247 MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
248 return false;
249 }
250 constexpr size_t hw_dim_count = 2;
251 if (real_dims.size() != hw_dim_count) {
252 MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
253 return false;
254 }
255 dynamic_image_size_nums.push_back(real_dims);
256 }
257 *dynamic_image_size_ptr = dynamic_image_size_nums;
258 return true;
259 }
260
GetDynamicDims(const std::shared_ptr<mindspore::Context> &,const ConfigInfos & config_infos,std::vector<std::vector<int64_t>> * dynamic_dims_ptr)261 bool GeDynamicUtils::GetDynamicDims(const std::shared_ptr<mindspore::Context> &, const ConfigInfos &config_infos,
262 std::vector<std::vector<int64_t>> *dynamic_dims_ptr) {
263 if (dynamic_dims_ptr == nullptr) {
264 return false;
265 }
266 std::string dynamic_dims;
267 // get options from [acl_option_cfg_param]
268 auto section_it = config_infos.find(lite::kAclOptionParam);
269 if (section_it != config_infos.end()) {
270 auto &options = section_it->second;
271 auto option_it = options.find("dynamic_dims");
272 if (option_it != options.end()) {
273 dynamic_dims = option_it->second;
274 MS_LOG(INFO) << "Find dynamic_dims " << dynamic_dims << " in " << lite::kAclOptionParam;
275 }
276 }
277 // get options from [ge_graph_options]
278 section_it = config_infos.find(lite::kGeGraphOptionsSection);
279 if (section_it != config_infos.end()) {
280 auto &options = section_it->second;
281 auto option_it = options.find("ge.dynamicDims");
282 if (option_it != options.end()) {
283 dynamic_dims = option_it->second;
284 MS_LOG(INFO) << "Find ge.dynamicDims " << dynamic_dims << " in " << lite::kGeGraphOptionsSection;
285 }
286 }
287 // get options from [aoe_tuning_options]
288 section_it = config_infos.find(lite::kAoeTuningOptionsSection);
289 if (section_it != config_infos.end()) {
290 auto &options = section_it->second;
291 auto option_it = options.find("dynamic_dims");
292 if (option_it != options.end()) {
293 dynamic_dims = option_it->second;
294 MS_LOG(INFO) << "Find dynamic_dims " << dynamic_dims << " in " << lite::kAoeTuningOptionsSection;
295 }
296 }
297 if (dynamic_dims.empty()) {
298 MS_LOG(INFO) << "Not found dynamic_dims in AscendDeviceInfo or config file";
299 return true;
300 }
301 // parse dynamic_dims
302 auto dynamic_dims_strs = lite::StrSplit(dynamic_dims, ";");
303 if (dynamic_dims_strs.empty()) {
304 MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims;
305 return false;
306 }
307 std::vector<std::vector<int64_t>> dynamic_dims_nums;
308 for (auto &item : dynamic_dims_strs) {
309 std::vector<int64_t> real_dims;
310 if (!lite::ParseShapeStr(item, &real_dims)) {
311 MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims;
312 return false;
313 }
314 if (!dynamic_dims_nums.empty() && dynamic_dims_nums[0].size() != real_dims.size()) {
315 MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims << ", dims count in all dynamic dims should be same";
316 return false;
317 }
318 dynamic_dims_nums.push_back(real_dims);
319 }
320 *dynamic_dims_ptr = dynamic_dims_nums;
321 return true;
322 }
323
CheckDynamicDims(const std::vector<int64_t> & dynamic_batch_size,const std::vector<std::vector<int64_t>> & dynamic_image_size,const std::vector<std::vector<int64_t>> & dynamic_dims,const std::string & input_shape_str)324 static bool CheckDynamicDims(const std::vector<int64_t> &dynamic_batch_size,
325 const std::vector<std::vector<int64_t>> &dynamic_image_size,
326 const std::vector<std::vector<int64_t>> &dynamic_dims,
327 const std::string &input_shape_str) {
328 if (!dynamic_dims.empty()) {
329 if (!dynamic_image_size.empty() || !dynamic_batch_size.empty()) {
330 MS_LOG(ERROR) << "Option dynamic_dims, dynamic_image_size and dynamic_batch_size cannot exist simultaneously.";
331 return false;
332 }
333 } else if (!dynamic_image_size.empty()) {
334 if (!dynamic_batch_size.empty()) {
335 MS_LOG(ERROR) << "Option dynamic_dims, dynamic_image_size and dynamic_batch_size cannot exist simultaneously.";
336 return false;
337 }
338 } else if (dynamic_batch_size.empty()) {
339 MS_LOG(ERROR) << "Cannot find dynamic_dims, dynamic_batch_size or dynamic_image_size in AscendDeviceInfo or "
340 "config file while there are dynamic dims in input shapes "
341 << input_shape_str;
342 return false;
343 }
344 return true;
345 }
346
SetDynamicDimsRealValue(const std::vector<int64_t> & dynamic_batch_size,const std::vector<std::vector<int64_t>> & dynamic_image_size,const std::vector<std::vector<int64_t>> & dynamic_dims,const std::string & input_shape_str,std::vector<std::pair<std::string,ShapeVector>> * real_shapes_ptr)347 static bool SetDynamicDimsRealValue(const std::vector<int64_t> &dynamic_batch_size,
348 const std::vector<std::vector<int64_t>> &dynamic_image_size,
349 const std::vector<std::vector<int64_t>> &dynamic_dims,
350 const std::string &input_shape_str,
351 std::vector<std::pair<std::string, ShapeVector>> *real_shapes_ptr) {
352 auto &real_shapes = *real_shapes_ptr;
353 if (!dynamic_dims.empty()) {
354 size_t dyn_count = 0;
355 for (auto &input_shape : real_shapes) {
356 for (auto &dim : input_shape.second) {
357 if (dim == -1) {
358 if (dyn_count >= dynamic_dims[0].size()) {
359 MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims
360 << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
361 << ", input shape: " << input_shape_str;
362 return false;
363 }
364 dim = dynamic_dims[0][dyn_count];
365 dyn_count++;
366 }
367 }
368 }
369 } else if (!dynamic_image_size.empty()) {
370 for (auto &input_shape : real_shapes) {
371 size_t dyn_count = 0;
372 for (auto &dim : input_shape.second) {
373 if (dim == -1) {
374 if (dyn_count >= dynamic_image_size[0].size()) {
375 MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size
376 << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
377 << ", input shape: " << input_shape_str;
378 return false;
379 }
380 dim = dynamic_image_size[0][dyn_count];
381 dyn_count++;
382 }
383 }
384 }
385 } else { // dynamic_batch_size
386 for (auto &input_shape : real_shapes) {
387 size_t dyn_count = 0;
388 for (auto &dim : input_shape.second) {
389 if (dim == -1) {
390 if (dyn_count >= 1) {
391 MS_LOG(ERROR) << "Invalid dynamic_batch_size " << dynamic_batch_size
392 << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
393 << ", input shape: " << input_shape_str;
394 return false;
395 }
396 dim = dynamic_batch_size[0];
397 dyn_count++;
398 }
399 }
400 }
401 }
402 return true;
403 }
404
GetGraphOneRealShapes(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<std::pair<std::string,ShapeVector>> * real_shapes_ptr,std::string * input_shape_str_ptr)405 bool GeDynamicUtils::GetGraphOneRealShapes(const std::shared_ptr<mindspore::Context> &context,
406 const ConfigInfos &config_infos,
407 std::vector<std::pair<std::string, ShapeVector>> *real_shapes_ptr,
408 std::string *input_shape_str_ptr) {
409 if (real_shapes_ptr == nullptr) {
410 MS_LOG(ERROR) << "Argument input_shapes_ptr cannot be nullptr";
411 return false;
412 }
413 std::string input_shape_str;
414 std::vector<GeDynamicShapeInfo> input_shapes;
415 auto ret = GetGraphInputShapes(context, config_infos, &input_shapes, &input_shape_str);
416 if (!ret) {
417 MS_LOG(ERROR) << "Failed to get input shape for AscendDeviceInfo or config file";
418 return false;
419 }
420 if (input_shape_str_ptr != nullptr) {
421 *input_shape_str_ptr = input_shape_str;
422 }
423 if (input_shapes.empty()) {
424 MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
425 return true;
426 }
427 std::vector<std::pair<std::string, ShapeVector>> &real_shapes = *real_shapes_ptr;
428 for (auto &item : input_shapes) {
429 std::vector<int64_t> shape;
430 for (auto &dim : item.shape) {
431 if (dim.dim == -1 && dim.min != dim.max) {
432 MS_LOG(ERROR) << "Cannot get one real shape because of shape range, shape: " << input_shape_str;
433 return false;
434 }
435 shape.push_back(dim.dim);
436 }
437 real_shapes.push_back(std::make_pair(item.name, shape));
438 }
439 if (!IsDynamicInputShapes(real_shapes)) {
440 MS_LOG(INFO) << "The dims number in all of input shapes are more than 0, return input shape in AscendDeviceInfo or "
441 "config file";
442 return true;
443 }
444 std::vector<int64_t> dynamic_batch_size;
445 if (!GetDynamicBatchSize(context, config_infos, &dynamic_batch_size)) {
446 return false;
447 }
448 std::vector<std::vector<int64_t>> dynamic_image_size;
449 if (!GetDynamicImageSize(context, config_infos, &dynamic_image_size)) {
450 return false;
451 }
452 std::vector<std::vector<int64_t>> dynamic_dims;
453 if (!GetDynamicDims(context, config_infos, &dynamic_dims)) {
454 return false;
455 }
456 if (!CheckDynamicDims(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape_str)) {
457 return false;
458 }
459 if (!SetDynamicDimsRealValue(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape_str,
460 real_shapes_ptr)) {
461 return false;
462 }
463 return true;
464 }
465 } // namespace mindspore
466