• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Utils of auto parallel"""
16from importlib import import_module
17import numpy as np
18import mindspore as ms
19from mindspore import context, log as logger
20from mindspore._c_expression import reset_op_id, reset_op_id_with_offset
21from mindspore.common.tensor import Tensor
22from mindspore.common.dtype import dtype_to_nptype
23from mindspore.common import dtype as mstype
24from mindspore.communication.management import get_group_size, get_rank
25from mindspore.parallel._auto_parallel_context import auto_parallel_context
26from mindspore.common.seed import get_seed
27from mindspore._c_expression import GraphExecutor_
28from mindspore.parallel._tensor import _load_tensor_by_layout
29
30SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore"]
31
32
33def _get_parallel_mode():
34    """Get parallel mode."""
35    return auto_parallel_context().get_parallel_mode()
36
37
38def _is_sharding_propagation():
39    """Is sharding propagation."""
40    return (auto_parallel_context().get_strategy_search_mode() == "sharding_propagation") or (
41        auto_parallel_context().get_sharding_propagation())
42
43
44def _is_in_auto_parallel_mode():
45    return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
46
47
48def _is_in_data_parallel_mode():
49    return _get_parallel_mode() == ms.ParallelMode.DATA_PARALLEL
50
51
52def _is_in_hybrid_parallel_mode():
53    return _get_parallel_mode() == ms.ParallelMode.HYBRID_PARALLEL
54
55
56def _is_pynative_parallel():
57    parallel_mode = context.get_auto_parallel_context('parallel_mode')
58    return context.get_context('mode') == context.PYNATIVE_MODE and parallel_mode in (
59        context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
60
61
62def _get_full_batch():
63    """Get whether to use full_batch."""
64    return auto_parallel_context().get_full_batch()
65
66
67def _get_pipeline_stages():
68    """Get pipeline stages"""
69    return auto_parallel_context().get_pipeline_stages()
70
71
72def _check_full_batch():
73    """
74    full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
75
76    Raises:
77        RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
78    """
79    parallel_mode = _get_parallel_mode()
80    full_batch = _get_full_batch()
81    if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
82        raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
83
84
85def _need_to_full():
86    """Check whether to convert input to full shape or tensor."""
87    if _get_parallel_mode() not in ("semi_auto_parallel", "auto_parallel"):
88        return False
89    dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
90    if dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"):
91        return True
92    return not _get_full_batch()
93
94
95def _slice_parameter(parameter, phase, layout):
96    """Slice python parameter obj according to the layout."""
97    is_train_phase = phase.startswith('train')
98    is_prefill_phase = phase.startswith('prefill')
99    if layout is not None and parameter.from_ckpt and not is_train_phase:
100        is_opt_shard_group = layout[5]
101        if not parameter.sliced and is_prefill_phase and is_opt_shard_group:
102            rank = get_rank()
103            new_tensor = _load_tensor_by_layout(parameter, layout, rank)
104            parameter.set_data(new_tensor, True)
105            return
106        layout_shape = layout[2]
107        parameter.shape = tuple(layout_shape)
108        return
109    graph_executor = GraphExecutor_.get_instance()
110    new_param = parameter.init_data(layout, set_sliced=True)
111    parameter = new_param
112    graph_executor.updata_param_node_default_input(phase, {parameter.name: parameter})
113    if layout is None:
114        parameter.sliced = True
115        return
116    if not parameter.sliced:
117        rank = get_rank()
118        new_tensor = _load_tensor_by_layout(parameter, layout, rank)
119        parameter.set_data(new_tensor, True)
120
121
122def _slice_tensor(tensor, layout, rank_id):
123    """Slice python tensor obj according to the layout."""
124    new_tensor = _load_tensor_by_layout(tensor, layout, rank_id)
125    return new_tensor
126
127
128def _init_optimizer_state(parameter, phase):
129    """init optimizer state"""
130    if not parameter.has_init:
131        return
132    graph_executor = GraphExecutor_.get_instance()
133    new_param = parameter.init_data()
134    parameter = new_param
135    graph_executor.updata_param_node_default_input(phase, {parameter.name: parameter})
136
137
138def _to_full_shapes(shapes, device_num):
139    """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
140    new_shapes = []
141    dataset_strategy = ()
142    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
143        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
144    if dataset_strategy:
145        if len(shapes) != len(dataset_strategy):
146            raise ValueError("The input shapes size {} is not equal to "
147                             "dataset strategy size {}".format(len(shapes), len(dataset_strategy)))
148        for index, shape in enumerate(shapes):
149            if len(shape) != len(dataset_strategy[index]):
150                raise ValueError("The input shapes item size {} is not equal to "
151                                 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
152            new_shape = []
153            for i, item in enumerate(shape):
154                if item > 0:
155                    new_shape += (item * dataset_strategy[index][i],)  # static shape
156                else:
157                    new_shape += (item,)  # dynamic shape
158            new_shapes.append(new_shape)
159        return new_shapes
160    for shape in shapes:
161        shape_v = []
162        for i, item in enumerate(shape):
163            if i == 0 and item > 0:
164                shape_v += (item * device_num,)  # only for static shape
165            else:
166                shape_v += (item,)
167        new_shapes.append(shape_v)
168    return new_shapes
169
170
171def _origin_shapes(shapes):
172    """resume origin shape after full shape."""
173    if _need_to_full():
174        device_num = _get_device_num() // _get_pipeline_stages()
175    else:
176        return shapes
177    new_shapes = []
178    dataset_strategy = ()
179    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
180        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
181    if dataset_strategy:
182        if len(shapes) != len(dataset_strategy):
183            raise ValueError("The input shapes size {} is not equal to "
184                             "dataset strategy size {}".format(len(shapes), len(dataset_strategy)))
185        for index, shape in enumerate(shapes):
186            if len(shape) != len(dataset_strategy[index]):
187                raise ValueError("The input shapes item size {} is not equal to "
188                                 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
189            new_shape = []
190            for i, item in enumerate(shape):
191                if item > 0:
192                    new_shape += (item // dataset_strategy[index][i],)  # static shape
193                else:
194                    new_shape += (item,)  # dynamic shape
195            new_shapes.append(new_shape)
196        return new_shapes
197    for shape in shapes:
198        shape_v = []
199        for i, item in enumerate(shape):
200            if i == 0 and item > 0:
201                shape_v += (item // device_num,)  # only for static shape
202            else:
203                shape_v += (item,)
204        new_shapes.append(shape_v)
205    return new_shapes
206
207
208def _dynamic_shape_for_dataset(dataset_shapes, dynamic_shapes):
209    """convert static dataset shapes to dynamic shape"""
210    if len(dataset_shapes) != len(dynamic_shapes):
211        raise ValueError("The dataset shapes size of {} is not equal to "
212                         "dynamic shapes size of {}".format(dataset_shapes, dynamic_shapes))
213    ret = dataset_shapes
214    for i in range(len(dynamic_shapes)):
215        if len(dataset_shapes[i]) != len(dynamic_shapes[i]):
216            raise ValueError("The dataset shapes size of {} is not equal to "
217                             "dynamic shapes size of {}".format(dataset_shapes, dynamic_shapes))
218        for j in range(len(dynamic_shapes[i])):
219            if dynamic_shapes[i][j] == -1:
220                ret[i][j] = -1
221    return ret
222
223
224def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
225    """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
226       from host solution.
227    """
228    lst = []
229    device_num = global_device_num // _get_pipeline_stages()
230    stage_rank = global_rank % device_num
231    if not isinstance(elem, (tuple, list)):
232        elem = [elem]
233    if stage_rank >= device_num:
234        raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
235                         "the device num is {}".format(stage_rank, device_num))
236    dataset_strategy = ()
237    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
238        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
239    if elem and dataset_strategy:
240        if len(elem) != len(dataset_strategy):
241            raise ValueError("The input size {} is not equal to "
242                             "dataset strategy size {}".format(len(elem), len(dataset_strategy)))
243    for index, data in enumerate(elem):
244        if isinstance(data, np.ndarray):
245            data = Tensor(data)
246        if not isinstance(data, Tensor):
247            raise ValueError("elements in tensors must be Tensor")
248        shape_ = data.shape
249        type_ = data.dtype
250        new_shape = ()
251        if not dataset_strategy:
252            batchsize_per_device = 1
253            for i, item in enumerate(shape_):
254                if i == 0:
255                    new_shape += (item * device_num,)
256                    batchsize_per_device = item
257                else:
258                    new_shape += (item,)
259            new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
260            start = stage_rank * batchsize_per_device
261            new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
262        else:
263            if len(shape_) != len(dataset_strategy[index]):
264                raise ValueError("The input shapes item size {} is not equal to "
265                                 "dataset strategy item size {}".format(len(shape_), len(dataset_strategy[index])))
266            slice_index = ()
267            for i, item in enumerate(shape_):
268                new_shape += (item * dataset_strategy[index][i],)
269                start = (stage_rank % dataset_strategy[index][i]) * item
270                end = (stage_rank % dataset_strategy[index][i] + 1) * item
271                s = slice(start, end, 1)
272                slice_index += (s,)
273            new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
274            new_tensor_numpy[slice_index] = data.asnumpy()
275        new_tensor = Tensor(new_tensor_numpy, dtype=type_)
276        lst.append(new_tensor)
277    if scaling_sens:
278        lst.append(Tensor(scaling_sens, mstype.float32))
279    return tuple(lst)
280
281
282def _get_gradients_mean():
283    """Get if using gradients_mean."""
284    return auto_parallel_context().get_gradients_mean()
285
286
287def _get_device_num():
288    """Get the device num."""
289    parallel_mode = auto_parallel_context().get_parallel_mode()
290    if parallel_mode == "stand_alone":
291        device_num = 1
292        return device_num
293
294    if auto_parallel_context().get_device_num_is_set() is False:
295        device_num = get_group_size()
296    else:
297        device_num = auto_parallel_context().get_device_num()
298    return device_num
299
300
301def _get_stage_device_num():
302    """Get the device number of each pipeline stage"""
303    return _get_device_num() // _get_pipeline_stages()
304
305
306def _get_global_rank():
307    """Get the global rank."""
308    parallel_mode = auto_parallel_context().get_parallel_mode()
309    if parallel_mode == "stand_alone":
310        global_rank = 0
311        return global_rank
312
313    if auto_parallel_context().get_global_rank_is_set() is False:
314        global_rank = get_rank()
315    else:
316        global_rank = auto_parallel_context().get_global_rank()
317    return global_rank
318
319
320def _get_parameter_broadcast():
321    """Get the parameter broadcast."""
322    parallel_mode = auto_parallel_context().get_parallel_mode()
323    parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
324
325    if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed() is None:
326        logger.warning("You are suggested to use mindspore.context.set_auto_parallel_context(parameter_broadcast=True)"
327                       " or mindspore.common.set_seed() to share parameters among multi-devices.")
328
329    return parameter_broadcast
330
331
332def _get_enable_parallel_optimizer():
333    """Get if using parallel optimizer."""
334    return auto_parallel_context().get_enable_parallel_optimizer()
335
336
337def _get_grad_accumulation_shard():
338    """Get if using parallel shard."""
339    return auto_parallel_context().get_grad_accumulation_shard()
340
341
342def _device_number_check(parallel_mode, device_number):
343    """
344    Check device num.
345
346    Args:
347        parallel_mode (str): The parallel mode.
348        device_number (int): The device number.
349    """
350    if parallel_mode == "stand_alone" and device_number != 1:
351        raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
352                         "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
353
354
355def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
356    """
357    Check parameter broadcast.
358
359    Note:
360        If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
361        random seed to make sure parameters on multiple devices are the same.
362
363    Args:
364        parallel_mode (str): The parallel mode.
365        parameter_broadcast (bool): The parameter broadcast.
366
367    Raises:
368        ValueError: If parameter is broadcasted
369                    but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
370    """
371    if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
372        raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
373                         "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
374                         .format(parallel_mode, parameter_broadcast))
375
376
377def _get_python_op(op_name, op_path, instance_name, arglist):
378    """Get python operator."""
379    module = import_module(op_path)
380    cls = getattr(module, op_name)
381    if op_path != "mindspore.ops.functional":
382        # The AllGather attrs contains group_name and group_ranks, pop group_ranks.
383        if op_name == "AllGather" and len(arglist) == 2:
384            arglist.pop()
385        op = cls(*arglist)
386    else:
387        op = cls
388    op.set_prim_instance_name(instance_name)
389    return op
390
391
392def _reset_op_id():
393    """Reset op id."""
394    reset_op_id()
395
396
397def _reset_op_id_with_offset():
398    """Reset op id with offset."""
399    reset_op_id_with_offset()
400
401
402def _parallel_predict_check():
403    """validate parallel model prediction"""
404    if _is_in_auto_parallel_mode():
405        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
406        is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"))
407        if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp:
408            logger.warning('Using non full-batch dataset in model prediction may lead to incorrect data.')
409
410
411def _check_similar_layout(tensor_layout1, tensor_layout2):
412    """check if two tensor layouts are same"""
413    if tensor_layout1[1] != tensor_layout2[1]:
414        return False
415    for i in tensor_layout1[1]:
416        if i == -1:
417            continue
418        if tensor_layout1[0][-1 - i] != tensor_layout2[0][-1 - i]:
419            return False
420    return True
421
422
423def _check_same_layout(tensor_layout1, tensor_layout2):
424    """check if two tensor layouts are same"""
425    return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
426
427
428def _remove_repeated_slices(tensor_layout):
429    """generate unrepeated tensor layout"""
430    import copy
431    new_tensor_layout = copy.deepcopy(tensor_layout)
432    dev_mat = tensor_layout[0][:]
433    tensor_map = tensor_layout[1]
434    for dim in range(len(dev_mat)):
435        if dim not in tensor_map:
436            dev_mat[-1 - dim] = 1
437    new_tensor_layout[0] = dev_mat
438    return new_tensor_layout
439
440
441def _infer_rank_list(train_map, predict_map=None):
442    """
443    infer checkpoint slices to be loaded.
444    map value format: [dev_mat, tensor_map, param_split_shape, field_size, opt_shard_stride, opt_shard_size]
445    """
446    ret = {}
447    if _get_pipeline_stages() > 1:
448        local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
449    else:
450        local_rank = _get_global_rank()
451    for param_name in train_map:
452        train_layout = train_map[param_name]
453        train_dev_mat = train_layout[0]
454        dev_num = np.array(train_dev_mat).prod()
455        new_train_layout = _remove_repeated_slices(train_layout)
456        array = np.arange(dev_num).reshape(train_dev_mat)
457        index = ()
458        for i in new_train_layout[0]:
459            if i == 1:
460                index = index + (0,)
461            else:
462                index = index + (slice(None),)
463        rank_list = array[index].flatten()
464        if not predict_map:
465            ret[param_name] = (rank_list, False)
466            continue
467        if param_name not in predict_map:
468            logger.warning("predict_map does not contain %s", param_name)
469            continue
470        predict_layout = predict_map[param_name]
471        dev_num = np.array(predict_layout[0]).prod()
472        # optimization pass
473        if _check_same_layout(train_layout, predict_layout):
474            ret[param_name] = ([local_rank], True)
475            continue
476        if _check_similar_layout(train_layout, predict_layout):
477            if len(rank_list) == 1:
478                ret[param_name] = (rank_list, True)
479            elif len(rank_list) == dev_num:
480                ret[param_name] = ([rank_list[local_rank]], True)
481            else:
482                ret[param_name] = (rank_list, False)
483        else:
484            ret[param_name] = (rank_list, False)
485    return ret
486
487
488def _handle_symbol_inputs(symbol_inputs):
489    """handle symbol inputs"""
490    dataset_strategy = ()
491    divisor_key = "divisor"
492    # dataset strategy is set
493    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
494        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
495    if dataset_strategy:
496        if len(symbol_inputs) != len(dataset_strategy):
497            raise ValueError("The symbol_inputs size {} is not equal to "
498                             "dataset strategy size {}".format(len(symbol_inputs), len(dataset_strategy)))
499        for index, shape in enumerate(symbol_inputs):
500            dataset_ele_s = dataset_strategy[index]
501            if len(shape) != len(dataset_ele_s):
502                raise ValueError("The symbol_inputs item size {} is not equal to "
503                                 "dataset strategy item size {}".format(len(shape), len(dataset_ele_s)))
504
505            for i, item in enumerate(shape):
506                if isinstance(item, dict):  # symbol
507                    symbol_inputs[index][i][divisor_key] = symbol_inputs[index][i][divisor_key] * dataset_ele_s[i]
508                else:  # common shape
509                    symbol_inputs[index][i] = item * dataset_ele_s[i]
510
511        return symbol_inputs
512
513    # full batch is set
514    device_num = _get_device_num() // _get_pipeline_stages()
515    for index, shape in enumerate(symbol_inputs):
516        for i, item in enumerate(shape):
517            if i == 0 and isinstance(item, dict):  # symbol
518                symbol_inputs[index][i][divisor_key] = symbol_inputs[index][i][divisor_key] * device_num
519
520    return symbol_inputs
521
522
523def _no_need_to_change_symbols(shapes):
524    """no need to handle the symbol if full_batch is true or it's not parallel mode"""
525    if not _need_to_full():
526        return True
527
528    # if static shape, return
529    is_dynamic_shape = False
530    for shape in shapes:
531        if any(i < 0 for i in shape):
532            is_dynamic_shape = True
533            break
534    if is_dynamic_shape is False:
535        return True
536
537    return False
538
539
540def _change_symbols_for_parallel(shapes, symbol_inputs=None):
541    """create or modify symbol inputs"""
542    if _no_need_to_change_symbols(shapes) is True:
543        return symbol_inputs
544    # the symbol_inputs is [[{'divisor': 8}, 16], [{'divisor': 8}, 16]]
545    # the dataset_shapes is [(-1, 16), (-1, 16)]
546    # if symbol_inputs is [None, None, ..., None], reset it
547    if symbol_inputs is not None and all(s is None for s in symbol_inputs):
548        symbol_inputs = []
549
550    # if symbol inputs is none or empty, create default symbol inputs
551    # if symbol inputs is not none, handle the empty symbol
552    divisor_key = "divisor"
553    if symbol_inputs is None or bool(symbol_inputs) is False:
554        symbol_inputs = [list(shape) for shape in shapes]  # tuple to list
555        for i, s in enumerate(symbol_inputs):
556            for j, item in enumerate(s):
557                if item == -1:
558                    symbol_inputs[i][j] = {divisor_key: 1}
559    else:
560        for i, s in enumerate(symbol_inputs):
561            # the symbol_inputs may be [None, [{'divisor': 8}, 16]]
562            # and the dataset_shapes is [(-1, 16), (-1, 16)], need to handle None
563            if s is None:
564                symbol_inputs[i] = shapes[i]
565                for k, item in enumerate(symbol_inputs[i]):
566                    if item == -1:
567                        symbol_inputs[i][k] = {divisor_key: 1}
568                s = symbol_inputs[i]
569            for j, item in enumerate(s):
570                if isinstance(item, dict) and bool(item) is False:  # the item is empty
571                    symbol_inputs[i][j] = {divisor_key: 1}
572
573    return _handle_symbol_inputs(symbol_inputs)
574
575
576def _grads_divided_by_device_num_if_recomputation(grads):
577    """
578    If in pynative parallel and full_batch is True, divide grads by device num to ensure that the gradients is correct.
579    """
580    if not _is_pynative_parallel() or not _get_full_batch():
581        return grads
582
583    device_num = _get_device_num()
584    logger.info(f"In PyNative mode, when parallel mode is in "
585                f"({context.ParallelMode.SEMI_AUTO_PARALLEL}, {context.ParallelMode.AUTO_PARALLEL}) and "
586                f"full_batch is Ture, the gradients will be automatically divided by device_num({device_num}).")
587
588    if not isinstance(grads, (tuple, Tensor)):
589        raise ValueError(f"The type of grads must be either Tuple[Tensor] or Tensor, but got {type(grads)}.")
590
591    if isinstance(grads, tuple):
592        new_grads = ()
593        if grads:
594            device_num_tensor = Tensor(device_num, grads[0].dtype)
595            for grad in grads:
596                new_grads += (grad / device_num_tensor,)
597    else:
598        device_num_tensor = Tensor(device_num, grads.dtype)
599        new_grads = grads / device_num_tensor
600    return new_grads
601