• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "extendrt/delegate/ascend_ge/ge_dynamic_utils.h"
17 #include "common/utils.h"
18 #include "common/common.h"
19 #include "extendrt/delegate/ascend_ge/ge_utils.h"
20 
21 namespace mindspore {
IsDynamicInputShapes(const std::vector<ShapeVector> & input_shapes)22 bool GeDynamicUtils::IsDynamicInputShapes(const std::vector<ShapeVector> &input_shapes) {
23   return std::any_of(input_shapes.begin(), input_shapes.end(), [](const ShapeVector &shape) {
24     return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
25   });
26 }
27 
IsDynamicInputShapes(const std::vector<std::pair<std::string,ShapeVector>> & input_shapes)28 bool GeDynamicUtils::IsDynamicInputShapes(const std::vector<std::pair<std::string, ShapeVector>> &input_shapes) {
29   return std::any_of(input_shapes.begin(), input_shapes.end(), [](const auto &item) {
30     auto &shape = item.second;
31     return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
32   });
33 }
34 
GetGraphInputShapes(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<GeDynamicShapeInfo> * input_shape_ptr,std::string * input_shape_str_ptr)35 bool GeDynamicUtils::GetGraphInputShapes(const std::shared_ptr<mindspore::Context> &context,
36                                          const ConfigInfos &config_infos,
37                                          std::vector<GeDynamicShapeInfo> *input_shape_ptr,
38                                          std::string *input_shape_str_ptr) {
39   if (input_shape_ptr == nullptr) {
40     MS_LOG(ERROR) << "Input argument input_shape_ptr is nullptr";
41     return false;
42   }
43   // get input shape from AscendDeviceInfo
44   auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
45   if (ascend_info == nullptr) {
46     MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
47     return false;
48   }
49   auto input_shape_str = ascend_info->GetInputShape();
50   if (!input_shape_str.empty()) {
51     MS_LOG(INFO) << "Find input shape " << input_shape_str
52                  << " in AscendDeviceInfo, which may come from [ascend_context] or [acl_option_cfg_param]";
53   }
54   // get options from [ge_graph_options]
55   auto section_it = config_infos.find(lite::kGeGraphOptionsSection);
56   if (section_it != config_infos.end()) {
57     auto &options = section_it->second;
58     auto option_it = options.find("ge.inputShape");
59     if (option_it != options.end()) {
60       input_shape_str = option_it->second;
61       MS_LOG(INFO) << "Find ge.inputShape " << input_shape_str << " in " << lite::kGeGraphOptionsSection;
62     }
63   }
64   // get options from [aoe_tuning_options]
65   section_it = config_infos.find(lite::kAoeTuningOptionsSection);
66   if (section_it != config_infos.end()) {
67     auto &options = section_it->second;
68     auto option_it = options.find("input_shape");
69     if (option_it != options.end()) {
70       input_shape_str = option_it->second;
71       MS_LOG(INFO) << "Find input_shape " << input_shape_str << " in " << lite::kAoeTuningOptionsSection;
72     }
73   }
74   if (!input_shape_str.empty()) {
75     auto input_shape_strs = lite::StrSplit(input_shape_str, ";");
76     std::vector<GeDynamicShapeInfo> input_shapes;
77     for (auto &shape_item_str : input_shape_strs) {
78       GeDynamicShapeInfo dynamic_shape_info;
79       auto split_pos = shape_item_str.rfind(":");
80       if (split_pos == std::string::npos) {
81         MS_LOG(ERROR) << "The input_shape should be in format of name:shape;name:shape, but got " << input_shape_str;
82         return false;
83       }
84       std::string name = shape_item_str.substr(0, split_pos);
85       std::string shape_str = shape_item_str.substr(split_pos + 1);
86       if (!lite::ParseShapeStr(shape_str, &dynamic_shape_info.shape)) {
87         MS_LOG(ERROR) << "Invalid input shape dims: " << shape_str << ", input_shape: " << input_shape_str;
88         return false;
89       }
90       dynamic_shape_info.name = name;
91       dynamic_shape_info.shape_str = shape_str;
92       input_shapes.push_back(dynamic_shape_info);
93     }
94     if (input_shape_str_ptr != nullptr) {
95       *input_shape_str_ptr = input_shape_str;
96     }
97     *input_shape_ptr = input_shapes;
98     return true;
99   }
100   return true;
101 }
102 
UpdateGraphInputShapes(const std::shared_ptr<mindspore::Context> & context,ConfigInfos * config_infos,const std::string & input_shape)103 void GeDynamicUtils::UpdateGraphInputShapes(const std::shared_ptr<mindspore::Context> &context,
104                                             ConfigInfos *config_infos, const std::string &input_shape) {
105   if (config_infos == nullptr) {
106     return;
107   }
108   // get options from [aoe_tuning_options]
109   auto section_it = config_infos->find(lite::kAoeTuningOptionsSection);
110   if (section_it != config_infos->end()) {
111     auto &options = section_it->second;
112     auto option_it = options.find("input_shape");
113     if (option_it != options.end()) {
114       auto &input_shape_str = option_it->second;
115       if (!input_shape_str.empty()) {
116         MS_LOG(INFO) << "Find input_shape " << input_shape_str << " in " << lite::kAoeTuningOptionsSection
117                      << " and updated to " << input_shape;
118         input_shape_str = input_shape;
119         return;
120       }
121     }
122   }
123   // get options from [ge_graph_options]
124   section_it = config_infos->find(lite::kGeGraphOptionsSection);
125   if (section_it != config_infos->end()) {
126     auto &options = section_it->second;
127     auto option_it = options.find("ge.inputShape");
128     if (option_it != options.end()) {
129       auto &input_shape_str = option_it->second;
130       if (!input_shape_str.empty()) {
131         MS_LOG(INFO) << "Find ge.inputShape " << input_shape_str << " in " << lite::kGeGraphOptionsSection
132                      << " and updated to " << input_shape;
133         input_shape_str = input_shape;
134         return;
135       }
136     }
137   }
138   // get input shape from AscendDeviceInfo
139   auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
140   if (ascend_info == nullptr) {
141     MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
142     return;
143   }
144   auto input_shape_str = ascend_info->GetInputShape();
145   if (!input_shape_str.empty()) {
146     MS_LOG(INFO) << "Find input shape " << input_shape_str
147                  << " in AscendDeviceInfo, which may come from [ascend_context] or [acl_option_cfg_param]"
148                  << " and updated to " << input_shape;
149     ascend_info->SetInputShape(input_shape);
150   }
151 }
152 
GetDynamicBatchSize(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<int64_t> * dynamic_batch_size_ptr)153 bool GeDynamicUtils::GetDynamicBatchSize(const std::shared_ptr<mindspore::Context> &context,
154                                          const ConfigInfos &config_infos,
155                                          std::vector<int64_t> *dynamic_batch_size_ptr) {
156   if (dynamic_batch_size_ptr == nullptr) {
157     return false;
158   }
159   // get input shape from AscendDeviceInfo
160   auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
161   if (ascend_info == nullptr) {
162     MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
163     return false;
164   }
165   auto dynamic_batch_size = ascend_info->GetDynamicBatchSize();
166   // get options from [acl_option_cfg_param]
167   auto section_it = config_infos.find(lite::kAclOptionParam);
168   if (section_it != config_infos.end()) {
169     auto &options = section_it->second;
170     auto option_it = options.find("dynamic_batch_size");
171     if (option_it != options.end()) {
172       dynamic_batch_size = option_it->second;
173       MS_LOG(INFO) << "Find dynamic_batch_size " << dynamic_batch_size << " in " << lite::kAclOptionParam;
174     }
175   }
176   // get options from [aoe_tuning_options]
177   section_it = config_infos.find(lite::kAoeTuningOptionsSection);
178   if (section_it != config_infos.end()) {
179     auto &options = section_it->second;
180     auto option_it = options.find("dynamic_batch_size");
181     if (option_it != options.end()) {
182       dynamic_batch_size = option_it->second;
183       MS_LOG(INFO) << "Find dynamic_batch_size " << dynamic_batch_size << " in " << lite::kAoeTuningOptionsSection;
184     }
185   }
186   if (dynamic_batch_size.empty()) {
187     MS_LOG(INFO) << "Not found dynamic_batch_size in AscendDeviceInfo or config file";
188     return true;
189   }
190   // parse dynamic_batch_size
191   std::vector<int64_t> dynamic_batch_size_nums;
192   if (!lite::ParseShapeStr(dynamic_batch_size, &dynamic_batch_size_nums)) {
193     MS_LOG(ERROR) << "Invalid dynamic_batch_size " << dynamic_batch_size;
194     return false;
195   }
196   *dynamic_batch_size_ptr = dynamic_batch_size_nums;
197   return true;
198 }
199 
GetDynamicImageSize(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<std::vector<int64_t>> * dynamic_image_size_ptr)200 bool GeDynamicUtils::GetDynamicImageSize(const std::shared_ptr<mindspore::Context> &context,
201                                          const ConfigInfos &config_infos,
202                                          std::vector<std::vector<int64_t>> *dynamic_image_size_ptr) {
203   if (dynamic_image_size_ptr == nullptr) {
204     return false;
205   }
206   // get input shape from AscendDeviceInfo
207   auto ascend_info = GeUtils::GetAscendDeviceInfo(context);
208   if (ascend_info == nullptr) {
209     MS_LOG(ERROR) << "Cannot find AscendDeviceInfo in context";
210     return false;
211   }
212   auto dynamic_image_size = ascend_info->GetDynamicImageSize();
213   // get options from [acl_option_cfg_param]
214   auto section_it = config_infos.find(lite::kAclOptionParam);
215   if (section_it != config_infos.end()) {
216     auto &options = section_it->second;
217     auto option_it = options.find("dynamic_image_size");
218     if (option_it != options.end()) {
219       dynamic_image_size = option_it->second;
220       MS_LOG(INFO) << "Find dynamic_image_size " << dynamic_image_size << " in " << lite::kAclOptionParam;
221     }
222   }
223   // get options from [aoe_tuning_options]
224   section_it = config_infos.find(lite::kAoeTuningOptionsSection);
225   if (section_it != config_infos.end()) {
226     auto &options = section_it->second;
227     auto option_it = options.find("dynamic_image_size");
228     if (option_it != options.end()) {
229       dynamic_image_size = option_it->second;
230       MS_LOG(INFO) << "Find dynamic_image_size " << dynamic_image_size << " in " << lite::kAoeTuningOptionsSection;
231     }
232   }
233   if (dynamic_image_size.empty()) {
234     MS_LOG(INFO) << "Not found dynamic_image_size in AscendDeviceInfo or config file";
235     return true;
236   }
237   // parse dynamic_image_size
238   auto dynamic_image_strs = lite::StrSplit(dynamic_image_size, ";");
239   if (dynamic_image_strs.empty()) {
240     MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
241     return false;
242   }
243   std::vector<std::vector<int64_t>> dynamic_image_size_nums;
244   for (auto &item : dynamic_image_strs) {
245     std::vector<int64_t> real_dims;
246     if (!lite::ParseShapeStr(item, &real_dims)) {
247       MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
248       return false;
249     }
250     constexpr size_t hw_dim_count = 2;
251     if (real_dims.size() != hw_dim_count) {
252       MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size;
253       return false;
254     }
255     dynamic_image_size_nums.push_back(real_dims);
256   }
257   *dynamic_image_size_ptr = dynamic_image_size_nums;
258   return true;
259 }
260 
GetDynamicDims(const std::shared_ptr<mindspore::Context> &,const ConfigInfos & config_infos,std::vector<std::vector<int64_t>> * dynamic_dims_ptr)261 bool GeDynamicUtils::GetDynamicDims(const std::shared_ptr<mindspore::Context> &, const ConfigInfos &config_infos,
262                                     std::vector<std::vector<int64_t>> *dynamic_dims_ptr) {
263   if (dynamic_dims_ptr == nullptr) {
264     return false;
265   }
266   std::string dynamic_dims;
267   // get options from [acl_option_cfg_param]
268   auto section_it = config_infos.find(lite::kAclOptionParam);
269   if (section_it != config_infos.end()) {
270     auto &options = section_it->second;
271     auto option_it = options.find("dynamic_dims");
272     if (option_it != options.end()) {
273       dynamic_dims = option_it->second;
274       MS_LOG(INFO) << "Find dynamic_dims " << dynamic_dims << " in " << lite::kAclOptionParam;
275     }
276   }
277   // get options from [ge_graph_options]
278   section_it = config_infos.find(lite::kGeGraphOptionsSection);
279   if (section_it != config_infos.end()) {
280     auto &options = section_it->second;
281     auto option_it = options.find("ge.dynamicDims");
282     if (option_it != options.end()) {
283       dynamic_dims = option_it->second;
284       MS_LOG(INFO) << "Find ge.dynamicDims " << dynamic_dims << " in " << lite::kGeGraphOptionsSection;
285     }
286   }
287   // get options from [aoe_tuning_options]
288   section_it = config_infos.find(lite::kAoeTuningOptionsSection);
289   if (section_it != config_infos.end()) {
290     auto &options = section_it->second;
291     auto option_it = options.find("dynamic_dims");
292     if (option_it != options.end()) {
293       dynamic_dims = option_it->second;
294       MS_LOG(INFO) << "Find dynamic_dims " << dynamic_dims << " in " << lite::kAoeTuningOptionsSection;
295     }
296   }
297   if (dynamic_dims.empty()) {
298     MS_LOG(INFO) << "Not found dynamic_dims in AscendDeviceInfo or config file";
299     return true;
300   }
301   // parse dynamic_dims
302   auto dynamic_dims_strs = lite::StrSplit(dynamic_dims, ";");
303   if (dynamic_dims_strs.empty()) {
304     MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims;
305     return false;
306   }
307   std::vector<std::vector<int64_t>> dynamic_dims_nums;
308   for (auto &item : dynamic_dims_strs) {
309     std::vector<int64_t> real_dims;
310     if (!lite::ParseShapeStr(item, &real_dims)) {
311       MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims;
312       return false;
313     }
314     if (!dynamic_dims_nums.empty() && dynamic_dims_nums[0].size() != real_dims.size()) {
315       MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims << ", dims count in all dynamic dims should be same";
316       return false;
317     }
318     dynamic_dims_nums.push_back(real_dims);
319   }
320   *dynamic_dims_ptr = dynamic_dims_nums;
321   return true;
322 }
323 
CheckDynamicDims(const std::vector<int64_t> & dynamic_batch_size,const std::vector<std::vector<int64_t>> & dynamic_image_size,const std::vector<std::vector<int64_t>> & dynamic_dims,const std::string & input_shape_str)324 static bool CheckDynamicDims(const std::vector<int64_t> &dynamic_batch_size,
325                              const std::vector<std::vector<int64_t>> &dynamic_image_size,
326                              const std::vector<std::vector<int64_t>> &dynamic_dims,
327                              const std::string &input_shape_str) {
328   if (!dynamic_dims.empty()) {
329     if (!dynamic_image_size.empty() || !dynamic_batch_size.empty()) {
330       MS_LOG(ERROR) << "Option dynamic_dims, dynamic_image_size and dynamic_batch_size cannot exist simultaneously.";
331       return false;
332     }
333   } else if (!dynamic_image_size.empty()) {
334     if (!dynamic_batch_size.empty()) {
335       MS_LOG(ERROR) << "Option dynamic_dims, dynamic_image_size and dynamic_batch_size cannot exist simultaneously.";
336       return false;
337     }
338   } else if (dynamic_batch_size.empty()) {
339     MS_LOG(ERROR) << "Cannot find dynamic_dims, dynamic_batch_size or dynamic_image_size in AscendDeviceInfo or "
340                      "config file while there are dynamic dims in input shapes "
341                   << input_shape_str;
342     return false;
343   }
344   return true;
345 }
346 
SetDynamicDimsRealValue(const std::vector<int64_t> & dynamic_batch_size,const std::vector<std::vector<int64_t>> & dynamic_image_size,const std::vector<std::vector<int64_t>> & dynamic_dims,const std::string & input_shape_str,std::vector<std::pair<std::string,ShapeVector>> * real_shapes_ptr)347 static bool SetDynamicDimsRealValue(const std::vector<int64_t> &dynamic_batch_size,
348                                     const std::vector<std::vector<int64_t>> &dynamic_image_size,
349                                     const std::vector<std::vector<int64_t>> &dynamic_dims,
350                                     const std::string &input_shape_str,
351                                     std::vector<std::pair<std::string, ShapeVector>> *real_shapes_ptr) {
352   auto &real_shapes = *real_shapes_ptr;
353   if (!dynamic_dims.empty()) {
354     size_t dyn_count = 0;
355     for (auto &input_shape : real_shapes) {
356       for (auto &dim : input_shape.second) {
357         if (dim == -1) {
358           if (dyn_count >= dynamic_dims[0].size()) {
359             MS_LOG(ERROR) << "Invalid dynamic_dims " << dynamic_dims
360                           << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
361                           << ", input shape: " << input_shape_str;
362             return false;
363           }
364           dim = dynamic_dims[0][dyn_count];
365           dyn_count++;
366         }
367       }
368     }
369   } else if (!dynamic_image_size.empty()) {
370     for (auto &input_shape : real_shapes) {
371       size_t dyn_count = 0;
372       for (auto &dim : input_shape.second) {
373         if (dim == -1) {
374           if (dyn_count >= dynamic_image_size[0].size()) {
375             MS_LOG(ERROR) << "Invalid dynamic_image_size " << dynamic_image_size
376                           << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
377                           << ", input shape: " << input_shape_str;
378             return false;
379           }
380           dim = dynamic_image_size[0][dyn_count];
381           dyn_count++;
382         }
383       }
384     }
385   } else {  // dynamic_batch_size
386     for (auto &input_shape : real_shapes) {
387       size_t dyn_count = 0;
388       for (auto &dim : input_shape.second) {
389         if (dim == -1) {
390           if (dyn_count >= 1) {
391             MS_LOG(ERROR) << "Invalid dynamic_batch_size " << dynamic_batch_size
392                           << " while dynamic dims in input_shape is more than " << (dyn_count + 1)
393                           << ", input shape: " << input_shape_str;
394             return false;
395           }
396           dim = dynamic_batch_size[0];
397           dyn_count++;
398         }
399       }
400     }
401   }
402   return true;
403 }
404 
GetGraphOneRealShapes(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos,std::vector<std::pair<std::string,ShapeVector>> * real_shapes_ptr,std::string * input_shape_str_ptr)405 bool GeDynamicUtils::GetGraphOneRealShapes(const std::shared_ptr<mindspore::Context> &context,
406                                            const ConfigInfos &config_infos,
407                                            std::vector<std::pair<std::string, ShapeVector>> *real_shapes_ptr,
408                                            std::string *input_shape_str_ptr) {
409   if (real_shapes_ptr == nullptr) {
410     MS_LOG(ERROR) << "Argument input_shapes_ptr cannot be nullptr";
411     return false;
412   }
413   std::string input_shape_str;
414   std::vector<GeDynamicShapeInfo> input_shapes;
415   auto ret = GetGraphInputShapes(context, config_infos, &input_shapes, &input_shape_str);
416   if (!ret) {
417     MS_LOG(ERROR) << "Failed to get input shape for AscendDeviceInfo or config file";
418     return false;
419   }
420   if (input_shape_str_ptr != nullptr) {
421     *input_shape_str_ptr = input_shape_str;
422   }
423   if (input_shapes.empty()) {
424     MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
425     return true;
426   }
427   std::vector<std::pair<std::string, ShapeVector>> &real_shapes = *real_shapes_ptr;
428   for (auto &item : input_shapes) {
429     std::vector<int64_t> shape;
430     for (auto &dim : item.shape) {
431       if (dim.dim == -1 && dim.min != dim.max) {
432         MS_LOG(ERROR) << "Cannot get one real shape because of shape range, shape: " << input_shape_str;
433         return false;
434       }
435       shape.push_back(dim.dim);
436     }
437     real_shapes.push_back(std::make_pair(item.name, shape));
438   }
439   if (!IsDynamicInputShapes(real_shapes)) {
440     MS_LOG(INFO) << "The dims number in all of input shapes are more than 0, return input shape in AscendDeviceInfo or "
441                     "config file";
442     return true;
443   }
444   std::vector<int64_t> dynamic_batch_size;
445   if (!GetDynamicBatchSize(context, config_infos, &dynamic_batch_size)) {
446     return false;
447   }
448   std::vector<std::vector<int64_t>> dynamic_image_size;
449   if (!GetDynamicImageSize(context, config_infos, &dynamic_image_size)) {
450     return false;
451   }
452   std::vector<std::vector<int64_t>> dynamic_dims;
453   if (!GetDynamicDims(context, config_infos, &dynamic_dims)) {
454     return false;
455   }
456   if (!CheckDynamicDims(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape_str)) {
457     return false;
458   }
459   if (!SetDynamicDimsRealValue(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape_str,
460                                real_shapes_ptr)) {
461     return false;
462   }
463   return true;
464 }
465 }  // namespace mindspore
466