• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Cell_wrapper."""
16from types import FunctionType, MethodType
17
18from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
19                                       _get_parallel_mode, _get_enable_parallel_optimizer)
20from mindspore.context import ParallelMode
21from mindspore._checkparam import Validator as validator
22from mindspore import ops, nn
23from ...common import dtype as mstype
24from ...common.parameter import Parameter, ParameterTuple
25from ...ops import composite as C
26from ...ops import functional as F
27from ...ops import operations as P
28from ...ops.operations.comm_ops import _VirtualDataset
29from ..cell import Cell
30from .grad_reducer import DistributedGradReducer
31
32_get_datatype = C.MultitypeFuncGraph("_get_datatype")
33
34
35@_get_datatype.register("Tensor")
36def _tensors_get_datatype(param):
37    """
38    Acquire parameter datatype.
39
40    Args:
41        param (Tensor): The parameter before operation.
42
43    Returns:
44        mstype, the datatype of parameter.
45    """
46    return F.dtype(param)
47
48
49_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
50
51
52@_cast_datatype.register("TypeType", "Tensor")
53def _tensors_cast_datatype(datatype, param):
54    """
55    Cast gradient to datatype.
56
57    Args:
58        datatype (mstype): the destination datatype of parameter.
59        param (Tensor): The parameter before operation.
60
61    Returns:
62        Tensor, the parameter after operation.
63    """
64    return F.cast(param, datatype)
65
66
67
68class WithLossCell(Cell):
69    r"""
70    Cell with loss function.
71
72    Wraps the network with loss function. This Cell accepts data and label as inputs and
73    the computed loss will be returned.
74
75    Args:
76        backbone (Cell): The target network to wrap.
77        loss_fn (Cell): The loss function used to compute loss.
78
79    Inputs:
80        - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
81        - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
82
83    Outputs:
84        Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
85
86    Raises:
87        TypeError: If dtype of `data` or `label` is neither float16 nor float32.
88
89    Supported Platforms:
90        ``Ascend`` ``GPU`` ``CPU``
91
92    Examples:
93        >>> net = Net()
94        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
95        >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
96        >>>
97        >>> batch_size = 2
98        >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
99        >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
100        >>>
101        >>> output_data = net_with_criterion(data, label)
102    """
103
104    def __init__(self, backbone, loss_fn):
105        super(WithLossCell, self).__init__(auto_prefix=False)
106        self._backbone = backbone
107        self._loss_fn = loss_fn
108
109    def construct(self, data, label):
110        out = self._backbone(data)
111        return self._loss_fn(out, label)
112
113    @property
114    def backbone_network(self):
115        """
116        Get the backbone network.
117
118        Returns:
119            Cell, the backbone network.
120        """
121        return self._backbone
122
123
124class WithGradCell(Cell):
125    r"""
126    Cell that returns the gradients.
127
128    Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
129    as argument. If loss function in None, the network must be a wrapper of network and loss function. This
130    Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter.
131
132    Note:
133        Run in PyNative mode.
134
135    Args:
136        network (Cell): The target network to wrap. The network only supports single output.
137        loss_fn (Cell): Primitive loss function used to compute gradients. Default: None.
138        sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
139            must be same as the `network` output. If None, we will fill one to a same type shape of
140            output value. Default: None.
141
142    Inputs:
143        - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
144
145    Outputs:
146        list, a list of Tensors with identical shapes as trainable weights.
147
148    Raises:
149        TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple.
150
151    Supported Platforms:
152        ``Ascend`` ``GPU`` ``CPU``
153
154    Examples:
155        >>> # For a defined network Net without loss function
156        >>> net = Net()
157        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
158        >>> grad_net = nn.WithGradCell(net, loss_fn)
159        >>>
160        >>> # For a network wrapped with loss function
161        >>> net = Net()
162        >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
163        >>> grad_net = nn.WithGradCell(net_with_criterion)
164    """
165
166    def __init__(self, network, loss_fn=None, sens=None):
167        super(WithGradCell, self).__init__(auto_prefix=False)
168        self.network = network
169        self.loss_fn = loss_fn
170        self.weights = ParameterTuple(network.trainable_params())
171        self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
172        self.sens = sens
173        if loss_fn is None:
174            self.network_with_loss = network
175        else:
176            self.network_with_loss = WithLossCell(self.network, self.loss_fn)
177        self.network_with_loss.set_train()
178
179    def construct(self, *inputs):
180        weights = self.weights
181        if self.sens is None:
182            grads = self.grad(self.network_with_loss, weights)(*inputs)
183        else:
184            grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
185        return grads
186
187
188class ForwardValueAndGrad(Cell):
189    r"""
190    Network training package class.
191
192    Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
193    The backward graph will be created in the gradient function to calculating gradient.
194
195    Args:
196        network (Cell): The training network.
197        weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
198        get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
199        get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
200            If get_all and get_by_list are both False, get the gradient with respect to first input.
201            If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
202            at the same time in the form of ((gradients with respect to inputs),
203            (gradients with respect to parameters)). Default: False.
204        sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
205            If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
206            Default: False.
207            If the sens_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
208            the input parameter.
209
210    Inputs:
211        - **(\*inputs)** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
212        - **(sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation.
213          If network has single output, the sens is a tensor.
214          If network has multiple outputs, the sens is the tuple(tensor).
215
216    Outputs:
217        - **forward value** - The result of network forward running.
218        - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
219
220    Supported Platforms:
221        ``Ascend`` ``GPU`` ``CPU``
222
223    Examples:
224        >>> class Net(nn.Cell):
225        ...    def __init__(self):
226        ...        super(Net, self).__init__()
227        ...        self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight")
228        ...        self.matmul = P.MatMul()
229        ...
230        ...    def construct(self, x):
231        ...        out = self.matmul(x, self.weight)
232        ...        return out
233        ...
234        >>> net = Net()
235        >>> criterion = nn.SoftmaxCrossEntropyWithLogits()
236        >>> net_with_criterion = nn.WithLossCell(net, criterion)
237        >>> weight = ParameterTuple(net.trainable_params())
238        >>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True)
239        >>> inputs = Tensor(np.ones([1, 2]).astype(np.float32))
240        >>> labels = Tensor(np.zeros([1, 2]).astype(np.float32))
241        >>> result = train_network(inputs, labels)
242        >>> print(result)
243        (Tensor(shape=[1], dtype=Float32, value=[0.00000000e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value=
244        [[1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value=
245        [[0.00000000e+00, 0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value=
246        [[5.00000000e-01, 5.00000000e-01],
247         [5.00000000e-01, 5.00000000e-01]]),)))
248    """
249
250    def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
251        super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
252        if not isinstance(network, (Cell, FunctionType, MethodType)):
253            raise TypeError(f"The type of training network should be cell, function type or method type, "
254                            f"but got '{type(network)}'")
255        if not isinstance(get_all, bool):
256            raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'")
257        if not isinstance(get_by_list, bool):
258            raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'")
259        if get_by_list and not isinstance(weights, ParameterTuple):
260            raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
261                            f"ParameterTuple type, but got '{type(weights)}'")
262        self.network = network
263        if isinstance(network, Cell):
264            self.network.set_grad()
265        self.weights = weights
266        self.get_all = get_all
267        self.get_by_list = get_by_list
268        self.sens_param = sens_param
269        self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
270
271    def construct(self, *inputs):
272        grad_inputs = inputs
273        if self.sens_param:
274            inputs = inputs[:-1]
275        loss = self.network(*inputs)
276        if self.get_by_list:
277            grads = self.grad(self.network, self.weights)(*grad_inputs)
278        else:
279            grads = self.grad(self.network)(*grad_inputs)
280        return loss, grads
281
282
283class TrainOneStepCell(Cell):
284    r"""
285    Network training package class.
286
287    Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'.
288    The backward graph will be created in the construct function to update the parameter. Different
289    parallel modes are available for training.
290
291    Args:
292        network (Cell): The training network. The network only supports single output.
293        optimizer (Union[Cell]): Optimizer for updating the weights.
294        sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
295
296    Inputs:
297        - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
298
299    Outputs:
300        Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
301
302    Raises:
303        TypeError: If `sens` is not a number.
304
305    Supported Platforms:
306        ``Ascend`` ``GPU`` ``CPU``
307
308    Examples:
309        >>> net = Net()
310        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
311        >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
312        >>> #1) Using the WithLossCell existing provide
313        >>> loss_net = nn.WithLossCell(net, loss_fn)
314        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
315        >>>
316        >>> #2) Using user-defined WithLossCell
317        >>> class MyWithLossCell(Cell):
318        ...    def __init__(self, backbone, loss_fn):
319        ...        super(MyWithLossCell, self).__init__(auto_prefix=False)
320        ...        self._backbone = backbone
321        ...        self._loss_fn = loss_fn
322        ...
323        ...    def construct(self, x, y, label):
324        ...        out = self._backbone(x, y)
325        ...        return self._loss_fn(out, label)
326        ...
327        ...    @property
328        ...    def backbone_network(self):
329        ...        return self._backbone
330        ...
331        >>> loss_net = MyWithLossCell(net, loss_fn)
332        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
333    """
334
335    def __init__(self, network, optimizer, sens=1.0):
336        super(TrainOneStepCell, self).__init__(auto_prefix=False)
337        self.network = network
338        self.network.set_grad()
339        self.optimizer = optimizer
340        self.weights = self.optimizer.parameters
341        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
342        self.sens = sens
343        self.reducer_flag = False
344        self.grad_reducer = F.identity
345        self.parallel_mode = _get_parallel_mode()
346        self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
347        if self.reducer_flag:
348            self.mean = _get_gradients_mean()
349            self.degree = _get_device_num()
350            self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
351
352    def construct(self, *inputs):
353        loss = self.network(*inputs)
354        sens = F.fill(loss.dtype, loss.shape, self.sens)
355        grads = self.grad(self.network, self.weights)(*inputs, sens)
356        grads = self.grad_reducer(grads)
357        loss = F.depend(loss, self.optimizer(grads))
358        return loss
359
360
361class GetNextSingleOp(Cell):
362    """
363    Cell to run for getting the next operation.
364
365    Args:
366        dataset_types (list[:class:`mindspore.dtype`]): The types of dataset.
367        dataset_shapes (list[tuple[int]]): The shapes of dataset.
368        queue_name (str): Queue name to fetch the data.
369
370    For detailed information, refer to `ops.operations.GetNext`.
371
372    Inputs:
373        No inputs.
374
375    Outputs:
376        tuple[Tensor], the data get from Dataset.
377
378    Supported Platforms:
379        ``Ascend`` ``GPU``
380
381    Examples:
382        >>> train_dataset = create_custom_dataset()
383        >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
384        >>> dataset = dataset_helper.iter.dataset
385        >>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
386        >>> queue_name = dataset.__transfer_dataset__.queue_name
387        >>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name)
388        >>> data, label = get_next_single_op_net()
389        >>> relu = P.ReLU()
390        >>> result = relu(data).asnumpy()
391        >>> print(result.shape)
392        (32, 1, 32, 32)
393    """
394
395    def __init__(self, dataset_types, dataset_shapes, queue_name):
396        super(GetNextSingleOp, self).__init__()
397        self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
398
399    def construct(self):
400        return self.get_next()
401
402
403class _VirtualDatasetCell(Cell):
404    """
405    Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
406
407    _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
408    of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted
409    dynamically during the graph compile process.
410
411    Note:
412        Only used in semi auto parallel and auto parallel mode.
413
414    Args:
415        backbone (Cell): The target network to wrap.
416
417    Examples:
418        >>> net = Net()
419        >>> net = _VirtualDatasetCell(net)
420    """
421
422    def __init__(self, backbone):
423        super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
424        self._backbone = backbone
425        self._virtual_dataset = _VirtualDataset()
426
427    def construct(self, *inputs):
428        output = self._virtual_dataset(*inputs)
429        return self._backbone(*output)
430
431
432class _MicroBatch(Cell):
433    """
434    transform mini-batch to micro-batch in pipeline parallel.
435
436    Args:
437       params (micro_size): The number of micro-batch.
438    """
439    def __init__(self, micro_size):
440        super(_MicroBatch, self).__init__()
441        self.shape = P.Shape()
442        self.micro_size = micro_size
443
444    def construct(self, i, *inputs):
445        micro_inputs = ()
446        for each_input in inputs:
447            input_shape = self.shape(each_input)
448            micro_batch_begin = i * input_shape[0] // self.micro_size
449            micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
450            micro_input = each_input[micro_batch_begin:micro_batch_end]
451            micro_inputs += (micro_input,)
452        return micro_inputs
453
454
455class PipelineCell(Cell):
456    """
457    Wrap the network with Micro Batch.
458
459    Note:
460        micro_size must be greater or equal to pipeline stages.
461
462    Args:
463        network (Cell): The target network to wrap.
464        micro_size (int): MicroBatch size.
465
466    Examples:
467        >>> net = Net()
468        >>> net = PipelineCell(net, 4)
469    """
470    def __init__(self, network, micro_size):
471        super(PipelineCell, self).__init__(auto_prefix=False)
472        self.network = network
473        self.micro_inputs = nn.CellList()
474        self.micro_size = micro_size
475        self.add_list = []
476        for i in range(micro_size):
477            micro_input = _MicroBatch(micro_size)
478            self.micro_inputs.append(micro_input)
479            self.add = P.Add().add_prim_attr("pipeline_end", i)
480            self.add_list.append(self.add)
481
482    def construct(self, *inputs):
483        ret = None
484        for i in range(self.micro_size):
485            micro_input = self.micro_inputs[i](i, *inputs)
486            output = self.network(*micro_input)
487            if ret is not None:
488                ret = self.add_list[i](ret, output)
489            else:
490                ret = output
491        return ret
492
493
494def _pipeline_clear_grad(accu_grad, grad):
495    accu_grad = F.depend(accu_grad, grad)
496    zeros = F.tensor_mul(accu_grad, 0.0)
497    return F.assign(accu_grad, zeros)
498
499
500class _TrainPipelineAccuStepCell(TrainOneStepCell):
501    """
502    Wraps the network with an optimizer in pipeline mode.
503    """
504    def __init__(self, network, optimizer, sens=1.0):
505        super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
506        self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
507        self.hyper_map = ops.HyperMap()
508        self.opt_shard = _get_enable_parallel_optimizer()
509
510    def construct(self, *inputs):
511        weights = self.weights
512        loss = self.network(*inputs)
513        sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
514        grads = self.grad(self.network, weights)(*inputs, sens)
515        accu_grads = ops.depend(self.accu_grads, grads)
516        if self.opt_shard:
517            succ = self.optimizer(grads)
518        else:
519            succ = self.optimizer(accu_grads)
520        loss = ops.depend(loss, succ)
521        clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
522        loss = ops.depend(loss, clear)
523        return loss
524
525
526class VirtualDatasetCellTriple(Cell):
527    """
528    Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
529
530    VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
531    of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
532    dynamically during the graph compile process.
533
534    Note:
535        Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
536        _VirtualDatasetCell.
537
538    Args:
539        backbone (Cell): The target network to wrap.
540
541    Examples:
542        >>> net = Net()
543        >>> net = VirtualDatasetCellTriple(net)
544    """
545
546    def __init__(self, backbone):
547        super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
548        self._backbone = backbone
549        self._virtual_dataset = _VirtualDataset()
550
551    def construct(self, a, b, c):
552        a_, b_, c_ = self._virtual_dataset(a, b, c)
553        return self._backbone(a_, b_, c_)
554
555
556class WithEvalCell(Cell):
557    r"""
558    Cell that returns loss, output and label for evaluation.
559
560    This Cell accepts a network and loss function as arguments and computes loss for model.
561    It returns loss, output and label to calculate the metrics.
562
563    Args:
564        network (Cell): The network Cell.
565        loss_fn (Cell): The loss Cell.
566        add_cast_fp32 (bool): Adjust the data type to float32.
567
568    Inputs:
569        - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
570        - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
571
572    Outputs:
573        Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
574        and a label Tensor of shape :math:`(N, \ldots)`.
575
576    Raises:
577        TypeError: If `add_cast_fp32` is not a bool.
578
579    Supported Platforms:
580        ``Ascend`` ``GPU`` ``CPU``
581
582    Examples:
583        >>> # For a defined network Net without loss function
584        >>> net = Net()
585        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
586        >>> eval_net = nn.WithEvalCell(net, loss_fn)
587    """
588
589    def __init__(self, network, loss_fn, add_cast_fp32=False):
590        super(WithEvalCell, self).__init__(auto_prefix=False)
591        self._network = network
592        self._loss_fn = loss_fn
593        self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
594
595    def construct(self, data, label):
596        outputs = self._network(data)
597        if self.add_cast_fp32:
598            label = F.mixed_precision_cast(mstype.float32, label)
599            outputs = F.cast(outputs, mstype.float32)
600        loss = self._loss_fn(outputs, label)
601        return loss, outputs, label
602
603
604class ParameterUpdate(Cell):
605    """
606    Cell that updates parameter.
607
608    With this Cell, one can manually update `param` with the input `Tensor`.
609
610    Args:
611        param (Parameter): The parameter to be updated manually.
612
613    Inputs:
614        - **x** (Tensor) - A tensor whose shape and type are the same with `param`.
615
616    Outputs:
617        Tensor, the input `x`.
618
619    Raises:
620        KeyError: If parameter with the specified name does not exist.
621
622    Supported Platforms:
623        ``Ascend`` ``GPU`` ``CPU``
624
625    Examples:
626        >>> network = nn.Dense(3, 4)
627        >>> param = network.parameters_dict()['weight']
628        >>> update = nn.ParameterUpdate(param)
629        >>> update.phase = "update_param"
630        >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
631        >>> output = update(weight)
632    """
633
634    def __init__(self, param):
635        super(ParameterUpdate, self).__init__(auto_prefix=False)
636        if not isinstance(param, Parameter):
637            raise TypeError("`param` must be `Parameter`, but got {}".format(param))
638        self._param = param
639
640    def construct(self, x):
641        F.assign(self._param, x)
642        return x
643
644
645class _BroadCastCell(Cell):
646    """
647    Broadcast the parameters from device 0 to other devices.
648
649    Args:
650       params (list): The parameters of Net.
651    """
652
653    def __init__(self, params):
654        super(_BroadCastCell, self).__init__()
655        from mindspore.communication.management import get_group_size, create_group
656        from mindspore import context
657        self.map_ = C.Map()
658        self.params = tuple(params)
659        if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE:
660            rank_list = [id for id in range(0, get_group_size())]
661            create_group("BroadcastWorldGroup", rank_list)
662            self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup")
663        else:
664            self.broadcast = P.Broadcast(0)
665
666    def construct(self):
667        datatypes = self.map_(F.partial(_get_datatype), self.params)
668        params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
669        params = self.broadcast(params)
670        new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
671        return new_params
672