• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2# Copyright 2020 Huawei Technologies Co., Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ============================================================================
16"""Cell_wrapper."""
17from __future__ import absolute_import
18from __future__ import division
19
20import os
21from types import FunctionType, MethodType
22
23from mindspore import log as logger
24from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
25    _get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
26from mindspore.context import ParallelMode, GRAPH_MODE, get_context
27from mindspore import _checkparam as validator
28from mindspore import ops, nn
29from mindspore.common import dtype as mstype
30from mindspore.common.parameter import Parameter, ParameterTuple
31from mindspore.common.tensor import Tensor
32from mindspore.ops.primitive import _primexpr
33from mindspore.ops import composite as C
34from mindspore.ops import functional as F
35from mindspore.ops import operations as P
36from mindspore.ops.operations.comm_ops import _VirtualDataset
37from mindspore.nn.cell import Cell
38from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
39
40_get_datatype = C.MultitypeFuncGraph("_get_datatype")
41
42
43@_get_datatype.register("Tensor")
44def _tensors_get_datatype(param):
45    """
46    Acquire parameter datatype.
47
48    Args:
49        param (Tensor): The parameter before operation.
50
51    Returns:
52        mstype, the datatype of parameter.
53    """
54    return F.dtype(param)
55
56
57_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
58
59
60@_cast_datatype.register("TypeType", "Tensor")
61def _tensors_cast_datatype(datatype, param):
62    """
63    Cast gradient to datatype.
64
65    Args:
66        datatype (mstype): the destination datatype of parameter.
67        param (Tensor): The parameter before operation.
68
69    Returns:
70        Tensor, the parameter after operation.
71    """
72    return F.cast(param, datatype)
73
74
75class WithLossCell(Cell):
76    r"""
77    Cell with loss function.
78
79    Wraps the network with loss function. This Cell accepts data and label as inputs and
80    the computed loss will be returned.
81
82    Args:
83        backbone (Cell): The backbone network to wrap.
84        loss_fn (Cell): The loss function used to compute loss.
85
86    Inputs:
87        - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
88        - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
89
90    Outputs:
91        Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
92
93    Raises:
94        TypeError: If dtype of `data` or `label` is neither float16 nor float32.
95
96    Supported Platforms:
97        ``Ascend`` ``GPU`` ``CPU``
98
99    Examples:
100        >>> import mindspore as ms
101        >>> from mindspore import Tensor, nn
102        >>> import numpy as np
103        >>> # Define the network structure of LeNet5. Refer to
104        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
105        >>> net = LeNet5()
106        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
107        >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
108        >>>
109        >>> batch_size = 2
110        >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
111        >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
112        >>>
113        >>> output_data = net_with_criterion(data, label)
114    """
115
116    def __init__(self, backbone, loss_fn):
117        super(WithLossCell, self).__init__(auto_prefix=False)
118        self._backbone = backbone
119        self._loss_fn = loss_fn
120        self._get_attr_from_cell(backbone)
121
122    def construct(self, data, label):
123        out = self._backbone(data)
124        return self._loss_fn(out, label)
125
126    @property
127    def backbone_network(self):
128        """
129        Get the backbone network.
130
131        Returns:
132            Cell, the backbone network.
133
134        Examples:
135            >>> from mindspore import nn
136            >>> # Define the network structure of LeNet5. Refer to
137            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
138            >>> net = LeNet5()
139            >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
140            >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
141            >>> backbone = net_with_criterion.backbone_network
142        """
143        return self._backbone
144
145
146class WithGradCell(Cell):
147    r"""
148    Cell that returns the gradients.
149
150    Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
151    as argument. If loss function in None, the network must be a wrapper of network and loss function. This
152    Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter.
153
154    Note:
155        Run in PyNative mode.
156
157    Args:
158        network (Cell): The target network to wrap. The network only supports single output.
159        loss_fn (Cell): Primitive loss function used to compute gradients. Default: ``None`` .
160        sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
161            must be same as the `network` output. If ``None`` , we will fill one to a same type shape of
162            output value. Default: ``None`` .
163
164    Inputs:
165        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
166
167    Outputs:
168        list, a list of Tensors with identical shapes as trainable weights.
169
170    Raises:
171        TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple.
172
173    Supported Platforms:
174        ``Ascend`` ``GPU`` ``CPU``
175
176    Examples:
177        >>> import mindspore as ms
178        >>> from mindspore import nn
179        >>> # Defined a network without loss function, taking LeNet5 as an example.
180        >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
181        >>> net = LeNet5()
182        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
183        >>> grad_net = nn.WithGradCell(net, loss_fn)
184        >>>
185        >>> # For a network wrapped with loss function
186        >>> net = Net()
187        >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
188        >>> grad_net = nn.WithGradCell(net_with_criterion)
189    """
190
191    def __init__(self, network, loss_fn=None, sens=None):
192        super(WithGradCell, self).__init__(auto_prefix=False)
193        self.network = network
194        self.loss_fn = loss_fn
195        self.weights = ParameterTuple(network.trainable_params())
196        self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
197        self.sens = sens
198        if loss_fn is None:
199            self.network_with_loss = network
200        else:
201            self.network_with_loss = WithLossCell(self.network, self.loss_fn)
202        self.network_with_loss.set_train()
203        self._get_attr_from_cell(network)
204
205    def construct(self, *inputs):
206        weights = self.weights
207        if self.sens is None:
208            grads = self.grad(self.network_with_loss, weights)(*inputs)
209        else:
210            grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
211        return grads
212
213
214class ForwardValueAndGrad(Cell):
215    r"""
216    Encapsulate training network.
217
218    Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
219    The backward graph will be created in the gradient function to calculating gradient.
220
221    Args:
222        network (Union[Cell, Function, MethodType]): The training network.
223        weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
224            Default: ``None`` .
225        get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
226        get_by_list (bool): If ``True`` s, get all the gradients with respect to Parameter variables.
227            If get_all and get_by_list are both ``False`` , get the gradient with respect to first input.
228            If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and Parameter
229            variables  at the same time in the form of ((gradients with respect to inputs),
230            (gradients with respect to parameters)). Default: ``False`` .
231        sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
232            If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically.
233            Default: ``False`` .
234            If the sens_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred
235            through the input parameter.
236
237    Inputs:
238        - **\*inputs** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
239        - **sens** - A sensitivity (gradient with respect to output) as the input of backpropagation.
240          If network has single output, the sens is a tensor.
241          If network has multiple outputs, the sens is the tuple(tensor).
242
243    Outputs:
244        - **forward value** - The result of network forward running.
245        - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
246
247    Supported Platforms:
248        ``Ascend`` ``GPU`` ``CPU``
249
250    Examples:
251        >>> import numpy as np
252        >>> import mindspore
253        >>> from mindspore import Tensor, nn, ops, ParameterTuple, Parameter
254        >>>
255        >>> class Net(nn.Cell):
256        ...    def __init__(self):
257        ...        super(Net, self).__init__()
258        ...        self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight")
259        ...        self.matmul = ops.MatMul()
260        ...
261        ...    def construct(self, x):
262        ...        out = self.matmul(x, self.weight)
263        ...        return out
264        ...
265        >>> net = Net()
266        >>> criterion = nn.SoftmaxCrossEntropyWithLogits()
267        >>> net_with_criterion = nn.WithLossCell(net, criterion)
268        >>> weight = ParameterTuple(net.trainable_params())
269        >>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True)
270        >>> inputs = Tensor(np.ones([1, 2]).astype(np.float32))
271        >>> labels = Tensor(np.ones([1, 2]).astype(np.float32))
272        >>> result = train_network(inputs, labels)
273        >>> print(result)
274         (Tensor(shape=[1], dtype=Float32, value= [ 1.38629436e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value=
275        [[ -1.00000000e+00,  -1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value=
276        [[ 0.00000000e+00,  0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value=
277        [[ -5.00000000e-01,  -5.00000000e-01],
278         [ -5.00000000e-01,  -5.00000000e-01]]),)))
279    """
280
281    def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
282        super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
283        if not isinstance(network, (Cell, FunctionType, MethodType)):
284            raise TypeError(f"For 'ForwardValueAndGrad', "
285                            f"the argument 'network' must be cell, function type or method type, "
286                            f"but got '{type(network)}'")
287        if not isinstance(get_all, bool):
288            raise TypeError(f"For 'ForwardValueAndGrad', "
289                            f"the type of 'get_all' must be bool, but got '{type(get_all)}'")
290        if not isinstance(get_by_list, bool):
291            raise TypeError(f"For 'ForwardValueAndGrad', "
292                            f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'")
293        if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)):
294            raise TypeError(f"For 'ForwardValueAndGrad', "
295                            f"when 'get_by_list' is set to True, the argument 'weights' must be "
296                            f"Parameters array, but got '{type(weights)}'")
297        self.network = network
298        if isinstance(network, Cell):
299            self.network.set_grad()
300        self.weights = weights
301        self.get_all = get_all
302        self.get_by_list = get_by_list
303        self.sens_param = sens_param
304        self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
305        self._get_attr_from_cell(network)
306
307    def construct(self, *inputs):
308        grad_inputs = inputs
309        if self.sens_param:
310            inputs = inputs[:-1]
311        loss = self.network(*inputs)
312        if self.get_by_list:
313            grads = self.grad(self.network, self.weights)(*grad_inputs)
314        else:
315            grads = self.grad(self.network)(*grad_inputs)
316        return loss, grads
317
318
319class TrainOneStepCell(Cell):
320    r"""
321    Network training package class.
322
323    Wraps the `network` with the `optimizer`. The resulting Cell is trained with input '\*inputs'.
324    The backward graph will be created in the construct function to update the parameter. Different
325    parallel modes are available for training.
326
327    Args:
328        network (Cell): The training network. The network only supports single output.
329        optimizer (Union[Cell]): Optimizer for updating the network parameters.
330        sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is
331            ``None`` , which is ``1.0`` .
332        return_grad (bool): Whether to return gradient. If ``True``, it will return the gradient in the form of a dict
333            while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value
334            is the gradient value. Default value is ``False`` .
335
336    Inputs:
337        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
338
339    Outputs:
340        Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
341
342    Raises:
343        TypeError: If `sens` is not a numbers.Number.
344
345    Supported Platforms:
346        ``Ascend`` ``GPU`` ``CPU``
347
348    Examples:
349        >>> import mindspore.nn as nn
350        >>> # Define the network structure of LeNet5. Refer to
351        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
352        >>> net = LeNet5()
353        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
354        >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
355        >>> #1) Using the WithLossCell provided by MindSpore
356        >>> loss_net = nn.WithLossCell(net, loss_fn)
357        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
358        >>>
359        >>> #2) Using user-defined WithLossCell
360        >>> class MyWithLossCell(nn.Cell):
361        ...    def __init__(self, backbone, loss_fn):
362        ...        super(MyWithLossCell, self).__init__(auto_prefix=False)
363        ...        self._backbone = backbone
364        ...        self._loss_fn = loss_fn
365        ...
366        ...    def construct(self, x, y, label):
367        ...        out = self._backbone(x, y)
368        ...        return self._loss_fn(out, label)
369        ...
370        ...    @property
371        ...    def backbone_network(self):
372        ...        return self._backbone
373        ...
374        >>> loss_net = MyWithLossCell(net, loss_fn)
375        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
376    """
377
378    def __init__(self, network, optimizer, sens=None, return_grad=False):
379        super(TrainOneStepCell, self).__init__(auto_prefix=False)
380        self.network = network
381        self.network.set_grad()
382        self.optimizer = optimizer
383        self.weights = self.optimizer.parameters
384        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
385        self.grad_no_sens = C.GradOperation(get_by_list=True)
386        self.sens = sens
387        if self.sens == 0:
388            raise ValueError("The input argument of 'sens' can not be 0.")
389        self.sense_flag = True
390        if self.sens is None:
391            self.sense_flag = False
392            self.sens = 1.0
393        self.return_grad = return_grad
394        if return_grad:
395            self.weights_name = [i.name for i in self.optimizer.parameters]
396        self.reducer_flag = False
397        self.grad_reducer = nn.Identity()
398        self.parallel_mode = _get_parallel_mode()
399        self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
400                            _is_pynative_parallel()
401        if self.reducer_flag:
402            self.mean = _get_gradients_mean()
403            self.degree = _get_device_num()
404            from mindspore.communication.management import GlobalComm
405            group = GlobalComm.WORLD_COMM_GROUP
406            if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
407                from mindspore.communication.management import get_group_size, create_group, get_rank
408                group_number = get_group_size() // 8
409                self.degree = int(self.degree / group_number)
410                group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)]
411                current_index = get_rank() // 8
412                server_group_name = "allreduce_" + str(current_index)
413                create_group(server_group_name, group_list[current_index])
414                group = server_group_name
415            self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
416        self._get_attr_from_cell(network)
417
418    def construct(self, *inputs):
419        if not self.sense_flag:
420            return self._no_sens_impl(*inputs)
421        loss = self.network(*inputs)
422        sens = F.fill(loss.dtype, loss.shape, self.sens)
423        grads = self.grad(self.network, self.weights)(*inputs, sens)
424        grads = self.grad_reducer(grads)
425        loss = F.depend(loss, self.optimizer(grads))
426        if self.return_grad:
427            grad_with_param_name = {}
428            for index, value in enumerate(grads):
429                grad_with_param_name[self.weights_name[index]] = value
430            return loss, grad_with_param_name
431        return loss
432
433    def _no_sens_impl(self, *inputs):
434        """construct implementation when the 'sens' parameter is passed in."""
435        loss = self.network(*inputs)
436        grads = self.grad_no_sens(self.network, self.weights)(*inputs)
437        grads = self.grad_reducer(grads)
438        loss = F.depend(loss, self.optimizer(grads))
439        if self.return_grad:
440            grad_with_param_name = {}
441            for index, value in enumerate(grads):
442                grad_with_param_name[self.weights_name[index]] = value
443            return loss, grad_with_param_name
444        return loss
445
446
447class GetNextSingleOp(Cell):
448    """
449    Cell to run for getting the next operation.
450
451    For detailed information, refer to :class:`mindspore.ops.GetNext`.
452
453    Args:
454        dataset_types (list[:class:`mindspore.dtype`]): The types of dataset.
455        dataset_shapes (list[tuple[int]]): The shapes of dataset.
456        queue_name (str): Queue name to fetch the data.
457
458    Outputs:
459        tuple[Tensor], the data gets from Dataset.
460
461    Supported Platforms:
462        ``Ascend`` ``GPU``
463
464    Examples:
465        >>> import mindspore
466        >>> from mindspore import ops, nn
467        >>> from mindspore import dataset as ds
468        >>> from mindspore import dtype as mstype
469        >>>
470        >>> data_path =  "/path/to/MNIST_Data/train/"
471        >>> train_dataset = ds.MnistDataset(data_path, num_samples=10)
472        >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
473        >>> dataset = dataset_helper.iter.dataset
474        >>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
475        >>> queue_name = dataset.__transfer_dataset__.queue_name
476        >>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name)
477        >>> data, label = get_next_single_op_net()
478        >>> relu = ops.ReLU()
479        >>> result = relu(data.astype(mstype.float32))
480        >>> print(result.shape)
481        (28, 28, 1)
482    """
483
484    def __init__(self, dataset_types, dataset_shapes, queue_name):
485        super(GetNextSingleOp, self).__init__()
486        self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
487
488    def construct(self):
489        return self.get_next()
490
491
492class _VirtualDatasetCell(Cell):
493    """
494    Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
495
496    _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
497    of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted
498    dynamically during the graph compile process.
499
500    Note:
501        Only used in semi auto parallel and auto parallel mode.
502
503    Args:
504        backbone (Cell): The target network to wrap.
505
506    Examples:
507        >>> net = Net()
508        >>> net = _VirtualDatasetCell(net)
509    """
510
511    def __init__(self, backbone):
512        super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
513        self._backbone = backbone
514        self._virtual_dataset = _VirtualDataset()
515        self._get_attr_from_cell(backbone)
516
517    def construct(self, *inputs):
518        output = self._virtual_dataset(*inputs)
519        return self._backbone(*output)
520
521
522@_primexpr
523def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
524    if F.isconstant(input_shape[0]) is False:
525        return
526    if input_shape[0] % micro_size != 0:
527        raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
528                         f"divided by micro size({micro_size})")
529
530
531class _MicroBatch(Cell):
532    """
533    transform mini-batch to micro-batch in pipeline parallel.
534
535    Args:
536       params (micro_size): The number of micro-batch.
537    """
538    def __init__(self, micro_size):
539        super(_MicroBatch, self).__init__()
540        self.shape = P.Shape()
541        self.micro_size = micro_size
542        self.strided_slice = P.StridedSlice()
543
544    def construct(self, i, *inputs):
545        """construct for _MicroBatch."""
546        micro_inputs = ()
547        for each_input in inputs:
548            input_shape = self.shape(each_input)
549            _check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
550            micro_batch_begin = (input_shape[0] // self.micro_size) * i
551            micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1)
552            strided_slice_begin = (micro_batch_begin,)
553            strided_slice_strides = (1,)
554            for _ in range(len(input_shape) - 1):
555                strided_slice_begin += (0,)
556                strided_slice_strides += (1,)
557            strided_slice_end = (micro_batch_end,)
558            strided_slice_end += input_shape[1:]
559            micro_input = self.strided_slice(each_input, strided_slice_begin, strided_slice_end, strided_slice_strides)
560            micro_inputs += (micro_input,)
561        return micro_inputs
562
563
564class MicroBatchInterleaved(Cell):
565    """
566    This function splits the input at the 0th into interleave_num pieces and then performs
567    the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode
568    and network, if the first slice data is calculating forward, the second slice data will execute the
569    communication operators at the same time, to achieve the performance acceleration of communication and computing
570    concurrency.
571
572    Note:
573        The output of the input network must be a single tensor.
574
575    Args:
576        network (Cell): The target network to wrap.
577        interleave_num (int, optional): split num of batch size. Default: ``2`` .
578
579    Inputs:
580        tuple[Tensor]. It's the same with the input of the `network` .
581
582    Outputs:
583        Tensor. The output of the input `network` .
584
585    Supported Platforms:
586        ``Ascend`` ``GPU``
587
588    Examples:
589        >>> import mindspore.nn as nn
590        >>> # Define the network structure of LeNet5. Refer to
591        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
592        >>> net = LeNet5()
593        >>> net = nn.MicroBatchInterleaved(net, 2)
594    """
595    def __init__(self, network, interleave_num=2):
596        super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
597        if not isinstance(interleave_num, int):
598            raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, "
599                            "but got the type : {}.".format(type(interleave_num)))
600        if interleave_num <= 0:
601            raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, "
602                             "but got {}.".format(interleave_num))
603        self.network = network
604        self.interleave_num = interleave_num
605        self.interleave_inputs = nn.CellList()
606        self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True)
607        for _ in range(interleave_num):
608            interleave_data = _MicroBatch(interleave_num)
609            interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
610            interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
611            self.interleave_inputs.append(interleave_data)
612        self._get_attr_from_cell(network)
613
614    def construct(self, *inputs):
615        output = 0.0
616        for i in range(self.interleave_num):
617            interleave_input = self.interleave_inputs[i](i, *inputs)
618            output = self.add(output, self.network(*interleave_input))
619        return output
620
621
622class PipelineCell(Cell):
623    """
624    Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.
625
626    Note:
627        micro_size must be greater or equal to pipeline stages.
628
629    Args:
630        network (Cell): The target network to wrap.
631        micro_size (int): MicroBatch size.
632
633    Supported Platforms:
634        ``Ascend`` ``GPU``
635
636    Examples:
637        >>> import mindspore.nn as nn
638        >>> # Define the network structure of LeNet5. Refer to
639        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
640        >>> net = LeNet5()
641        >>> net = nn.PipelineCell(net, 4)
642    """
643    def __init__(self, network, micro_size):
644        super(PipelineCell, self).__init__(auto_prefix=False)
645        self.network = network
646        self.micro_inputs = nn.CellList()
647        self.micro_size = micro_size
648        self.add_list = []
649        if not isinstance(network, Cell):
650            raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
651                            "but got the type : {}.".format(type(network)))
652        if not isinstance(micro_size, int):
653            raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
654                            "but got the type : {}.".format(type(micro_size)))
655        if micro_size <= 0:
656            raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
657                             "but got {}.".format(micro_size))
658        for i in range(micro_size):
659            micro_input = _MicroBatch(micro_size)
660            self.micro_inputs.append(micro_input)
661            self.add = P.Add().add_prim_attr("pipeline_end", i)
662            self.add_list.append(self.add)
663        self._get_attr_from_cell(network)
664
665    def construct(self, *inputs):
666        ret = None
667        for i in range(self.micro_size):
668            micro_input = self.micro_inputs[i](i, *inputs)
669            output = self.network(*micro_input)
670            if ret is not None:
671                ret = self.add_list[i](ret, output)
672            else:
673                ret = output
674        return ret
675
676class GradAccumulationCell(Cell):
677    """
678    Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.
679
680    Args:
681        network (Cell): The target network to wrap.
682        micro_size (int): MicroBatch size.
683
684    Supported Platforms:
685        ``Ascend`` ``GPU``
686
687    Examples:
688        >>> import mindspore.nn as nn
689        >>> # Define the network structure of LeNet5. Refer to
690        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
691        >>> net = LeNet5()
692        >>> net = nn.GradAccumulationCell(net, 4)
693    """
694    def __init__(self, network, micro_size):
695        super(GradAccumulationCell, self).__init__(auto_prefix=False)
696        self.network = network
697        self.micro_inputs = nn.CellList()
698        self.micro_size = micro_size
699        self.add_list = []
700        if not isinstance(network, Cell):
701            raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
702                            "but got the type : {}.".format(type(network)))
703        if not isinstance(micro_size, int):
704            raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
705                            "but got the type : {}.".format(type(micro_size)))
706        if micro_size <= 0:
707            raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, "
708                             "but got {}.".format(micro_size))
709        for i in range(micro_size):
710            micro_input = _MicroBatch(micro_size)
711            micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size)
712            self.micro_inputs.append(micro_input)
713            self.add = P.Add().add_prim_attr("forward_end", i)
714            self.add_list.append(self.add)
715        self._get_attr_from_cell(network)
716
717    def construct(self, *inputs):
718        ret = None
719        for i in range(self.micro_size):
720            micro_input = self.micro_inputs[i](i, *inputs)
721            output = self.network(*micro_input)
722            if ret is not None:
723                ret = self.add_list[i](ret, output)
724            else:
725                ret = output
726        return ret
727
728
729def _pipeline_clear_grad(accu_grad, grad):
730    accu_grad = F.depend(accu_grad, grad)
731    zeros = F.zeros_like(accu_grad)
732    return F.assign(accu_grad, zeros)
733
734
735class _TrainGradAccuStepCell(TrainOneStepCell):
736    """
737    Wraps the network with an optimizer in pipeline mode.
738    """
739    def __init__(self, network, optimizer, sens=None):
740        super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens)
741        self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
742        self.hyper_map = ops.HyperMap()
743        self.opt_shard = _get_enable_parallel_optimizer()
744        self._get_attr_from_cell(network)
745        self.enable_mindio = False
746        mode = get_context("mode")
747        device_type = get_context("device_target")
748        if device_type != "Ascend" or mode != GRAPH_MODE:
749            return
750        graceful_exit = os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT")
751        ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
752        ttp_path_check = ttp_lib_path is not None and os.path.isfile(ttp_lib_path)
753        if graceful_exit == "true" and ttp_path_check:
754            self.g_one = Tensor([0.1])
755            self.allreduce_sum = ops.AllReduce()
756            self.enable_mindio = True
757
758    def construct(self, *inputs):
759        if not self.sense_flag:
760            return self._no_sens_impl(*inputs)
761        loss = self.network(*inputs)
762        sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
763        grads = self.grad(self.network, self.weights)(*inputs, sens)
764        accu_grads = ops.depend(self.accu_grads, grads)
765        if self.enable_mindio:
766            g_one = ops.depend(self.g_one, accu_grads)
767            g_one_res = self.allreduce_sum(g_one)
768            accu_grads = ops.depend(accu_grads, g_one_res)
769            grads = ops.depend(grads, g_one_res)
770        if self.opt_shard:
771            succ = self.optimizer(grads)
772        else:
773            succ = self.optimizer(accu_grads)
774        loss = ops.depend(loss, succ)
775        clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
776        loss = ops.depend(loss, clear)
777        return loss
778
779    def _no_sens_impl(self, *inputs):
780        """construct implementation when the 'sens' parameter is passed in."""
781        loss = self.network(*inputs)
782        grads = self.grad_no_sens(self.network, self.weights)(*inputs)
783        accu_grads = ops.depend(self.accu_grads, grads)
784        if self.enable_mindio:
785            g_one = ops.depend(self.g_one, accu_grads)
786            g_one_res = self.allreduce_sum(g_one)
787            accu_grads = ops.depend(accu_grads, g_one_res)
788            grads = ops.depend(grads, g_one_res)
789        if self.opt_shard:
790            succ = self.optimizer(grads)
791        else:
792            succ = self.optimizer(accu_grads)
793        loss = ops.depend(loss, succ)
794        clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
795        loss = ops.depend(loss, clear)
796        return loss
797
798
799class AllreduceGraph(Cell):
800    """
801    A allreduce graph to broadcast parameters.
802    """
803    def __init__(self, inputs, group_name):
804        super(AllreduceGraph, self).__init__()
805        self.input_num = len(inputs)
806        self.inputs = inputs
807        self.allreduces = []
808        self.assigns = []
809        for _ in range(self.input_num):
810            self.allreduces.append(ops.AllReduce(op="sum", group=group_name))
811            self.assigns.append(ops.Assign())
812
813    def construct(self):
814        for i in range(self.input_num):
815            res = self.allreduces[i](self.inputs[i])
816            self.assigns[i](self.inputs[i], res)
817        return self.inputs
818
819
820class VirtualDatasetCellTriple(Cell):
821    """
822    Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
823
824    VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
825    of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
826    dynamically during the graph compile process.
827
828    Note:
829        Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
830        _VirtualDatasetCell.
831
832    Args:
833        backbone (Cell): The target network to wrap.
834
835    Examples:
836        >>> import mindspore.nn as nn
837        >>> # Define the network structure of LeNet5. Refer to
838        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
839        >>> net = LeNet5()
840        >>> net = nn.VirtualDatasetCellTriple(net)
841    """
842
843    def __init__(self, backbone):
844        super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
845        logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
846        self._backbone = backbone
847        self._get_attr_from_cell(backbone)
848
849    def construct(self, a, b, c):
850        return self._backbone(a, b, c)
851
852
853class WithEvalCell(Cell):
854    r"""
855    Wraps the forward network with the loss function.
856
857    It returns loss, forward output and label to calculate the metrics.
858
859    Args:
860        network (Cell): The forward network.
861        loss_fn (Cell): The loss function.
862        add_cast_fp32 (bool): Whether to adjust the data type to float32. Default: ``False`` .
863
864    Inputs:
865        - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
866        - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
867
868    Outputs:
869        Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
870        and a label Tensor of shape :math:`(N, \ldots)`.
871
872    Raises:
873        TypeError: If `add_cast_fp32` is not a bool.
874
875    Supported Platforms:
876        ``Ascend`` ``GPU`` ``CPU``
877
878    Examples:
879        >>> import mindspore.nn as nn
880        >>> # Define a forward network without loss function, taking LeNet5 as an example.
881        >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
882        >>> net = LeNet5()
883        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
884        >>> eval_net = nn.WithEvalCell(net, loss_fn)
885    """
886
887    def __init__(self, network, loss_fn, add_cast_fp32=False):
888        super(WithEvalCell, self).__init__(auto_prefix=False)
889        self._network = network
890        self._loss_fn = loss_fn
891        self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
892        self._get_attr_from_cell(network)
893
894    def construct(self, data, label):
895        outputs = self._network(data)
896        if self.add_cast_fp32:
897            label = F.mixed_precision_cast(mstype.float32, label)
898            outputs = F.cast(outputs, mstype.float32)
899        loss = self._loss_fn(outputs, label)
900        return loss, outputs, label
901
902
903class ParameterUpdate(Cell):
904    """
905    Cell that updates parameter.
906
907    With this Cell, one can manually update `param` with the input `Tensor`.
908
909    Args:
910        param (Parameter): The parameter to be updated manually.
911
912    Inputs:
913        - **x** (Tensor) - A tensor whose shape and type are the same with `param`.
914
915    Outputs:
916        Tensor, the updated value.
917
918    Raises:
919        KeyError: If parameter with the specified name does not exist.
920
921    Supported Platforms:
922        ``Ascend`` ``GPU`` ``CPU``
923
924    Examples:
925        >>> import numpy as np
926        >>> import mindspore
927        >>> from mindspore import nn, Tensor
928        >>> network = nn.Dense(3, 4)
929        >>> param = network.parameters_dict()['weight']
930        >>> update = nn.ParameterUpdate(param)
931        >>> update.phase = "update_param"
932        >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
933        >>> output = update(weight)
934        >>> print(output)
935        [[ 0.  1.  2.]
936         [ 3.  4.  5.]
937         [ 6.  7.  8.]
938         [ 9. 10. 11.]]
939    """
940
941    def __init__(self, param):
942        super(ParameterUpdate, self).__init__(auto_prefix=False)
943        if not isinstance(param, Parameter):
944            raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(type(param)))
945        self._param = param
946
947    def construct(self, x):
948        F.assign(self._param, x)
949        return x
950
951
952class _BroadCastCell(Cell):
953    """
954    Broadcast the parameters from device 0 to other devices.
955
956    Args:
957       params (list): The parameters of Net.
958    """
959
960    def __init__(self, params):
961        super(_BroadCastCell, self).__init__()
962        from mindspore.communication.management import get_group_size, create_group
963        from mindspore import context
964        self.map_ = C.Map()
965        self.params = tuple(params)
966        if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE:
967            rank_list = [id for id in range(0, get_group_size())]
968            create_group("BroadcastWorldGroup", rank_list)
969            self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup")
970        else:
971            self.broadcast = P.Broadcast(0)
972        self.add_flags(skip_auto_parallel_compile=True)
973
974    def construct(self):
975        datatypes = self.map_(F.partial(_get_datatype), self.params)
976        params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
977        params = self.broadcast(params)
978        new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
979        return new_params
980