• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "base/base.h"
25 
26 #include "mindspore/core/symbolic_shape/symbol.h"
27 #include "mindspore/core/symbolic_shape/int_symbol.h"
28 #include "mindspore/core/symbolic_shape/symbol_info.h"
29 #include "pipeline/jit/ps/action.h"
30 #include "include/common/utils/parallel_context.h"
31 
32 #include "ir/anf.h"
33 #include "ir/graph_utils.h"
34 #include "include/common/utils/comm_manager.h"
35 #include "utils/ms_context.h"
36 
37 namespace mindspore {
38 namespace parallel {
GetSymbolInfo(const IntSymbol * int_s)39 static SymbolElement GetSymbolInfo(const IntSymbol *int_s) {
40   MS_EXCEPTION_IF_NULL(int_s);
41   SymbolElement tmp;
42   if (int_s->is_const()) {  // static shape element
43     tmp.max = int_s->value();
44     tmp.min = int_s->value();
45     tmp.divisor = int_s->value();
46     tmp.remainder = 0;
47   } else {
48     tmp.max = int_s->range_max();
49     tmp.min = int_s->range_min();
50     tmp.divisor = int_s->divisor();
51     tmp.remainder = int_s->remainder();
52   }
53   return tmp;
54 }
55 
StaticShapesToSymbols(const Shapes & shapes)56 Symbols StaticShapesToSymbols(const Shapes &shapes) {
57   Symbols symbols;
58   for (auto &shape : shapes) {
59     Symbol symbol;
60     for (auto &ele : shape) {
61       if (ele <= 0) {
62         MS_LOG(EXCEPTION) << "it is not static shape: " << ShapesToString(shapes);
63       }
64       SymbolElement symbol_ele;
65       symbol_ele.divisor = ele;  // assign the divisor
66       symbol.push_back(symbol_ele);
67     }
68     symbols.push_back(symbol);
69   }
70   return symbols;
71 }
72 
IsDynamicShape(const Shape & shape)73 bool IsDynamicShape(const Shape &shape) { return (std::count(shape.cbegin(), shape.cend(), -1) >= 1); }
74 
IsDynamicShapes(const Shapes & shapes)75 bool IsDynamicShapes(const Shapes &shapes) {
76   for (auto &shape : shapes) {
77     if (std::count(shape.cbegin(), shape.cend(), -1) >= 1) {
78       return True;
79     }
80   }
81   return False;
82 }
83 
IsDynamicShapesList(const std::vector<Shapes> & shapes_list)84 bool IsDynamicShapesList(const std::vector<Shapes> &shapes_list) {
85   return std::any_of(shapes_list.cbegin(), shapes_list.cend(),
86                      [](const Shapes &shapes) { return IsDynamicShapes(shapes); });
87 }
88 
CheckRealDivisorSize(const Shapes & shapes,const Shapes & real_divisor_shapes)89 void CheckRealDivisorSize(const Shapes &shapes, const Shapes &real_divisor_shapes) {
90   if (shapes.size() != real_divisor_shapes.size()) {
91     MS_LOG(EXCEPTION) << "the size of shapes is " << shapes.size() << ", but the size of real_divisor_shapes is "
92                       << real_divisor_shapes.size() << ", they must be equal";
93   }
94   for (size_t i = 0; i < shapes.size(); ++i) {
95     if (shapes[i].size() != real_divisor_shapes[i].size()) {
96       MS_LOG(EXCEPTION) << "the size of shape is " << shapes[i].size() << ", but the size of real_divisor_shapes is "
97                         << real_divisor_shapes[i].size() << ", they must be equal, the index is " << i;
98     }
99   }
100 }
101 
102 // real divisor:
103 // 1, For static shape elements, the real divisor of symbol is static shape.
104 // 2, For dynamic shape elements, if remainder != 0, the real divisor of symbol is the maximum common divisor of divisor
105 // and remainder, else equal to divisor
106 // 3, For static shape node, the symbols may be empty. For example, the shapes of make_tuple may be [[1], [1]], but
107 // the symbols of make_tuple may be [[]], using the static shape as the real divisor of symbol
GetRealDivisorSymbols(const Shapes & shapes,const Symbols & symbols)108 Shapes GetRealDivisorSymbols(const Shapes &shapes, const Symbols &symbols) {
109   // dynamic shape graph may be has static operator, and its symbol may be empty, use shapes in this case
110   // static shape
111   if (!IsDynamicShapes(shapes)) {
112     return shapes;
113   }
114 
115   if (shapes.size() != symbols.size()) {
116     MS_LOG(EXCEPTION) << "the size of shapes is " << shapes.size() << ", but the size of symbols is " << symbols.size()
117                       << ", they must be equal";
118   }
119 
120   Shapes real_divisor_shapes;
121   for (size_t i = 0; i < shapes.size(); ++i) {
122     // dynamic shape graph may be has static operator, and its symbol may be empty, use shapes in this case
123     // static shape
124     if (!IsDynamicShape(shapes[i])) {
125       real_divisor_shapes.push_back(shapes[i]);
126       continue;
127     }
128 
129     // dynamic shape
130     if (shapes[i].size() != symbols[i].size()) {
131       MS_LOG(EXCEPTION) << "the size of shape is " << shapes[i].size() << ", but the size of symbol is "
132                         << symbols[i].size() << ", they must be equal, the index is " << i;
133     }
134 
135     Shape real_divisor_shape;
136     for (size_t j = 0; j < shapes[i].size(); ++j) {
137       // static shape element, use shape
138       if (shapes[i][j] > 0) {
139         real_divisor_shape.push_back(shapes[i][j]);
140         continue;
141       }
142 
143       // dynamic shape element
144       int64_t real_divisor = 1;
145       if (symbols[i][j].remainder > 0) {
146         real_divisor = std::gcd(symbols[i][j].divisor, symbols[i][j].remainder);
147       } else {
148         real_divisor = symbols[i][j].divisor;
149       }
150       real_divisor_shape.push_back(real_divisor);
151     }
152 
153     real_divisor_shapes.push_back(real_divisor_shape);
154   }
155 
156   CheckRealDivisorSize(shapes, real_divisor_shapes);
157 
158   return real_divisor_shapes;
159 }
160 
DivisorOfSymbolToString(const Symbol & symbol)161 static std::string DivisorOfSymbolToString(const Symbol &symbol) {
162   std::string str = "[";
163   for (size_t i = 0; i < symbol.size(); ++i) {
164     str += std::to_string(symbol[i].divisor);
165     if (i < symbol.size() - 1) {
166       str += ", ";
167     }
168   }
169   return str + "]";
170 }
171 
RemainderOfSymbolToString(const Symbol & symbol)172 static std::string RemainderOfSymbolToString(const Symbol &symbol) {
173   std::string str = "[";
174   for (size_t i = 0; i < symbol.size(); ++i) {
175     str += std::to_string(symbol[i].remainder);
176     if (i < symbol.size() - 1) {
177       str += ", ";
178     }
179   }
180   return str + "]";
181 }
182 
DivisorOfSymbolsToString(const Symbols & symbols)183 std::string DivisorOfSymbolsToString(const Symbols &symbols) {
184   std::string str = "[";
185   for (size_t i = 0; i < symbols.size(); ++i) {
186     str += DivisorOfSymbolToString(symbols[i]);
187     if (i < symbols.size() - 1) {
188       str += ", ";
189     }
190   }
191   return str + "]";
192 }
193 
RemainderOfSymbolsToString(const Symbols & symbols)194 std::string RemainderOfSymbolsToString(const Symbols &symbols) {
195   std::string str = "[";
196   for (size_t i = 0; i < symbols.size(); ++i) {
197     str += RemainderOfSymbolToString(symbols[i]);
198     if (i < symbols.size() - 1) {
199       str += ", ";
200     }
201   }
202   return str + "]";
203 }
204 
PrintSymbolInfo(const std::vector<symshape::SymbolInfoList> & symbol_infos)205 void PrintSymbolInfo(const std::vector<symshape::SymbolInfoList> &symbol_infos) {
206   for (size_t i = 0; i < symbol_infos.size(); ++i) {
207     auto info_list = symbol_infos[i];
208     for (size_t j = 0; j < info_list.size(); ++j) {
209       MS_LOG(DEBUG) << "SYMBOL, i is " << i << ", j is " << j << ", divisor is " << info_list[j].divisor
210                     << ", remainder is " << info_list[j].remainder;
211     }
212   }
213 }
214 
IsParallelDynamicShape(const FuncGraphPtr & func_graph)215 bool IsParallelDynamicShape(const FuncGraphPtr &func_graph) {
216   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
217   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
218   if (parallel_mode != parallel::kAutoParallel && parallel_mode != parallel::kSemiAutoParallel) {
219     return false;
220   }
221   return pipeline::IsDynamicShapeGraph(func_graph);
222 }
223 
IsSemiOrAutoParallelMode()224 bool IsSemiOrAutoParallelMode() {
225   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
226   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
227   return (parallel_mode == parallel::kAutoParallel || parallel_mode == parallel::kSemiAutoParallel);
228 }
229 
GetDeviceNum()230 static int64_t GetDeviceNum() {
231   int64_t device_num = 1;
232   if (parallel::ParallelContext::GetInstance()->device_num_is_set()) {
233     device_num = parallel::ParallelContext::GetInstance()->device_num();
234   } else {
235     auto ms_context = MsContext::GetInstance();
236     MS_EXCEPTION_IF_NULL(ms_context);
237     std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
238     std::string world_group;
239     if (backend == kAscendDevice || backend == kDavinciDevice) {
240       world_group = parallel::HCCL_WORLD_GROUP;
241     } else if (backend == kGPUDevice) {
242       world_group = parallel::NCCL_WORLD_GROUP;
243     } else {
244       MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend
245                         << " for semi_auto_parallel/auto_parallel mode,"
246                            " currently only support Ascend/GPU backend.";
247     }
248     uint32_t world_rank_size = 0;
249     if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
250       MS_LOG(EXCEPTION) << "Get rank size failed";
251     }
252     device_num = UintToInt(world_rank_size);
253   }
254 
255   auto pipeline_stage = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
256   device_num = device_num / pipeline_stage;
257   return device_num;
258 }
259 
260 // modify symbol info by dataset strategy
261 // only for data sink is false
ParallelSymbolInfo(const std::vector<symshape::SymbolInfoList> & symbol_infos,bool has_dyn_shape)262 std::vector<symshape::SymbolInfoList> ParallelSymbolInfo(const std::vector<symshape::SymbolInfoList> &symbol_infos,
263                                                          bool has_dyn_shape) {
264   if (!has_dyn_shape || !IsSemiOrAutoParallelMode()) {  // static shape or sink mode no need to handle symbol info here
265     return symbol_infos;
266   }
267 
268   ParallelContext::GetInstance()->set_symbol_infos(symbol_infos);
269 
270   auto parallel_symbol_infos = symbol_infos;
271   parallel::Strategies dataset_strategy;
272   if (!parallel::ParallelContext::GetInstance()->dataset_strategy().empty()) {
273     dataset_strategy = parallel::ParallelContext::GetInstance()->dataset_strategy();
274   } else {
275     bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
276     if (full_batch) {
277       return parallel_symbol_infos;
278     } else {
279       // get device num
280       int64_t device_num = GetDeviceNum();
281 
282       // set parallel symbol
283       for (auto &symbol : parallel_symbol_infos) {
284         if (!symbol.empty()) {
285           symbol[0].divisor = symbol[0].divisor * device_num;
286           symbol[0].remainder = symbol[0].remainder * device_num;
287         }
288       }
289       return parallel_symbol_infos;
290     }
291   }
292 
293   MS_LOG(DEBUG) << "dataset strategy is " << dataset_strategy;
294   if (dataset_strategy.size() != parallel_symbol_infos.size()) {
295     MS_LOG(EXCEPTION) << "The size of dataset strategy is " << dataset_strategy.size()
296                       << ", but the size of symbol info is " << parallel_symbol_infos.size();
297   }
298 
299   for (size_t i = 0; i < dataset_strategy.size(); ++i) {
300     if (dataset_strategy[i].size() != parallel_symbol_infos[i].size()) {
301       MS_LOG(EXCEPTION) << "Invalid dataset strategy size for index " << i << ", the size of dataset strategy ele is "
302                         << dataset_strategy[i].size() << ", but the size of symbol info ele is "
303                         << parallel_symbol_infos[i].size();
304     }
305 
306     for (size_t j = 0; j < dataset_strategy[i].size(); ++j) {
307       parallel_symbol_infos[i][j].divisor = parallel_symbol_infos[i][j].divisor * dataset_strategy[i][j];
308       parallel_symbol_infos[i][j].remainder = parallel_symbol_infos[i][j].remainder * dataset_strategy[i][j];
309     }
310   }
311 
312   return parallel_symbol_infos;
313 }
314 
GetNodeSymbol(const AnfNodePtr & node)315 Symbols GetNodeSymbol(const AnfNodePtr &node) {
316   MS_EXCEPTION_IF_NULL(node);
317   MS_EXCEPTION_IF_NULL(node->abstract());
318   Symbols symbols;
319   MS_LOG(DEBUG) << ", node is " << node->ToString() << ",full name is " << node->fullname_with_scope();
320 
321   auto sym_shape = node->abstract()->GetSymbolicShape();
322   if (sym_shape == nullptr) {
323     // for static operator in dynamic shape graph, the symbol maybe null
324     // construct symbol base on shape
325     MS_EXCEPTION_IF_NULL(node->abstract()->GetShape());
326     sym_shape = node->abstract()->GetShape()->BuildSymbolicShape();
327   }
328 
329   if (sym_shape->symbols().empty()) {
330     // dynamic operator, and the input is scalar
331     symbols.push_back(Symbol{});
332     return symbols;
333   }
334 
335   Symbol int_symbol;
336   Symbols list_symbol;
337   bool int_symbol_flag = false;
338   for (const auto &s : sym_shape->symbols()) {
339     // There are two situations in sym_shape->symbols():
340     // 1, It is a ListSymbol, its elements are IntSymbols: [IntSymbol, IntSymbol, ..., IntSymbol]
341     // 2, It is a vector of ListSymbol, its elements are ListSymbols: [ListSymbols, ListSymbols, ..., ListSymbols]
342     // The ListSymbol like this: [s46<[64,inf]|64N|=s1*64-64>, 768], all elements are IntSymbols, but 768 is const
343     MS_EXCEPTION_IF_NULL(s);
344     if (s->is<IntSymbol>()) {
345       auto int_s = s->as<IntSymbol>();
346       int_symbol.push_back(GetSymbolInfo(int_s));
347       int_symbol_flag = true;
348       continue;
349     } else if (s->is<ListSymbol>()) {
350       auto list_s = s->as<ListSymbol>();
351       Symbol tmp;
352       for (const auto &ele : list_s->symbols()) {
353         if (ele->is<IntSymbol>()) {
354           auto int_s = ele->as<IntSymbol>();
355           tmp.push_back(GetSymbolInfo(int_s));
356         }
357       }
358       list_symbol.push_back(tmp);
359     } else {
360       MS_LOG(EXCEPTION) << "invalid symbol for " << node->fullname_with_scope();
361     }
362   }
363 
364   if (int_symbol_flag) {
365     symbols.push_back(int_symbol);
366   } else {
367     symbols = list_symbol;
368   }
369 
370   MS_LOG(DEBUG) << "The symbol is " << DivisorOfSymbolsToString(symbols);
371   return symbols;
372 }
373 
TagDynamicShapeFuncGraph(const FuncGraphPtr & root)374 void TagDynamicShapeFuncGraph(const FuncGraphPtr &root) {
375   MS_EXCEPTION_IF_NULL(root);
376   MS_EXCEPTION_IF_NULL(root->manager());
377   for (auto &fg : root->manager()->func_graphs()) {
378     fg->set_dynamic_shape(pipeline::IsDynamicShapeGraph(fg));
379   }
380 }
381 
InDynamicGraph(const CNodePtr & node)382 bool InDynamicGraph(const CNodePtr &node) {
383   MS_EXCEPTION_IF_NULL(node);
384   auto func_graph = node->func_graph();
385   MS_EXCEPTION_IF_NULL(func_graph);
386   MS_EXCEPTION_IF_NULL(func_graph->manager());
387   auto roots = func_graph->manager()->roots();
388   FuncGraphPtr root_graph = roots.back();
389   MS_EXCEPTION_IF_NULL(root_graph);
390 
391   return root_graph->dynamic_shape();
392 }
393 
UpdateParamSymbolicShape(const FuncGraphPtr & root)394 void UpdateParamSymbolicShape(const FuncGraphPtr &root) {
395   if (!root->dynamic_shape()) {
396     return;
397   }
398   auto symbol_infos = ParallelContext::GetInstance()->symbol_infos();
399   // when input is None, the parameter is removed from root graph.
400   symbol_infos.erase(std::remove_if(symbol_infos.begin(), symbol_infos.end(),
401                                     [](const symshape::SymbolInfoList &s) { return s.empty(); }),
402                      symbol_infos.end());
403   abstract::AbstractBasePtrList params_abs(root->parameters().size());
404   (void)std::transform(root->parameters().begin(), root->parameters().end(), params_abs.begin(),
405                        [](const AnfNodePtr &p) { return p->abstract(); });
406   std::vector<ListSymbolPtr> original_symbolic_shapes;
407   if (!symbol_infos.empty()) {
408     original_symbolic_shapes = symshape::BuildSymbolicShapeBySymbolInfo(params_abs, symbol_infos);
409   }
410   for (size_t i = 0; i < params_abs.size(); i++) {
411     if (params_abs[i] == nullptr) {
412       continue;
413     }
414     if (i < original_symbolic_shapes.size()) {
415       params_abs[i]->SetSymbolicShape(original_symbolic_shapes[i]);
416     } else if (params_abs[i]->GetSymbolicShape() != nullptr) {
417       params_abs[i]->SetSymbolicShape(nullptr);
418     }
419   }
420   ParallelContext::GetInstance()->set_symbol_infos({});
421 }
422 }  // namespace parallel
423 }  // namespace mindspore
424