• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""grad reducer cell for distributed training"""
16from mindspore import context
17from mindspore.nn.cell import Cell
18from mindspore.communication.management import GlobalComm, get_group_size
19from mindspore.common.tensor import RowTensor
20from mindspore.ops import functional as F, composite as C
21from mindspore.ops.operations.comm_ops import AllReduce, AllGather
22from mindspore.parallel._auto_parallel_context import auto_parallel_context
23import mindspore.common.dtype as mstype
24from mindspore.common.tensor import Tensor
25
26
27reduce_opt = C.MultitypeFuncGraph("reduce_opt")
28
29
30def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
31    """ initialize allreduce communication operators"""
32    fusion_type = 2 ** 10
33    split = 0
34    fusion = ()
35    for i in range(length):
36        fusion = fusion + (fusion_type,)
37        if split >= len(split_indices):
38            continue
39        if split_indices[split] <= i:
40            fusion_type += 1
41            split += 1
42    index = tuple(range(1, length + 1))
43    op_list = ()
44    for i in range(length):
45        op = AllReduce('sum', group)
46        op.add_prim_attr('fusion', fusion[i])
47        op.add_prim_attr('index', index[i])
48        op_list = op_list + (op,)
49    return op_list
50
51
52def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fusion_type=1):
53    """ initialize allreduce communication operators by parameters"""
54    op_list = ()
55    param_fusion = False
56    last_comm_fusion = None
57    first_parameter_flag = True
58    index = 1
59    for parameter in parameters:
60        comm_fusion = parameter.comm_fusion
61        if first_parameter_flag:
62            last_comm_fusion = comm_fusion
63            first_parameter_flag = False
64        elif not param_fusion:
65            if comm_fusion != last_comm_fusion:
66                param_fusion = True
67                last_comm_fusion = comm_fusion
68        op = AllReduce('sum', group)
69        op.add_prim_attr('fusion', comm_fusion)
70        op.add_prim_attr('index', index)
71        index += 1
72        op_list = op_list + (op,)
73    if not param_fusion:
74        if split_indices and fusion_type == 1:
75            op_list = _init_allreduce_operators(len(parameters), split_indices, group)
76            param_fusion = True
77        else:
78            op_list = ()
79    return op_list, param_fusion
80
81
82@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
83def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
84    """
85    Apply allreduce on gradient.
86
87    Args:
88        degree (int): The mean coefficient.
89        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
90        allgather (Primitive): The communication operator for sparse gradients.
91        allreduce (Primitive): The communication operator for gradients.
92        allreduce_filter (bool): When it is true, allreduce would apply.
93        grad (Tensor): The gradient tensor before operation.
94
95    Returns:
96        Tensor, the gradient tensor after operation.
97    """
98    if allreduce_filter:
99        grad = allreduce(grad)
100        if mean:
101            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
102        return grad
103    return grad
104
105
106@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor")
107def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
108    """
109    Apply allreduce on gradient in PyNative mode.
110
111    Args:
112        degree (int): The mean coefficient.
113        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
114        allgather (Primitive): The communication operator for sparse gradients.
115        allreduce (Primitive): The communication operator for gradients.
116        allreduce_filter (bool): When it is true, allreduce would apply.
117        grad (Tensor): The gradient tensor before operation.
118
119    Returns:
120        Tensor, the gradient tensor after operation.
121    """
122    if allreduce_filter:
123        if mean:
124            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
125        return grad
126    return grad
127
128
129@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
130def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
131    """
132    Apply allreduce on gradient.
133
134    Args:
135        degree (int): The mean coefficient.
136        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
137        allgather (Primitive): The communication operator for sparse gradients.
138        allreduce (Primitive): The communication operator for gradients.
139        allreduce_filter (bool): When it is true, allreduce would apply.
140        grad (Tensor): The gradient tensor before operation.
141        ps_parameter (bool): Use parameter server or not.
142
143    Returns:
144        Tensor, the gradient tensor after operation.
145    """
146    if ps_parameter:
147        return grad
148
149    if allreduce_filter:
150        grad = allreduce(grad)
151        if mean:
152            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
153        return grad
154    return grad
155
156
157@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor")
158def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
159    """
160    Apply allgather on gradient instead of allreduce for sparse feature.
161    Allgather is a communication operation used for distributed deep learning.
162
163    Args:
164        degree (int): The mean coefficient.
165        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
166        allgather (Primitive): The communication operator for sparse gradients.
167        allreduce (Primitive): The communication operator for gradients.
168        allreduce_filter (bool): When it is true, allgather would apply.
169        grad (tuple): The indices, gradient tensor and tensor_shape before operation.
170
171    Returns:
172        RowTensor, the gradient after operation.
173    """
174    if allreduce_filter:
175        indices = allgather(grad.indices)
176        dout = allgather(grad.values)
177        if mean:
178            dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
179        grad = RowTensor(indices, dout, grad.dense_shape)
180    return grad
181
182
183@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
184def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
185    """
186    Apply allgather on gradient instead of allreduce for sparse feature.
187    Allgather is a communication operation used for distributed deep learning.
188
189    Args:
190        degree (int): The mean coefficient.
191        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
192        allgather (Primitive): The communication operator for sparse gradients.
193        allreduce (Primitive): The communication operator for gradients.
194        allreduce_filter (bool): When it is true, allgather would apply.
195        grad (tuple): The indices, gradient tensor and tensor_shape before operation.
196        ps_parameter (bool): Use parameter server or not.
197
198    Returns:
199        RowTensor, the gradient after operation.
200    """
201    if ps_parameter:
202        return grad
203
204    if allreduce_filter:
205        indices = allgather(grad.indices)
206        dout = allgather(grad.values)
207        if mean:
208            dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
209        grad = RowTensor(indices, dout, grad.dense_shape)
210    return grad
211
212
213_get_datatype = C.MultitypeFuncGraph("_get_datatype")
214
215
216@_get_datatype.register("Tensor")
217def _tensors_get_datatype(grad):
218    """
219    Acquire gradient datatype.
220
221    Args:
222        grad (Tensor): The gradient tensor before operation.
223
224    Returns:
225        mstype, the datatype of gradient.
226    """
227    return F.dtype(grad)
228
229
230@_get_datatype.register("RowTensor")
231def _tensors_get_datatype_with_sparse(grad):
232    """
233    Acquire gradient datatype.
234
235    Args:
236        grad (RowTensor): The gradient before operation.
237
238    Returns:
239        mstype, the datatype of gradient.
240    """
241    return F.dtype(grad.values)
242
243
244_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
245
246
247@_cast_datatype.register("TypeType", "Tensor")
248def _tensors_cast_datatype(datatype, grad):
249    """
250    Cast gradient to datatype.
251
252    Args:
253        datatype (mstype): the destination datatype of gradient.
254        grad (Tensor): The gradient tensor before operation.
255
256    Returns:
257        Tensor, the gradient tensor after operation.
258    """
259    return F.cast(grad, datatype)
260
261
262@_cast_datatype.register("TypeType", "RowTensor")
263def _tensors_cast_datatype_with_sparse(datatype, grad):
264    """
265    Cast gradient to datatype.
266
267    Args:
268        datatype (mstype): the destination datatype of gradient.
269        grad (RowTensor): The gradient before operation.
270
271    Returns:
272        RowTensor, the gradient after operation.
273    """
274    dout = F.cast(grad.values, datatype)
275    return RowTensor(grad.indices, dout, grad.dense_shape)
276
277
278class DistributedGradReducer(Cell):
279    """
280    A distributed optimizer.
281
282    Constructs a gradient reducer Cell, which applies communication and average operations on
283    single-process gradient values.
284
285    Args:
286        parameters (list): the parameters to be updated.
287        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
288        degree (int): The mean coefficient. Usually it equals to device number. Default: None.
289        fusion_type (int): The type of all reduce fusion. Default: 1.
290
291    Raises:
292        ValueError: If degree is not a int or less than 0.
293
294    Supported Platforms:
295        ``Ascend`` ``GPU``
296
297    Examples:
298        >>> # This example should be run with multiple processes.
299        >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
300        >>> import numpy as np
301        >>> from mindspore.communication import init
302        >>> from mindspore import ops
303        >>> from mindspore import context
304        >>> from mindspore.context import ParallelMode
305        >>> from mindspore import Parameter, Tensor
306        >>> from mindspore import nn
307        >>>
308        >>> context.set_context(mode=context.GRAPH_MODE)
309        >>> init()
310        >>> context.reset_auto_parallel_context()
311        >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
312        >>>
313        >>> class TrainingWrapper(nn.Cell):
314        ...     def __init__(self, network, optimizer, sens=1.0):
315        ...         super(TrainingWrapper, self).__init__(auto_prefix=False)
316        ...         self.network = network
317        ...         self.network.add_flags(defer_inline=True)
318        ...         self.weights = optimizer.parameters
319        ...         self.optimizer = optimizer
320        ...         self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
321        ...         self.sens = sens
322        ...         self.reducer_flag = False
323        ...         self.grad_reducer = None
324        ...         self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
325        ...         if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
326        ...             self.reducer_flag = True
327        ...         if self.reducer_flag:
328        ...             mean = context.get_auto_parallel_context("gradients_mean")
329        ...             degree = context.get_auto_parallel_context("device_num")
330        ...             self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
331        ...
332        ...     def construct(self, *args):
333        ...         weights = self.weights
334        ...         loss = self.network(*args)
335        ...         sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
336        ...         grads = self.grad(self.network, weights)(*args, sens)
337        ...         if self.reducer_flag:
338        ...             # apply grad reducer on grads
339        ...             grads = self.grad_reducer(grads)
340        ...         return ops.Depend(loss, self.optimizer(grads))
341        >>>
342        >>> class Net(nn.Cell):
343        ...     def __init__(self, in_features, out_features):
344        ...         super(Net, self).__init__()
345        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
346        ...                                 name='weight')
347        ...         self.matmul = ops.MatMul()
348        ...
349        ...     def construct(self, x):
350        ...         output = self.matmul(x, self.weight)
351        ...         return output
352        >>>
353        >>> size, in_features, out_features = 16, 16, 10
354        >>> network = Net(in_features, out_features)
355        >>> loss = nn.MSELoss()
356        >>> net_with_loss = nn.WithLossCell(network, loss)
357        >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
358        >>> train_cell = TrainingWrapper(net_with_loss, optimizer)
359        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
360        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
361        >>> grads = train_cell(inputs, label)
362        >>> print(grads)
363        256.0
364    """
365
366    def __init__(self, parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
367        super(DistributedGradReducer, self).__init__(auto_prefix=False)
368        self.map_ = C.Map()
369        if degree is None:
370            self.degree = get_group_size()
371        else:
372            if not isinstance(degree, int) or degree <= 0:
373                raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
374            self.degree = degree
375        self.degree = Tensor(1.0 / self.degree, mstype.float32)
376        self.mean = mean
377        self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
378        is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
379        split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
380        if is_parallel_optimizer and split_indices:
381            self.split_fusion = True
382            self.op_list = _init_allreduce_operators(len(parameters), split_indices, group)
383        else:
384            self.split_fusion = True
385            self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices, group,
386                                                                                 fusion_type)
387            if not param_fusion:
388                self.split_fusion = False
389                self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)
390        self.allgather = AllGather(group)
391        ps_filter = lambda x: x.is_param_ps
392        self.ps_parameters = tuple(ps_filter(x) for x in parameters)
393        self.enable_parameter_server = any(self.ps_parameters)
394        self.mode = context.get_context("mode")
395
396    def construct(self, grads):
397        """
398        Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
399        result of AllReduce is unreliable. To solve the problem, grads must be cast to float32 before AllReduce,
400        and cast back after the operation.
401
402        Args:
403            grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation.
404
405        Returns:
406            new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation.
407        """
408        datatypes = self.map_(F.partial(_get_datatype), grads)
409        grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
410        if self.mode == context.PYNATIVE_MODE:
411            new_grad = grads
412        elif self.split_fusion:
413            if self.enable_parameter_server:
414                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
415                                     self.op_list, self.allreduce_filter, grads, self.ps_parameters)
416            else:
417                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
418                                     self.op_list, self.allreduce_filter, grads)
419        else:
420            if self.enable_parameter_server:
421                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
422                                               self.allreduce), self.allreduce_filter, grads, self.ps_parameters)
423            else:
424                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
425                                               self.allreduce), self.allreduce_filter, grads)
426        new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
427        return new_grad
428