• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Utils of auto parallel"""
16
17import numpy as np
18from mindspore import context, log as logger
19from mindspore.context import ParallelMode
20from mindspore._c_expression import reset_op_id
21from mindspore.common.tensor import Tensor
22from mindspore.common.dtype import dtype_to_nptype
23from mindspore.common import dtype as mstype
24from mindspore.communication.management import get_group_size, get_rank
25from mindspore.parallel._auto_parallel_context import auto_parallel_context
26from mindspore.common.seed import get_seed
27
28
29def _get_parallel_mode():
30    """Get parallel mode."""
31    return auto_parallel_context().get_parallel_mode()
32
33
34def _is_in_auto_parallel_mode():
35    return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]
36
37
38def _get_full_batch():
39    """Get whether to use full_batch."""
40    return auto_parallel_context().get_full_batch()
41
42
43def _get_pipeline_stages():
44    """Get pipeline stages"""
45    return auto_parallel_context().get_pipeline_stages()
46
47
48def _check_task_sink_envs():
49    """
50    Check whether task_sink environment variables have been exported or not.
51
52    return True if task_sink environment variables have been exported, False otherwise.
53    """
54    import os
55    task_sink = os.getenv("GRAPH_OP_RUN")
56    if task_sink and task_sink.isdigit() and int(task_sink) == 1:
57        return False
58    return True
59
60
61def _check_full_batch():
62    """
63    full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
64
65    Raises:
66        RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
67    """
68    parallel_mode = _get_parallel_mode()
69    full_batch = _get_full_batch()
70    if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
71        raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
72
73
74def _need_to_full():
75    """Check whether to convert input to full shape or tensor."""
76    if _get_parallel_mode() not in ("semi_auto_parallel", "auto_parallel"):
77        return False
78    dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
79    if dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"):
80        return True
81    return not _get_full_batch()
82
83
84def _to_full_shapes(shapes, device_num):
85    """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
86    new_shapes = []
87    dataset_strategy = ()
88    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
89        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
90    if dataset_strategy:
91        if len(shapes) != len(dataset_strategy):
92            raise ValueError("The input shapes size {} is not equal to "
93                             "dataset strategy size {}".format(len(shapes), len(dataset_strategy)))
94        for index, shape in enumerate(shapes):
95            if len(shape) != len(dataset_strategy[index]):
96                raise ValueError("The input shapes item size {} is not equal to "
97                                 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
98            new_shape = ()
99            for i, item in enumerate(shape):
100                new_shape += (item * dataset_strategy[index][i],)
101            new_shapes.append(new_shape)
102        return new_shapes
103    for shape in shapes:
104        new_shape = ()
105        for i, item in enumerate(shape):
106            if i == 0:
107                new_shape += (item * device_num,)
108            else:
109                new_shape += (item,)
110        new_shapes.append(new_shape)
111    return new_shapes
112
113
114def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
115    """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
116       from host solution.
117    """
118    lst = []
119    device_num = global_device_num // _get_pipeline_stages()
120    stage_rank = global_rank % device_num
121    if not isinstance(elem, (tuple, list)):
122        elem = [elem]
123    if stage_rank >= device_num:
124        raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
125                         "the device num is {}".format(stage_rank, device_num))
126    dataset_strategy = ()
127    if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
128        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
129    if elem and dataset_strategy:
130        if len(elem) != len(dataset_strategy):
131            raise ValueError("The input size {} is not equal to "
132                             "dataset strategy size {}".format(len(elem), len(dataset_strategy)))
133    for index, data in enumerate(elem):
134        if isinstance(data, np.ndarray):
135            data = Tensor(data)
136        if not isinstance(data, Tensor):
137            raise ValueError("elements in tensors must be Tensor")
138        shape_ = data.shape
139        type_ = data.dtype
140        new_shape = ()
141        if not dataset_strategy:
142            batchsize_per_device = 1
143            for i, item in enumerate(shape_):
144                if i == 0:
145                    new_shape += (item * device_num,)
146                    batchsize_per_device = item
147                else:
148                    new_shape += (item,)
149            new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
150            start = stage_rank * batchsize_per_device
151            new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
152        else:
153            if len(shape_) != len(dataset_strategy[index]):
154                raise ValueError("The input shapes item size {} is not equal to "
155                                 "dataset strategy item size {}".format(len(shape_), len(dataset_strategy[index])))
156            slice_index = ()
157            for i, item in enumerate(shape_):
158                new_shape += (item * dataset_strategy[index][i],)
159                start = (stage_rank % dataset_strategy[index][i]) * item
160                end = (stage_rank % dataset_strategy[index][i] + 1) * item
161                s = slice(start, end, 1)
162                slice_index += (s,)
163            new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
164            new_tensor_numpy[slice_index] = data.asnumpy()
165        new_tensor = Tensor(new_tensor_numpy)
166        lst.append(new_tensor)
167    if scaling_sens:
168        lst.append(Tensor(scaling_sens, mstype.float32))
169    return tuple(lst)
170
171
172def _get_gradients_mean():
173    """Get if using gradients_mean."""
174    return auto_parallel_context().get_gradients_mean()
175
176
177def _get_device_num():
178    """Get the device num."""
179    parallel_mode = auto_parallel_context().get_parallel_mode()
180    if parallel_mode == "stand_alone":
181        device_num = 1
182        return device_num
183
184    if auto_parallel_context().get_device_num_is_set() is False:
185        device_num = get_group_size()
186    else:
187        device_num = auto_parallel_context().get_device_num()
188    return device_num
189
190
191def _get_global_rank():
192    """Get the global rank."""
193    parallel_mode = auto_parallel_context().get_parallel_mode()
194    if parallel_mode == "stand_alone":
195        global_rank = 0
196        return global_rank
197
198    if auto_parallel_context().get_global_rank_is_set() is False:
199        global_rank = get_rank()
200    else:
201        global_rank = auto_parallel_context().get_global_rank()
202    return global_rank
203
204
205def _get_parameter_broadcast():
206    """Get the parameter broadcast."""
207    parallel_mode = auto_parallel_context().get_parallel_mode()
208    parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
209
210    if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed() is None:
211        logger.warning("You are suggested to use mindspore.context.set_auto_parallel_context(parameter_broadcast=True)"
212                       " or mindspore.common.set_seed() to share parameters among multi-devices.")
213
214    return parameter_broadcast
215
216
217def _get_enable_parallel_optimizer():
218    """Get if using parallel optimizer."""
219    return auto_parallel_context().get_enable_parallel_optimizer()
220
221
222def _device_number_check(parallel_mode, device_number):
223    """
224    Check device num.
225
226    Args:
227        parallel_mode (str): The parallel mode.
228        device_number (int): The device number.
229    """
230    if parallel_mode == "stand_alone" and device_number != 1:
231        raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
232                         "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
233
234
235def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
236    """
237    Check parameter broadcast.
238
239    Note:
240        If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
241        random seed to make sure parameters on multiple devices are the same.
242
243    Args:
244        parallel_mode (str): The parallel mode.
245        parameter_broadcast (bool): The parameter broadcast.
246
247    Raises:
248        ValueError: If parameter is broadcasted
249                    but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
250    """
251    if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
252        raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
253                         "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
254                         .format(parallel_mode, parameter_broadcast))
255
256
257def _get_python_op(op_name, op_path, instance_name, arglist):
258    """Get python operator."""
259    module = __import__(op_path, fromlist=["None"])
260    cls = getattr(module, op_name)
261    if op_path != "mindspore.ops.functional":
262        op = cls(*arglist)
263    else:
264        op = cls
265    op.set_prim_instance_name(instance_name)
266    return op
267
268
269def _reset_op_id():
270    """Reset op id."""
271    reset_op_id()
272
273
274def _parallel_predict_check():
275    """validate parallel model prediction"""
276    if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
277        dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
278        is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"))
279        if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp:
280            raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
281
282
283def _check_similar_layout(tensor_layout1, tensor_layout2):
284    """check if two tensor layouts are same"""
285    if tensor_layout1[1] != tensor_layout2[1]:
286        return False
287    for i in tensor_layout1[1]:
288        if i == -1:
289            continue
290        if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]:
291            return False
292    return True
293
294
295def _check_same_layout(tensor_layout1, tensor_layout2):
296    """check if two tensor layouts are same"""
297    return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
298
299
300def _remove_repeated_slices(tensor_layout):
301    """generate unrepeated tensor layout"""
302    import copy
303    new_tensor_layout = copy.deepcopy(tensor_layout)
304    dev_mat = tensor_layout[0][:]
305    tensor_map = tensor_layout[1]
306    for dim in range(len(dev_mat)):
307        if dim not in tensor_map:
308            dev_mat[-1-dim] = 1
309    new_tensor_layout[0] = dev_mat
310    return new_tensor_layout
311
312
313def _infer_rank_list(train_map, predict_map=None):
314    """infer checkpoint slices to be loaded"""
315    ret = {}
316    if _get_pipeline_stages() > 1:
317        local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
318    else:
319        local_rank = _get_global_rank()
320    for param_name in train_map:
321        train_layout = train_map[param_name]
322        train_dev_mat = train_layout[0]
323        dev_num = np.array(train_dev_mat).prod()
324        new_train_layout = _remove_repeated_slices(train_layout)
325        array = np.arange(dev_num).reshape(train_dev_mat)
326        index = ()
327        for i in new_train_layout[0]:
328            if i == 1:
329                index = index + (0,)
330            else:
331                index = index + (slice(None),)
332        rank_list = array[index].flatten()
333        if not predict_map:
334            ret[param_name] = (rank_list, False)
335            continue
336        if param_name not in predict_map:
337            logger.warning("predict_map does not contain %s", param_name)
338            continue
339        predict_layout = predict_map[param_name]
340        dev_num = np.array(predict_layout[0]).prod()
341        # optimization pass
342        if _check_same_layout(train_layout, predict_layout):
343            ret[param_name] = ([local_rank], True)
344            continue
345        if _check_similar_layout(train_layout, predict_layout):
346            if len(rank_list) == 1:
347                ret[param_name] = (rank_list, True)
348            elif len(rank_list) == dev_num:
349                ret[param_name] = ([rank_list[local_rank]], True)
350            else:
351                ret[param_name] = (rank_list, False)
352        else:
353            ret[param_name] = (rank_list, False)
354    return ret
355