1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "base/base.h"
25
26 #include "mindspore/core/symbolic_shape/symbol.h"
27 #include "mindspore/core/symbolic_shape/int_symbol.h"
28 #include "mindspore/core/symbolic_shape/symbol_info.h"
29 #include "pipeline/jit/ps/action.h"
30 #include "include/common/utils/parallel_context.h"
31
32 #include "ir/anf.h"
33 #include "ir/graph_utils.h"
34 #include "include/common/utils/comm_manager.h"
35 #include "utils/ms_context.h"
36
37 namespace mindspore {
38 namespace parallel {
GetSymbolInfo(const IntSymbol * int_s)39 static SymbolElement GetSymbolInfo(const IntSymbol *int_s) {
40 MS_EXCEPTION_IF_NULL(int_s);
41 SymbolElement tmp;
42 if (int_s->is_const()) { // static shape element
43 tmp.max = int_s->value();
44 tmp.min = int_s->value();
45 tmp.divisor = int_s->value();
46 tmp.remainder = 0;
47 } else {
48 tmp.max = int_s->range_max();
49 tmp.min = int_s->range_min();
50 tmp.divisor = int_s->divisor();
51 tmp.remainder = int_s->remainder();
52 }
53 return tmp;
54 }
55
StaticShapesToSymbols(const Shapes & shapes)56 Symbols StaticShapesToSymbols(const Shapes &shapes) {
57 Symbols symbols;
58 for (auto &shape : shapes) {
59 Symbol symbol;
60 for (auto &ele : shape) {
61 if (ele <= 0) {
62 MS_LOG(EXCEPTION) << "it is not static shape: " << ShapesToString(shapes);
63 }
64 SymbolElement symbol_ele;
65 symbol_ele.divisor = ele; // assign the divisor
66 symbol.push_back(symbol_ele);
67 }
68 symbols.push_back(symbol);
69 }
70 return symbols;
71 }
72
IsDynamicShape(const Shape & shape)73 bool IsDynamicShape(const Shape &shape) { return (std::count(shape.cbegin(), shape.cend(), -1) >= 1); }
74
IsDynamicShapes(const Shapes & shapes)75 bool IsDynamicShapes(const Shapes &shapes) {
76 for (auto &shape : shapes) {
77 if (std::count(shape.cbegin(), shape.cend(), -1) >= 1) {
78 return True;
79 }
80 }
81 return False;
82 }
83
IsDynamicShapesList(const std::vector<Shapes> & shapes_list)84 bool IsDynamicShapesList(const std::vector<Shapes> &shapes_list) {
85 return std::any_of(shapes_list.cbegin(), shapes_list.cend(),
86 [](const Shapes &shapes) { return IsDynamicShapes(shapes); });
87 }
88
CheckRealDivisorSize(const Shapes & shapes,const Shapes & real_divisor_shapes)89 void CheckRealDivisorSize(const Shapes &shapes, const Shapes &real_divisor_shapes) {
90 if (shapes.size() != real_divisor_shapes.size()) {
91 MS_LOG(EXCEPTION) << "the size of shapes is " << shapes.size() << ", but the size of real_divisor_shapes is "
92 << real_divisor_shapes.size() << ", they must be equal";
93 }
94 for (size_t i = 0; i < shapes.size(); ++i) {
95 if (shapes[i].size() != real_divisor_shapes[i].size()) {
96 MS_LOG(EXCEPTION) << "the size of shape is " << shapes[i].size() << ", but the size of real_divisor_shapes is "
97 << real_divisor_shapes[i].size() << ", they must be equal, the index is " << i;
98 }
99 }
100 }
101
102 // real divisor:
103 // 1, For static shape elements, the real divisor of symbol is static shape.
104 // 2, For dynamic shape elements, if remainder != 0, the real divisor of symbol is the maximum common divisor of divisor
105 // and remainder, else equal to divisor
106 // 3, For static shape node, the symbols may be empty. For example, the shapes of make_tuple may be [[1], [1]], but
107 // the symbols of make_tuple may be [[]], using the static shape as the real divisor of symbol
GetRealDivisorSymbols(const Shapes & shapes,const Symbols & symbols)108 Shapes GetRealDivisorSymbols(const Shapes &shapes, const Symbols &symbols) {
109 // dynamic shape graph may be has static operator, and its symbol may be empty, use shapes in this case
110 // static shape
111 if (!IsDynamicShapes(shapes)) {
112 return shapes;
113 }
114
115 if (shapes.size() != symbols.size()) {
116 MS_LOG(EXCEPTION) << "the size of shapes is " << shapes.size() << ", but the size of symbols is " << symbols.size()
117 << ", they must be equal";
118 }
119
120 Shapes real_divisor_shapes;
121 for (size_t i = 0; i < shapes.size(); ++i) {
122 // dynamic shape graph may be has static operator, and its symbol may be empty, use shapes in this case
123 // static shape
124 if (!IsDynamicShape(shapes[i])) {
125 real_divisor_shapes.push_back(shapes[i]);
126 continue;
127 }
128
129 // dynamic shape
130 if (shapes[i].size() != symbols[i].size()) {
131 MS_LOG(EXCEPTION) << "the size of shape is " << shapes[i].size() << ", but the size of symbol is "
132 << symbols[i].size() << ", they must be equal, the index is " << i;
133 }
134
135 Shape real_divisor_shape;
136 for (size_t j = 0; j < shapes[i].size(); ++j) {
137 // static shape element, use shape
138 if (shapes[i][j] > 0) {
139 real_divisor_shape.push_back(shapes[i][j]);
140 continue;
141 }
142
143 // dynamic shape element
144 int64_t real_divisor = 1;
145 if (symbols[i][j].remainder > 0) {
146 real_divisor = std::gcd(symbols[i][j].divisor, symbols[i][j].remainder);
147 } else {
148 real_divisor = symbols[i][j].divisor;
149 }
150 real_divisor_shape.push_back(real_divisor);
151 }
152
153 real_divisor_shapes.push_back(real_divisor_shape);
154 }
155
156 CheckRealDivisorSize(shapes, real_divisor_shapes);
157
158 return real_divisor_shapes;
159 }
160
DivisorOfSymbolToString(const Symbol & symbol)161 static std::string DivisorOfSymbolToString(const Symbol &symbol) {
162 std::string str = "[";
163 for (size_t i = 0; i < symbol.size(); ++i) {
164 str += std::to_string(symbol[i].divisor);
165 if (i < symbol.size() - 1) {
166 str += ", ";
167 }
168 }
169 return str + "]";
170 }
171
RemainderOfSymbolToString(const Symbol & symbol)172 static std::string RemainderOfSymbolToString(const Symbol &symbol) {
173 std::string str = "[";
174 for (size_t i = 0; i < symbol.size(); ++i) {
175 str += std::to_string(symbol[i].remainder);
176 if (i < symbol.size() - 1) {
177 str += ", ";
178 }
179 }
180 return str + "]";
181 }
182
DivisorOfSymbolsToString(const Symbols & symbols)183 std::string DivisorOfSymbolsToString(const Symbols &symbols) {
184 std::string str = "[";
185 for (size_t i = 0; i < symbols.size(); ++i) {
186 str += DivisorOfSymbolToString(symbols[i]);
187 if (i < symbols.size() - 1) {
188 str += ", ";
189 }
190 }
191 return str + "]";
192 }
193
RemainderOfSymbolsToString(const Symbols & symbols)194 std::string RemainderOfSymbolsToString(const Symbols &symbols) {
195 std::string str = "[";
196 for (size_t i = 0; i < symbols.size(); ++i) {
197 str += RemainderOfSymbolToString(symbols[i]);
198 if (i < symbols.size() - 1) {
199 str += ", ";
200 }
201 }
202 return str + "]";
203 }
204
PrintSymbolInfo(const std::vector<symshape::SymbolInfoList> & symbol_infos)205 void PrintSymbolInfo(const std::vector<symshape::SymbolInfoList> &symbol_infos) {
206 for (size_t i = 0; i < symbol_infos.size(); ++i) {
207 auto info_list = symbol_infos[i];
208 for (size_t j = 0; j < info_list.size(); ++j) {
209 MS_LOG(DEBUG) << "SYMBOL, i is " << i << ", j is " << j << ", divisor is " << info_list[j].divisor
210 << ", remainder is " << info_list[j].remainder;
211 }
212 }
213 }
214
IsParallelDynamicShape(const FuncGraphPtr & func_graph)215 bool IsParallelDynamicShape(const FuncGraphPtr &func_graph) {
216 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
217 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
218 if (parallel_mode != parallel::kAutoParallel && parallel_mode != parallel::kSemiAutoParallel) {
219 return false;
220 }
221 return pipeline::IsDynamicShapeGraph(func_graph);
222 }
223
IsSemiOrAutoParallelMode()224 bool IsSemiOrAutoParallelMode() {
225 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
226 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
227 return (parallel_mode == parallel::kAutoParallel || parallel_mode == parallel::kSemiAutoParallel);
228 }
229
GetDeviceNum()230 static int64_t GetDeviceNum() {
231 int64_t device_num = 1;
232 if (parallel::ParallelContext::GetInstance()->device_num_is_set()) {
233 device_num = parallel::ParallelContext::GetInstance()->device_num();
234 } else {
235 auto ms_context = MsContext::GetInstance();
236 MS_EXCEPTION_IF_NULL(ms_context);
237 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
238 std::string world_group;
239 if (backend == kAscendDevice || backend == kDavinciDevice) {
240 world_group = parallel::HCCL_WORLD_GROUP;
241 } else if (backend == kGPUDevice) {
242 world_group = parallel::NCCL_WORLD_GROUP;
243 } else {
244 MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend
245 << " for semi_auto_parallel/auto_parallel mode,"
246 " currently only support Ascend/GPU backend.";
247 }
248 uint32_t world_rank_size = 0;
249 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
250 MS_LOG(EXCEPTION) << "Get rank size failed";
251 }
252 device_num = UintToInt(world_rank_size);
253 }
254
255 auto pipeline_stage = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
256 device_num = device_num / pipeline_stage;
257 return device_num;
258 }
259
260 // modify symbol info by dataset strategy
261 // only for data sink is false
ParallelSymbolInfo(const std::vector<symshape::SymbolInfoList> & symbol_infos,bool has_dyn_shape)262 std::vector<symshape::SymbolInfoList> ParallelSymbolInfo(const std::vector<symshape::SymbolInfoList> &symbol_infos,
263 bool has_dyn_shape) {
264 if (!has_dyn_shape || !IsSemiOrAutoParallelMode()) { // static shape or sink mode no need to handle symbol info here
265 return symbol_infos;
266 }
267
268 ParallelContext::GetInstance()->set_symbol_infos(symbol_infos);
269
270 auto parallel_symbol_infos = symbol_infos;
271 parallel::Strategies dataset_strategy;
272 if (!parallel::ParallelContext::GetInstance()->dataset_strategy().empty()) {
273 dataset_strategy = parallel::ParallelContext::GetInstance()->dataset_strategy();
274 } else {
275 bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
276 if (full_batch) {
277 return parallel_symbol_infos;
278 } else {
279 // get device num
280 int64_t device_num = GetDeviceNum();
281
282 // set parallel symbol
283 for (auto &symbol : parallel_symbol_infos) {
284 if (!symbol.empty()) {
285 symbol[0].divisor = symbol[0].divisor * device_num;
286 symbol[0].remainder = symbol[0].remainder * device_num;
287 }
288 }
289 return parallel_symbol_infos;
290 }
291 }
292
293 MS_LOG(DEBUG) << "dataset strategy is " << dataset_strategy;
294 if (dataset_strategy.size() != parallel_symbol_infos.size()) {
295 MS_LOG(EXCEPTION) << "The size of dataset strategy is " << dataset_strategy.size()
296 << ", but the size of symbol info is " << parallel_symbol_infos.size();
297 }
298
299 for (size_t i = 0; i < dataset_strategy.size(); ++i) {
300 if (dataset_strategy[i].size() != parallel_symbol_infos[i].size()) {
301 MS_LOG(EXCEPTION) << "Invalid dataset strategy size for index " << i << ", the size of dataset strategy ele is "
302 << dataset_strategy[i].size() << ", but the size of symbol info ele is "
303 << parallel_symbol_infos[i].size();
304 }
305
306 for (size_t j = 0; j < dataset_strategy[i].size(); ++j) {
307 parallel_symbol_infos[i][j].divisor = parallel_symbol_infos[i][j].divisor * dataset_strategy[i][j];
308 parallel_symbol_infos[i][j].remainder = parallel_symbol_infos[i][j].remainder * dataset_strategy[i][j];
309 }
310 }
311
312 return parallel_symbol_infos;
313 }
314
GetNodeSymbol(const AnfNodePtr & node)315 Symbols GetNodeSymbol(const AnfNodePtr &node) {
316 MS_EXCEPTION_IF_NULL(node);
317 MS_EXCEPTION_IF_NULL(node->abstract());
318 Symbols symbols;
319 MS_LOG(DEBUG) << ", node is " << node->ToString() << ",full name is " << node->fullname_with_scope();
320
321 auto sym_shape = node->abstract()->GetSymbolicShape();
322 if (sym_shape == nullptr) {
323 // for static operator in dynamic shape graph, the symbol maybe null
324 // construct symbol base on shape
325 MS_EXCEPTION_IF_NULL(node->abstract()->GetShape());
326 sym_shape = node->abstract()->GetShape()->BuildSymbolicShape();
327 }
328
329 if (sym_shape->symbols().empty()) {
330 // dynamic operator, and the input is scalar
331 symbols.push_back(Symbol{});
332 return symbols;
333 }
334
335 Symbol int_symbol;
336 Symbols list_symbol;
337 bool int_symbol_flag = false;
338 for (const auto &s : sym_shape->symbols()) {
339 // There are two situations in sym_shape->symbols():
340 // 1, It is a ListSymbol, its elements are IntSymbols: [IntSymbol, IntSymbol, ..., IntSymbol]
341 // 2, It is a vector of ListSymbol, its elements are ListSymbols: [ListSymbols, ListSymbols, ..., ListSymbols]
342 // The ListSymbol like this: [s46<[64,inf]|64N|=s1*64-64>, 768], all elements are IntSymbols, but 768 is const
343 MS_EXCEPTION_IF_NULL(s);
344 if (s->is<IntSymbol>()) {
345 auto int_s = s->as<IntSymbol>();
346 int_symbol.push_back(GetSymbolInfo(int_s));
347 int_symbol_flag = true;
348 continue;
349 } else if (s->is<ListSymbol>()) {
350 auto list_s = s->as<ListSymbol>();
351 Symbol tmp;
352 for (const auto &ele : list_s->symbols()) {
353 if (ele->is<IntSymbol>()) {
354 auto int_s = ele->as<IntSymbol>();
355 tmp.push_back(GetSymbolInfo(int_s));
356 }
357 }
358 list_symbol.push_back(tmp);
359 } else {
360 MS_LOG(EXCEPTION) << "invalid symbol for " << node->fullname_with_scope();
361 }
362 }
363
364 if (int_symbol_flag) {
365 symbols.push_back(int_symbol);
366 } else {
367 symbols = list_symbol;
368 }
369
370 MS_LOG(DEBUG) << "The symbol is " << DivisorOfSymbolsToString(symbols);
371 return symbols;
372 }
373
TagDynamicShapeFuncGraph(const FuncGraphPtr & root)374 void TagDynamicShapeFuncGraph(const FuncGraphPtr &root) {
375 MS_EXCEPTION_IF_NULL(root);
376 MS_EXCEPTION_IF_NULL(root->manager());
377 for (auto &fg : root->manager()->func_graphs()) {
378 fg->set_dynamic_shape(pipeline::IsDynamicShapeGraph(fg));
379 }
380 }
381
InDynamicGraph(const CNodePtr & node)382 bool InDynamicGraph(const CNodePtr &node) {
383 MS_EXCEPTION_IF_NULL(node);
384 auto func_graph = node->func_graph();
385 MS_EXCEPTION_IF_NULL(func_graph);
386 MS_EXCEPTION_IF_NULL(func_graph->manager());
387 auto roots = func_graph->manager()->roots();
388 FuncGraphPtr root_graph = roots.back();
389 MS_EXCEPTION_IF_NULL(root_graph);
390
391 return root_graph->dynamic_shape();
392 }
393
UpdateParamSymbolicShape(const FuncGraphPtr & root)394 void UpdateParamSymbolicShape(const FuncGraphPtr &root) {
395 if (!root->dynamic_shape()) {
396 return;
397 }
398 auto symbol_infos = ParallelContext::GetInstance()->symbol_infos();
399 // when input is None, the parameter is removed from root graph.
400 symbol_infos.erase(std::remove_if(symbol_infos.begin(), symbol_infos.end(),
401 [](const symshape::SymbolInfoList &s) { return s.empty(); }),
402 symbol_infos.end());
403 abstract::AbstractBasePtrList params_abs(root->parameters().size());
404 (void)std::transform(root->parameters().begin(), root->parameters().end(), params_abs.begin(),
405 [](const AnfNodePtr &p) { return p->abstract(); });
406 std::vector<ListSymbolPtr> original_symbolic_shapes;
407 if (!symbol_infos.empty()) {
408 original_symbolic_shapes = symshape::BuildSymbolicShapeBySymbolInfo(params_abs, symbol_infos);
409 }
410 for (size_t i = 0; i < params_abs.size(); i++) {
411 if (params_abs[i] == nullptr) {
412 continue;
413 }
414 if (i < original_symbolic_shapes.size()) {
415 params_abs[i]->SetSymbolicShape(original_symbolic_shapes[i]);
416 } else if (params_abs[i]->GetSymbolicShape() != nullptr) {
417 params_abs[i]->SetSymbolicShape(nullptr);
418 }
419 }
420 ParallelContext::GetInstance()->set_symbol_infos({});
421 }
422 } // namespace parallel
423 } // namespace mindspore
424