• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""cell"""
16import gc
17import inspect
18import os
19import time
20from collections import OrderedDict
21
22import numpy
23
24from mindspore._checkparam import args_type_check
25from mindspore import log as logger
26from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
27from mindspore.common._decorator import deprecated
28from mindspore.context import ParallelMode
29from .. import context
30from .._c_expression import init_pipeline, Cell_, FuncGraph
31from .._checkparam import Validator
32from ..common import dtype as mstype
33from ..common.api import _cell_graph_executor, _pynative_executor
34from ..common.parameter import Parameter, ParameterTuple
35from ..common.tensor import Tensor
36from ..ops.operations import HookBackward, Cast
37from ..ops.primitive import Primitive
38from ..parallel._tensor import _load_tensor_by_layout
39
40
41class Cell(Cell_):
42    """
43    Base class for all neural networks.
44
45    A 'Cell' could be a single neural network cell, such as conv2d, relu, batch_norm, etc. or a composition of
46    cells to constructing a network.
47
48    Note:
49        In general, the autograd algorithm will automatically generate the implementation of the gradient function,
50        but if back-propagation(bprop) method is implemented, the gradient function will be replaced by the bprop.
51        The bprop implementation will receive a tensor `dout` containing the gradient of the loss w.r.t.
52        the output, and a tensor `out` containing the forward result. The bprop needs to compute the
53        gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported
54        currently. The bprop method must contain the self parameter.
55
56    Args:
57        auto_prefix (bool): Recursively generate namespaces. Default: True.
58        flags (dict): Network configuration information, currently it is used for the binding of network and dataset.
59                      Users can also customize network attributes by this parameter. Default: None.
60
61    Supported Platforms:
62        ``Ascend`` ``GPU`` ``CPU``
63
64    Examples:
65        >>> import mindspore.nn as nn
66        >>> import mindspore.ops as ops
67        >>> class MyCell(nn.Cell):
68        ...    def __init__(self):
69        ...        super(MyCell, self).__init__()
70        ...        self.relu = ops.ReLU()
71        ...
72        ...    def construct(self, x):
73        ...        return self.relu(x)
74    """
75    IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
76                   '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
77                   '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
78                   '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
79                   '_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
80                   '_auto_parallel_compile_and_run', 'cell_type']
81
82    def __init__(self, auto_prefix=True, flags=None):
83        Cell_.__init__(self, self._cell_tag)
84        self._params = OrderedDict()
85        self._cells = OrderedDict()
86        self._params_list = OrderedDict()
87        self._tensor_list = OrderedDict()
88        self._primitives = OrderedDict()
89        self.training = False
90        self.requires_grad = False
91        self.pynative = False
92        self._attr_synced = False
93        self._param_prefix = ''
94        self._auto_prefix = auto_prefix
95        self._scope = None
96        self._phase = 'train'
97        self._parameter_layout_dict = {}
98        self._parallel_parameter_name_list = ()
99        self._parallel_parameter_merge_net_dict = {}
100        self._create_time = int(time.time() * 1e9)
101        self.arguments_key = ""
102        self.parameter_broadcast_done = False
103        init_pipeline()
104
105        # call gc to release GE session resources used by non-used cell objects
106        if os.getenv('GC_COLLECT_IN_CELL') == '1':
107            gc.collect()
108
109        self._construct_inputs_num = 0
110        self._construct_inputs_names = []
111        self._auto_parallel_mode = False
112        self._parallel_inputs_run = None
113        if flags:
114            self.add_flags(**flags)
115        self._backward_hook = None
116        self.enable_hook = False
117        self._bprop_debug = False
118        self.cell_type = None
119        self._auto_parallel_compile_and_run = False
120        self.cast = Cast()
121        self._has_config_recompute = False
122
123    def __getstate__(self):
124        base = Cell_.__getstate__(self)
125        return base, self.__dict__
126
127    def __setstate__(self, state):
128        base, dict_ = state
129        Cell_.__setstate__(self, base)
130        self.__dict__ = dict_
131        self._attr_synced = False
132
133    @property
134    def _cell_tag(self):
135        # `<class 'xxxxxxx'>` to `xxxxxxx`
136        return str(self.__class__)[8:-2]
137
138    @property
139    def create_time(self):
140        return self._create_time
141
142    @property
143    def cell_init_args(self):
144        return self._cell_init_args
145
146    @property
147    def param_prefix(self):
148        """
149        Param prefix is the prefix of current cell's direct child parameter.
150        """
151        return self._param_prefix
152
153    @property
154    def bprop_debug(self):
155        """
156        Get whether cell custom bprop debug is enabled.
157        """
158        return self._bprop_debug
159
160    @bprop_debug.setter
161    def bprop_debug(self, value):
162        """
163        Set whether to enable cell custom bprop debug.
164
165        Note:
166            When bprop is defined in cell, the bprop function will be executed
167            in python interpreter when bprop debug is true, and will be parsed
168            and add to graph when bprop debug is false.
169
170        Args:
171            value (bool): Specifies whether to enable bprop debug. Default: False.
172        """
173        if not isinstance(value, bool):
174            raise TypeError("The 'bprop debug' value must be a bool type.")
175        self._bprop_debug = value
176
177    def update_cell_prefix(self):
178        """
179        Update the all child cells' self.param_prefix.
180
181        After being invoked, it can get all the cell's children's name prefix by '_param_prefix'.
182        """
183        cells_name = self.cells_and_names()
184
185        for cell_name, cell in cells_name:
186            cell._param_prefix = cell_name
187
188    def update_cell_type(self, cell_type):
189        """
190        The current cell type is updated when a quantization aware training network is encountered.
191
192        After being invoked, it can set the cell type to 'cell_type'.
193        """
194        self.cell_type = cell_type
195
196    @cell_init_args.setter
197    def cell_init_args(self, value):
198        if not isinstance(value, str):
199            raise TypeError("The 'cell_init_args' must be a string type.")
200        self._cell_init_args = value
201
202    @property
203    def phase(self):
204        return self._phase
205
206    @phase.setter
207    def phase(self, value):
208        if not isinstance(value, str):
209            raise TypeError("The 'phase' must be a string type.")
210        self._phase = value
211
212    @property
213    def parameter_layout_dict(self):
214        """
215        `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
216        distributed operator information.
217        """
218        return self._parameter_layout_dict
219
220    @property
221    def cls_name(self):
222        return self.__class__.__name__
223
224    @parameter_layout_dict.setter
225    def parameter_layout_dict(self, value):
226        if not isinstance(value, dict):
227            raise TypeError("The 'parameter_layout_dict' must be a dict type.")
228        self._parameter_layout_dict = value
229
230    @property
231    def parallel_parameter_name_list(self):
232        return self._parallel_parameter_name_list
233
234    @parallel_parameter_name_list.setter
235    def parallel_parameter_name_list(self, value):
236        if not isinstance(value, list):
237            raise TypeError("The 'parallel_parameter_name_list' must be a list type.")
238        self._parallel_parameter_name_list = value
239
240    @property
241    def pipeline_stage(self):
242        return self._pipeline_stage
243
244    @pipeline_stage.setter
245    def pipeline_stage(self, value):
246        if isinstance(value, bool):
247            raise TypeError("'pipeline_stage' must be int type, but got bool.")
248        if not isinstance(value, int):
249            raise TypeError("'pipeline_stage' must be int type.")
250        if value < 0:
251            raise TypeError("'pipeline_stage' can not less than 0.")
252        self._pipeline_stage = value
253        for item in self.trainable_params():
254            item.add_pipeline_stage(value)
255
256    @property
257    def parallel_parameter_merge_net_dict(self):
258        return self._parallel_parameter_merge_net_dict
259
260    @parallel_parameter_merge_net_dict.setter
261    def parallel_parameter_merge_net_dict(self, value):
262        if not isinstance(value, dict):
263            raise TypeError("The 'parallel_parameter_merge_net_dict' must be a dict type.")
264        self._parallel_parameter_merge_net_dict = value
265
266    def get_func_graph_proto(self):
267        """Return graph binary proto."""
268        exec_id = self.phase + "." + str(self.create_time) + '.' + str(id(self))
269        return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True)
270
271    def __getattr__(self, name):
272        if '_params' in self.__dict__:
273            params = self.__dict__['_params']
274            if name in params:
275                if context.get_context("mode") == context.PYNATIVE_MODE:
276                    return self.cast_param(params[name])
277                return params[name]
278        if '_cells' in self.__dict__:
279            cells = self.__dict__['_cells']
280            if name in cells:
281                return cells[name]
282        if '_tensor_list' in self.__dict__:
283            tensor_list = self.__dict__['_tensor_list']
284            if name in tensor_list:
285                return self.cast_param(tensor_list[name])
286        if '_params_list' in self.__dict__:
287            params_list = self.__dict__['_params_list']
288            if name in params_list:
289                para_list = params_list[name]
290                cast_list = list()
291                for para in para_list:
292                    cast_list.append(self.cast_param(para))
293                para_list = ParameterTuple(cast_list)
294                return para_list
295        raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
296
297    def __del__(self):
298        if context.get_context is not None and context.get_context("mode") == context.PYNATIVE_MODE:
299            _pynative_executor.del_cell(str(id(self)))
300        if hasattr(self, "_create_time"):
301            _cell_graph_executor.del_net_res(str(self._create_time))
302
303    def __delattr__(self, name):
304        if name in self._params:
305            del self._params[name]
306        elif name in self._cells:
307            del self._cells[name]
308        else:
309            if '_params_list' in self.__dict__ and name in self._params_list:
310                del self._params_list[name]
311            elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
312                del self._tensor_list[name]
313            object.__delattr__(self, name)
314        self._attr_synced = False
315
316    def _cast_mixed_precision_inputs(self, inputs, dst_type):
317        """Cast input for mixed precision"""
318        res = list()
319        for item in inputs:
320            if isinstance(item, tuple):
321                res.append(self._cast_mixed_precision_inputs(item, dst_type))
322            elif isinstance(item, float):
323                res.append(self.cast(item, dst_type))
324            elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
325                res.append(self.cast(item, dst_type))
326            else:
327                res.append(item)
328        return tuple(res)
329
330    def cast_inputs(self, inputs, dst_type):
331        """
332        Cast inputs to specified type.
333        """
334        res = list()
335        for item in inputs:
336            if isinstance(item, tuple):
337                res.append(self.cast_inputs(item, dst_type))
338            else:
339                res.append(self.cast(item, dst_type))
340        return tuple(res)
341
342    def _do_parameter_broadcast(self):
343        if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
344            if not self.parameter_broadcast_done:
345                _pynative_executor.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
346                self.parameter_broadcast_done = True
347
348    def run_construct(self, cast_inputs, kwargs):
349        if self.enable_hook:
350            output = self._hook_construct(*cast_inputs)
351        else:
352            output = self.construct(*cast_inputs, **kwargs)
353        return output
354
355    def _check_construct_args(self, *inputs, **kwargs):
356        """Check the args needed by the function construct"""
357        if kwargs:
358            raise ValueError("For 'graph' mode, the outermost network does not support passing "
359                             "variable key-value pair parameters.")
360        positional_args = 0
361        default_args = 0
362        for value in inspect.signature(self.construct).parameters.values():
363            if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
364                return
365            if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
366                if value.default is inspect.Parameter.empty:
367                    positional_args += 1
368                else:
369                    default_args += 1
370
371        if len(inputs) < positional_args:
372            raise TypeError(
373                f"The function construct needs {positional_args} positional argument, but only provided {len(inputs)}.")
374
375        if len(inputs) > positional_args + default_args:
376            raise TypeError(
377                f"The function construct needs {positional_args} positional argument and {default_args} default "
378                f"argument, but provided {len(inputs)}")
379
380    class CellGuard:
381        def __enter__(self):
382            _pynative_executor.set_lazy_build(True)
383            _pynative_executor.enter_cell()
384
385        def __exit__(self, exc_type, exc_val, exc_tb):
386            _pynative_executor.exit_cell()
387            if _pynative_executor.is_top_cell():
388                _pynative_executor.set_lazy_build(False)
389
390    def __call__(self, *inputs, **kwargs):
391        if self.__class__.construct is Cell.construct:
392            logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
393                           f"will call the super class(Cell) 'construct'.")
394        if kwargs:
395            bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
396            inputs = bound_args.args
397            kwargs = bound_args.kwargs
398
399        # Run in Graph mode.
400        if context.get_context("mode") == context.GRAPH_MODE:
401            self._check_construct_args(*inputs, **kwargs)
402            if self.enable_hook:
403                raise ValueError("The graph mode does not support hook function.")
404            out = self.compile_and_run(*inputs)
405            return out
406
407        # Run in PyNative mode.
408        if _pynative_executor.is_top_cell():
409            _pynative_executor.set_lazy_build(True)
410            # There many Casts in parameter_broadcast. Enable lazy_build and build faster.
411            self._do_parameter_broadcast()
412
413        for item in inputs:
414            if isinstance(item, numpy.ndarray):
415                raise TypeError("The cell inputs should not be numpy arrays.")
416        if self.requires_grad is True:
417            _pynative_executor.set_grad_flag(True)
418        _pynative_executor.new_graph(self, *inputs, **kwargs)
419        cast_inputs = list()
420        if hasattr(self, "_mindspore_flags"):
421            if self._mindspore_flags.get('fp16'):
422                cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
423            if self._mindspore_flags.get('fp32'):
424                cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
425        if not cast_inputs:
426            cast_inputs = inputs
427
428        with self.CellGuard():
429            try:
430                output = self.run_construct(cast_inputs, kwargs)
431            except Exception as err:
432                _pynative_executor.clear_res()
433                raise err
434
435        if _pynative_executor.is_top_cell():
436            _pynative_executor.execute_all_task()
437
438        if isinstance(output, Parameter):
439            output = output.data
440        _pynative_executor.end_graph(self, output, *inputs, **kwargs)
441        return output
442
443    def _add_attr(self, name, value):
444        if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
445            super(Cell, self)._add_attr(name, value)
446
447    def _sync_attr_for_compile(self):
448        """Sync the attr to c++ object."""
449        if self._attr_synced:
450            return
451        cells = self.__dict__.get('_cells')
452        for key in cells:
453            cell = cells[key]
454            cell._sync_attr_for_compile()
455            self._add_attr(key, cell)
456        params = self.__dict__.get('_params')
457        for key in params:
458            if '.' in key:
459                continue
460            param = params[key]
461            self._add_attr(key, param)
462        params_list = self.__dict__.get('_params_list')
463        for key in params_list:
464            params_list_item = params_list[key]
465            self._add_attr(key, params_list_item)
466        for key in self.__dict__:
467            value = self.__dict__[key]
468            self._add_attr(key, value)
469        self._attr_synced = True
470
471    def _set_attr_for_parameter(self, name, value):
472        """Set attr for parameter."""
473        cells = self.__dict__.get('_cells')
474        params = self.__dict__.get('_params')
475        if params is None:
476            raise AttributeError("Can not assign params before Cell.__init__() call.")
477        if name in self.__dict__:
478            if self.__dict__[name] is not None:
479                raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
480            del self.__dict__[name]
481        if cells and name in cells:
482            raise TypeError("The type of value should be Cell, but got Parameter.")
483        self.insert_param_to_cell(name, value)
484
485    def _set_attr_for_parameter_tuple(self, name, value):
486        """Set attr for parameter tuple."""
487        params = self.__dict__.get('_params')
488        params_list = self.__dict__.get('_params_list')
489        if params is None:
490            raise AttributeError("Can not assign params before Cell.__init__() call.")
491        for item in value:
492            self.insert_param_to_cell(item.name, item, check_name=False)
493        if context.get_context("mode") == context.PYNATIVE_MODE:
494            if name in self.__dict__:
495                del self.__dict__[name]
496            if name in params:
497                del params[name]
498            params_list[name] = value
499        else:
500            object.__setattr__(self, name, value)
501
502    def _set_attr_for_cell(self, name, value):
503        """Set attr for cell."""
504        cells = self.__dict__.get('_cells')
505        params = self.__dict__.get('_params')
506        if cells is None:
507            raise AttributeError("Can not assign cells before Cell.__init__() call.")
508        if name in self.__dict__:
509            del self.__dict__[name]
510        if params and name in params:
511            raise TypeError("The type of value should be Parameter, but got Cell.")
512        if self._auto_prefix:
513            value.update_parameters_name(name + '.')
514        cells[name] = value
515        if hasattr(self, '_cell_init_args'):
516            self.cell_init_args += str({name: value})
517
518    def __setattr__(self, name, value):
519        cells = self.__dict__.get('_cells')
520        params = self.__dict__.get('_params')
521        tensor_list = self.__dict__.get('_tensor_list')
522        if isinstance(value, Parameter):
523            self._set_attr_for_parameter(name, value)
524        elif isinstance(value, ParameterTuple):
525            self._set_attr_for_parameter_tuple(name, value)
526        elif isinstance(value, Cell):
527            self._set_attr_for_cell(name, value)
528        elif params and name in params:
529            if isinstance(value, Tensor) and self._params[name] is not None:
530                self._params[name].set_data(value)
531            elif value is not None:
532                raise TypeError(f"The type of value should be Parameter or ParameterTuple, "
533                                f"but got {type(value).__name__}.")
534            else:
535                self.insert_param_to_cell(name, None)
536        elif cells and name in cells:
537            if value is not None:
538                raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.")
539            self._cells[name] = None
540        elif isinstance(value, Tensor):
541            if context.get_context("mode") == context.PYNATIVE_MODE:
542                if name in self.__dict__:
543                    del self.__dict__[name]
544                tensor_list[name] = value
545            else:
546                object.__setattr__(self, name, value)
547        else:
548            if isinstance(value, Primitive):
549                value.set_prim_instance_name(name)
550                self._primitives[name] = value
551            object.__setattr__(self, name, value)
552        if name not in Cell.IGNORE_LIST:
553            self._attr_synced = False
554
555    def extend_repr(self):
556        """
557        Sets the extended representation of the Cell.
558
559        To print customized extended information, re-implement this method in your own cells.
560        """
561        return ''
562
563    def __str__(self):
564        return self.__repr__()
565
566    def __repr__(self):
567        extra_str = self.extend_repr()
568        info_str = self.__class__.__name__ + '<'
569        if self._cells:
570            sub_str = '\n'
571            if extra_str:
572                sub_str += '{}\n'.format(self.extend_repr())
573            for key, value in self._cells.items():
574                sub_str += '({}): {}\n'.format(key, repr(value))
575            sub_str = sub_str.replace('\n', '\n  ') + '>'
576            info_str += sub_str
577        else:
578            info_str += extra_str + '>'
579        return info_str
580
581    def load_parameter_slice(self, params):
582        """
583        Replace parameters with sliced tensors by parallel strategies.
584
585        Please refer to the usage in source code of `mindspore.common._CellGraphExecutor.compile`.
586
587        Args:
588            params (dict): The parameters dictionary used for initializing the data graph.
589        """
590        if params is None:
591            params = self.parameters_dict()
592        if isinstance(params, OrderedDict):
593            for key in params:
594                tensor = params[key].data
595                if key not in self.parameter_layout_dict:
596                    logger.info("The layout dict does not contain the key %s.", key)
597                    continue
598                if params[key].sliced:
599                    logger.debug("The param %s is already sliced.", key)
600                    continue
601                layout = self.parameter_layout_dict[key]
602                new_tensor = _load_tensor_by_layout(tensor, layout)
603                params[key].set_data(new_tensor, True)
604        else:
605            raise TypeError("Parameters need OrderedDict type, but got {}.".format(type(params)))
606
607    def _load_inputs(self, *inputs):
608        """
609        Slice inputs tensors by parallel strategies.
610
611        Args:
612            inputs (Function or Cell): inputs of construct method.
613        """
614        parallel_inputs_run = []
615        # judge if *args exists in input
616        if self.argspec[1] is not None:
617            prefix = self.argspec[1]
618            for i in range(len(inputs)):
619                key = prefix + str(i)
620                self._construct_inputs_names = self._construct_inputs_names + (key,)
621                self._construct_inputs_num = self._construct_inputs_num + 1
622        for i, tensor in enumerate(inputs):
623            key = self._construct_inputs_names[i]
624            # if input is not used, self.parameter_layout_dict may not contain the key
625            if key not in self.parameter_layout_dict:
626                logger.warning("Layout dict does not contain the key %s.", key)
627                parallel_inputs_run.append(tensor)
628            else:
629                layout = self.parameter_layout_dict[key]
630                new_tensor = _load_tensor_by_layout(tensor, layout)
631                parallel_inputs_run.append(new_tensor)
632        return tuple(parallel_inputs_run)
633
634    def set_parallel_input_with_inputs(self, *inputs):
635        """
636        Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run`
637
638        Args:
639            inputs (tuple): inputs of construct method.
640        """
641        self._parallel_inputs_run = self._load_inputs(*inputs)
642
643    def _get_construct_inputs_number_and_name(self):
644        """Compute self._construct_inputs_names and self._construct_inputs_num"""
645        from mindspore._extends.parse.parser import get_parse_method_of_class
646
647        fn = get_parse_method_of_class(self)
648        self.argspec = inspect.getfullargspec(fn)
649        self._construct_inputs_num = fn.__code__.co_argcount
650        self._construct_inputs_names = fn.__code__.co_varnames
651
652        if self._construct_inputs_num <= 0:
653            raise ValueError(f"Num of inputs must be greater than 0, but got {self._construct_inputs_num}")
654        if self._construct_inputs_names[0] != 'self':
655            raise ValueError(f"First member of fn function must be self, but got {self._construct_inputs_names[0]}")
656        if self._construct_inputs_num - 1 > len(self._construct_inputs_names):
657            raise ValueError(f"Num of inputs must be greater than num of fn function members, num of inputs is \
658                {self._construct_inputs_names - 1}, num of fn function members is {len(self._construct_inputs_names)}")
659        self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
660        self._construct_inputs_num = self._construct_inputs_num - 1
661
662    def compile(self, *inputs):
663        """
664        Compiles cell.
665
666        Args:
667            inputs (tuple): Inputs of the Cell object.
668        """
669        _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
670
671    def compile_and_run(self, *inputs):
672        """
673        Compiles and runs cell.
674
675        Args:
676            inputs (tuple): Inputs of the Cell object.
677
678        Returns:
679            Object, the result of executing.
680        """
681        self._auto_parallel_compile_and_run = True
682        self.compile(*inputs)
683
684        new_inputs = []
685        for i in inputs:
686            if isinstance(i, Tensor):
687                new_inputs.append(i)
688            elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
689                new_inputs.append(i)
690
691        if self._auto_parallel_mode:
692            if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag:
693                # get parallel inputs in sink mode, parallel inputs set in _cell_graph_executor.compile
694                parallel_inputs_run = self._parallel_inputs_run
695            else:
696                parallel_inputs_run = new_inputs
697            return _cell_graph_executor(self, *parallel_inputs_run, phase=self.phase)
698        return _cell_graph_executor(self, *new_inputs, phase=self.phase)
699
700    def auto_parallel_compile_and_run(self):
701        """
702        Whether or not to execute compile and run.
703
704        Returns:
705            bool, `_auto_parallel_compile_and_run` value.
706        """
707        return self._auto_parallel_compile_and_run
708
709    def exec_checkpoint_graph(self):
710        """Executes saving checkpoint graph operation."""
711        _cell_graph_executor(self, phase='save')
712
713    def insert_param_to_cell(self, param_name, param, check_name=True):
714        """
715        Adds a parameter to the current cell.
716
717        Inserts a parameter with given name to the cell. Please refer to the usage in
718        source code of `mindspore.nn.Cell.__setattr__`.
719
720        Args:
721            param_name (str): Name of the parameter.
722            param (Parameter): Parameter to be inserted to the cell.
723            check_name (bool): Determines whether the name input is compatible. Default: True.
724
725        Raises:
726            KeyError: If the name of parameter is null or contains dot.
727            AttributeError: If user did not call init() first.
728            TypeError: If the type of parameter is not Parameter.
729        """
730        if not param_name:
731            raise KeyError("The name of parameter should not be null.")
732        if check_name and '.' in param_name:
733            raise KeyError("The name of parameter should not contain \".\"")
734        if '_params' not in self.__dict__:
735            raise AttributeError("You need call init() first.")
736        if hasattr(self, param_name) and param_name not in self._params:
737            raise KeyError("Duplicated parameter name '{}'.".format(param_name))
738        if not isinstance(param, Parameter) and param is not None:
739            raise TypeError("The type of parameter should be 'Parameter' if not None.")
740        if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
741            param.name = param_name
742        self._params[param_name] = param
743
744    def cast_param(self, param):
745        """
746        Cast parameter according to auto mix precision level in pynative mode.
747
748        This interface is currently used in the case of auto mix precision and usually need not to be used explicitly.
749
750        Args:
751            param (Parameter): Parameters, the type of which should be cast.
752
753        Returns:
754            Parameter, the input parameter with type automatically cast.
755        """
756        if hasattr(self, "_mindspore_flags"):
757            if self._mindspore_flags.get('fp32'):
758                param.set_cast_dtype(mstype.float32)
759            elif self._mindspore_flags.get('fp16'):
760                param.set_cast_dtype(mstype.float16)
761            elif hasattr(param, "set_cast_dtype"):
762                # retest dtype
763                param.set_cast_dtype()
764        return param
765
766    def insert_child_to_cell(self, child_name, child_cell):
767        """
768        Adds a child cell to the current cell with a given name.
769
770        Args:
771            child_name (str): Name of the child cell.
772            child_cell (Cell): The child cell to be inserted.
773
774        Raises:
775            KeyError: Child Cell's name is incorrect or duplicated with the other child name.
776            TypeError: Child Cell's type is incorrect.
777        """
778        if not child_name or '.' in child_name:
779            raise KeyError("Child cell name is incorrect.")
780        if hasattr(self, child_name) and child_name not in self._cells:
781            raise KeyError("Duplicate child name '{}'.".format(child_name))
782        if not isinstance(child_cell, Cell) and child_cell is not None:
783            raise TypeError("Child cell type is incorrect.")
784        self._cells[child_name] = child_cell
785
786    def construct(self, *inputs, **kwargs):
787        """
788        Defines the computation to be performed. This method must be overridden by all subclasses.
789
790        Returns:
791            Tensor, returns the computed result.
792        """
793        return None
794
795    def remove_redundant_parameters(self):
796        """
797        Remove the redundant parameters.
798
799        This interface usually need not to be used explicitly.
800        """
801        cells = self.cells_and_names()
802        for _, cell in cells:
803            params = cell._params.items()
804            for param_name, param in list(params):
805                if param.name not in self.parallel_parameter_name_list:
806                    cell._params.pop(param_name)
807                    logger.info("remove the redundant parameter: %s", param.name)
808                    continue
809            cell_dict = cell.__dict__
810            for key in cell_dict:
811                if isinstance(cell_dict[key], ParameterTuple):
812                    param_tuple = cell_dict[key]
813                    new_param_tuple = []
814                    for param in param_tuple:
815                        if param.name not in self.parallel_parameter_name_list:
816                            logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
817                            continue
818                        new_param_tuple.append(param)
819                    cell.__dict__[key] = ParameterTuple(new_param_tuple)
820
821    def init_parameters_data(self, auto_parallel_mode=False):
822        """
823        Initialize all parameters and replace the original saved parameters in cell.
824
825        Note:
826            trainable_params() and other similar interfaces may return different parameter instance after
827            `init_parameters_data`, do not save these result.
828
829        Args:
830            auto_parallel_mode (bool): If running in auto_parallel_mode.
831
832        Returns:
833            Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
834        """
835        replace = dict()
836
837        def _updata(param):
838            if param in replace:
839                return replace[param]
840            layout = None
841            set_sliced = False
842            if auto_parallel_mode:
843                set_sliced = True
844                if param.name not in self.parameter_layout_dict:
845                    logger.debug("Layout dict does not contain the key %s.", param.name)
846                else:
847                    layout = self.parameter_layout_dict[param.name]
848            new_p = param.init_data(layout, set_sliced=set_sliced)
849            replace[param] = new_p
850            return new_p
851
852        # replace all original usage.
853        cells = self.cells_and_names()
854        for _, cell in cells:
855            params = cell._params.items()
856            for param_name, param in params:
857                if not auto_parallel_mode:
858                    cell._params[param_name] = _updata(param)
859                    continue
860                if param.name in self.parallel_parameter_name_list:
861                    cell._params[param_name] = _updata(param)
862            cell_dict = cell.__dict__
863            for key in cell_dict:
864                if isinstance(cell_dict[key], ParameterTuple):
865                    param_tuple = cell_dict[key]
866                    new_param_tuple = []
867                    for param in param_tuple:
868                        if not auto_parallel_mode:
869                            new_param_tuple.append(_updata(param))
870                            continue
871                        if param.name in self.parallel_parameter_name_list:
872                            new_param_tuple.append(_updata(param))
873                        else:
874                            new_param_tuple.append(param)
875                    cell.__dict__[key] = ParameterTuple(new_param_tuple)
876        return replace
877
878    def parameters_dict(self, recurse=True):
879        """
880        Gets parameters dictionary.
881
882        Gets the parameters dictionary of this cell.
883
884        Args:
885            recurse (bool): Whether contains the parameters of subcells. Default: True.
886
887        Returns:
888            OrderedDict, return parameters dictionary.
889        """
890        param_dict = OrderedDict()
891        for param in self.get_parameters(expand=recurse):
892            param_dict[param.name] = param
893        return param_dict
894
895    def parameters_broadcast_dict(self, recurse=True):
896        """
897        Gets the parameters broadcast dictionary of this cell.
898
899        Args:
900            recurse (bool): Whether contains the parameters of subcells. Default: True.
901
902        Returns:
903            OrderedDict, return parameters broadcast dictionary.
904        """
905        param_dict = OrderedDict()
906        for param in self.get_parameters(expand=recurse):
907            if param.layerwise_parallel is False:
908                param_dict[param.name] = param
909        if not param_dict:
910            return None
911        return param_dict
912
913    def update_parameters_name(self, prefix='', recurse=True):
914        """
915        Updates the names of parameters with given prefix string.
916
917        Adds the given prefix to the names of parameters.
918
919        Args:
920            prefix (str): The prefix string. Default: ''.
921            recurse (bool): Whether contains the parameters of subcells. Default: True.
922        """
923
924        Validator.check_str_by_regular(prefix)
925        for name, param in self.parameters_and_names(expand=recurse):
926            if prefix != '':
927                param.is_init = False
928            param.name = prefix + name
929
930    def trainable_params(self, recurse=True):
931        """
932        Returns all trainable parameters.
933
934        Returns a list of all trainable parameters.
935
936        Args:
937            recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
938
939        Returns:
940            List, the list of trainable parameters.
941        """
942        return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
943
944    def untrainable_params(self, recurse=True):
945        """
946        Returns all untrainable parameters.
947
948        Returns a list of all untrainable parameters.
949
950        Args:
951            recurse (bool): Whether contains the untrainable parameters of subcells. Default: True.
952
953        Returns:
954            List, the list of untrainable parameters.
955        """
956        return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
957
958    def get_parameters(self, expand=True):
959        """
960        Returns an iterator over cell parameters.
961
962        Yields parameters of this cell. If `expand` is true, yield parameters of this cell and all subcells.
963
964        Args:
965            expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
966                           that are direct members of this cell. Default: True.
967
968        Returns:
969            Iteration, all parameters at the cell.
970
971        Examples:
972            >>> net = Net()
973            >>> parameters = []
974            >>> for item in net.get_parameters():
975            ...     parameters.append(item)
976        """
977        for _, param in self.parameters_and_names(expand=expand):
978            yield param
979
980    def check_names(self):
981        """
982        Check the names of cell parameters.
983        """
984        names = set("")
985        for value, param in self.parameters_and_names():
986            if param.name in names:
987                raise ValueError("The value of {} is {}, its name '{}' already exists.".
988                                 format(value, param, param.name))
989            names.add(param.name)
990
991    def parameters_and_names(self, name_prefix='', expand=True):
992        """
993        Returns an iterator over cell parameters.
994
995        Includes the parameter's name and itself.
996
997        Args:
998            name_prefix (str): Namespace. Default: ''.
999            expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
1000                           that are direct members of this cell. Default: True.
1001
1002        Returns:
1003            Iteration, all the names and corresponding parameters in the cell.
1004
1005        Examples:
1006            >>> n = Net()
1007            >>> names = []
1008            >>> for m in n.parameters_and_names():
1009            ...     if m[0]:
1010            ...         names.append(m[0])
1011        """
1012        cells = []
1013        if expand:
1014            cells = self.cells_and_names(name_prefix=name_prefix)
1015        else:
1016            cells.append((name_prefix, self))
1017
1018        params_set = set()
1019        for cell_name, cell in cells:
1020            params = cell._params.items()
1021            for par_name, par in params:
1022                if par.inited_param is not None:
1023                    par = par.inited_param
1024                if par is not None and id(par) not in params_set:
1025                    params_set.add(id(par))
1026                    par_new_name = par_name
1027                    if cell_name:
1028                        par_new_name = cell_name + '.' + par_new_name
1029
1030                    yield par_new_name, par
1031
1032    def cells_and_names(self, cells=None, name_prefix=''):
1033        """
1034        Returns an iterator over all cells in the network.
1035
1036        Includes the cell's name and itself.
1037
1038        Args:
1039            cells (str): Cells to iterate over. Default: None.
1040            name_prefix (str): Namespace. Default: ''.
1041
1042        Returns:
1043            Iteration, all the child cells and corresponding names in the cell.
1044
1045        Examples:
1046            >>> n = Net()
1047            >>> names = []
1048            >>> for m in n.cells_and_names():
1049            ...     if m[0]:
1050            ...         names.append(m[0])
1051        """
1052        t_cells = cells if cells else set()
1053        if self in t_cells:
1054            return
1055
1056        t_cells.add(self)
1057        yield name_prefix, self
1058
1059        for name, cell in self._cells.items():
1060            if cell:
1061                cells_name_prefix = name
1062                if name_prefix:
1063                    cells_name_prefix = name_prefix + '.' + cells_name_prefix
1064                for ele in cell.cells_and_names(t_cells, cells_name_prefix):
1065                    yield ele
1066
1067    def cells(self):
1068        """
1069        Returns an iterator over immediate cells.
1070
1071        Returns:
1072            Iteration, all the child cells in the cell.
1073        """
1074        return self.name_cells().values()
1075
1076    def _set_scope(self, name):
1077        """Sets the name on the first time."""
1078        if self._scope is None:
1079            self._scope = name
1080        elif self._scope == 'recompute_':
1081            self._scope = self._scope + name
1082
1083    def _children_scope_recursive(self, parent_prefix='Default'):
1084        """Generates the scope of each layer of the network recursively."""
1085        reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope")
1086
1087        for name, cell in self.name_cells().items():
1088            yield parent_prefix + "/" + name + (("-" + cell.__class__.__name__)
1089                                                if reserve_class_name_in_scope else ""), cell
1090
1091        for name, cell in self.name_cells().items():
1092            for key, value in cell._children_scope_recursive(parent_prefix + "/" + name +
1093                                                             (("-" + cell.__class__.__name__)
1094                                                              if reserve_class_name_in_scope else "")):
1095                yield key, value
1096
1097    def get_scope(self):
1098        """
1099        Returns the scope of a cell object in one network.
1100
1101        Returns:
1102            String, scope of the cell.
1103        """
1104        return self._scope
1105
1106    def generate_scope(self):
1107        """Generate the scope for each cell object in the network."""
1108        for name, cell in self._children_scope_recursive():
1109            cell._set_scope(name)
1110
1111    def name_cells(self):
1112        """
1113        Returns an iterator over all cells in the network.
1114
1115        Include name of the cell and cell itself.
1116
1117        Returns:
1118            Dict[String, Cell], all the child cells and corresponding names in the cell.
1119        """
1120        value_set = set()
1121        cells = OrderedDict()
1122        for name, cell in self._cells.items():
1123            if cell is not None and cell not in value_set:
1124                value_set.add(cell)
1125                cells[name] = cell
1126        return cells
1127
1128    def add_flags(self, **flags):
1129        """
1130        Add customized attributes for cell.
1131
1132        This method is also called when the cell class is instantiated and the class parameter 'flag' is set to True.
1133        """
1134        if not hasattr(self, "_mindspore_flags"):
1135            self._mindspore_flags = {}
1136        self._mindspore_flags.update({**flags})
1137        self.__dict__.update({**flags})
1138        return self
1139
1140    def add_flags_recursive(self, **flags):
1141        """
1142        If a cell contains child cells, this method can recursively customize attributes of all cells.
1143        """
1144        self.add_flags(**flags)
1145        for cell in self.cells():
1146            cell.add_flags_recursive(**flags)
1147        return self
1148
1149    def _add_init_args(self, **args):
1150        if hasattr(self, '_cell_init_args'):
1151            self._cell_init_args += str({**args})
1152
1153    def get_flags(self):
1154        """
1155        Get the attributes of cell's flags.
1156        """
1157        if not hasattr(self, "_mindspore_flags"):
1158            self._mindspore_flags = {}
1159        return self._mindspore_flags
1160
1161    def to_float(self, dst_type):
1162        """
1163        Add cast on all inputs of cell and child cells to run with certain float type.
1164
1165        If `dst_type is mindspore.dtype.float16`, all the inputs of Cell including input, Parameter, Tensor
1166        as const will be cast to float16. Please refer to the usage in source code of
1167        `mindspore.train.amp.build_train_network`.
1168
1169        Note:
1170            Multiple calls will overwrite.
1171
1172        Args:
1173            dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1174                dst_type can be `mindspore.dtype.float16` or `mindspore.dtype.float32`.
1175
1176        Returns:
1177            Cell, the cell itself.
1178
1179        Raises:
1180            ValueError: If dst_type is not float32 or float16.
1181        """
1182        if dst_type not in (mstype.float16, mstype.float32):
1183            raise ValueError("The dst_type should inside float32 or float16.")
1184        flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
1185        self.add_flags_recursive(**flags)
1186        self._add_init_args(**flags)
1187        return self
1188
1189    def set_boost(self, boost_type):
1190        """
1191        In order to improve the network performance, configure the network auto enable to
1192        accelerate the algorithm in the algorithm library.
1193
1194        If `boost_type is not in the algorithm library`, Please view the algorithm in the algorithm library
1195        through `algorithm library`.
1196
1197        Note:
1198            Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
1199
1200        Args:
1201            boost_type (str): accelerate algorithm.
1202
1203        Returns:
1204            Cell, the cell itself.
1205
1206        Raises:
1207            ValueError: If boost_type is not in the algorithm library.
1208        """
1209        if boost_type not in ("less_bn",):
1210            raise ValueError("The boost_type is not in the algorithm library.")
1211        flags = {"less_bn": boost_type == "less_bn"}
1212        self.add_flags_recursive(**flags)
1213        return self
1214
1215    def set_grad(self, requires_grad=True):
1216        """
1217        Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network require
1218        gradients. If true, the backward network needed to compute the gradients will be generated when the forward
1219        network is executed.
1220
1221        Args:
1222            requires_grad (bool): Specifies if the net need to grad, if it is
1223                true, the cell will construct backward network in pynative mode. Default: True.
1224
1225        Returns:
1226            Cell, the cell itself.
1227        """
1228        self.requires_grad = requires_grad
1229        return self
1230
1231    def set_train(self, mode=True):
1232        """
1233        Sets the cell to training mode.
1234
1235        The cell itself and all children cells will be set to training mode. Layers that have different constructions
1236        for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
1237        set to true, the training branch will be executed, otherwise another branch.
1238
1239        Args:
1240            mode (bool): Specifies whether the model is training. Default: True.
1241
1242        Returns:
1243            Cell, the cell itself.
1244        """
1245        if mode is False:
1246            self._phase = 'predict'
1247        else:
1248            self._phase = 'train'
1249        self.add_flags_recursive(training=mode)
1250        return self
1251
1252    def set_broadcast_flag(self, mode=True):
1253        """
1254        Set the cell to data_parallel mode.
1255
1256        The cell can be accessed as an attribute using the given name.
1257
1258        Args:
1259            mode (bool): Specifies whether the model is data_parallel. Default: True.
1260        """
1261        self.add_flags_recursive(broadcast_flag=mode)
1262        return self
1263
1264    def set_auto_parallel(self):
1265        """
1266        Set the cell to auto parallel mode.
1267
1268        Note:
1269            If a cell needs to use the auto parallel or semi auto parallel mode for training, evaluation or prediction,
1270            this interface needs to be called by the cell.
1271        """
1272        self._auto_parallel_mode = True
1273        self.add_flags(auto_parallel=True)
1274        self._get_construct_inputs_number_and_name()
1275
1276    def _hook_construct(self, *inputs):
1277        """Hook construct method to replace original construct method when hook function enabled."""
1278        inputs = self._backward_hook(*inputs)
1279        inputs = self.construct(inputs)
1280        outputs = self._backward_hook(inputs)
1281        return outputs
1282
1283    def register_backward_hook(self, fn):
1284        """
1285        Set the cell backward hook function. Note that this function is only supported in pynative mode.
1286
1287        Note:
1288            fn must be defined as the following code. `cell_name` is the name of registered cell.
1289            `grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the
1290            next cell or primitive, which may be modified and returned.
1291            hook_fn(cell_name, grad_input, grad_output) -> Tensor or None.
1292
1293        Args:
1294            fn (function): Specifies the hook function with grad as input.
1295
1296        """
1297        self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
1298        self.enable_hook = True
1299
1300    def set_param_ps(self, recurse=True, init_in_server=False):
1301        """
1302        Set whether the trainable parameters are updated by parameter server and whether the
1303        trainable parameters are initialized on server.
1304
1305        Note:
1306            It only works when a running task is in the parameter server mode.
1307
1308        Args:
1309            recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
1310            init_in_server (bool): Whether trainable parameters updated by parameter server are
1311                initialized on server. Default: False.
1312        """
1313        params = self.trainable_params(recurse)
1314        for param in params:
1315            param.set_param_ps(init_in_server)
1316
1317    def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
1318        """
1319        Set the way of parameter and server interaction.
1320
1321        Args:
1322            push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
1323            pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
1324            requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
1325        """
1326        params = self.parameters_and_names()
1327        for param in params:
1328            param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
1329
1330    def set_comm_fusion(self, fusion_type, recurse=True):
1331        """
1332        Set `comm_fusion` for all the parameters in the Net. Please refer to the description of
1333        `mindspore.common.parameter.comm_fusion`.
1334
1335        Note:
1336            The value of attribute will be overwritten when the function is called multiply.
1337
1338        Args:
1339            fusion_type (int): The value of `comm_fusion`.
1340            recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
1341        """
1342        Validator.check_non_negative_int(fusion_type)
1343        for param in self.trainable_params(recurse):
1344            param.comm_fusion = fusion_type
1345        return self
1346
1347    def _set_recompute_scope(self, mode):
1348        prefix = 'recompute_'
1349        if mode is True:
1350            if self._scope is None:
1351                self._scope = prefix
1352            elif not self._scope.startswith(prefix):
1353                self._scope = prefix + self._scope
1354        elif self._scope is not None and self._scope.startswith(prefix):
1355            self._scope = self._scope[len(prefix):]
1356
1357    def _mp_comm_recompute(self, mp_comm_recompute=True):
1358        """
1359        Set the model parallel communication in cell recomputed.
1360        """
1361        for _, value in self._primitives.items():
1362            if value:
1363                value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
1364        for cell in self.cells():
1365            cell._mp_comm_recompute(mp_comm_recompute)
1366
1367    def _parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute=False):
1368        """
1369        Set the parallel optimizer communication in cell recomputed.
1370        """
1371        for param in self.trainable_params():
1372            param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
1373
1374    def _recompute(self, mode=True, output_recompute=False):
1375        """
1376        Set the cell recomputed.
1377        """
1378        if context.get_context("mode") == context.PYNATIVE_MODE:
1379            raise TypeError("Recompute is not supported in pynative mode currently.")
1380        Validator.check_bool(mode)
1381        Validator.check_bool(output_recompute)
1382        if not self._has_config_recompute:
1383            self._has_config_recompute = True
1384        else:
1385            raise RuntimeError("The recompute interface can be configured only once."
1386                               " When the parent cell is configured, the child cell should not be configured")
1387        self._set_recompute_scope(mode)
1388        if mode and not output_recompute:
1389            self.add_flags(output_no_recompute=True)
1390        for cell in self.cells():
1391            cell._recompute(mode, True)
1392
1393    @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
1394    def recompute(self, **kwargs):
1395        """
1396        Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
1397        set recomputed feeds into some backward nodes for computing gradient, rather than storing the
1398        intermediate activation computed in forward pass, we will recompute it in backward pass.
1399
1400        Note:
1401
1402            - If the computation involves something like randomization or global variable, the equivalence
1403              is not guaranteed currently.
1404            - If the recompute api of a primitive in this cell is also called, the recompute mode of this
1405              primitive is subject to the recompute api of the primitive.
1406            - The interface can be configured only once.
1407              Therefore, when the parent cell is configured, the child cell should not be configured.
1408            - When the memory remains after applying the recompute, configuring 'mp_comm_recompute=False'
1409              to improve performance if necessary.
1410            - When the memory still not enough after applying the recompute, configuring
1411              'parallel_optimizer_comm_recompute=True' to save more memory if necessary.
1412              Cells in the same fusion group should has the same parallel_optimizer_comm_recompute configures.
1413
1414        Args:
1415            mp_comm_recompute (bool): Specifies whether the model parallel communication operators
1416                in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
1417            parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
1418                introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
1419                Default: False.
1420        """
1421        self._recompute()
1422        if 'mp_comm_recompute' in kwargs.keys():
1423            self._mp_comm_recompute(kwargs['mp_comm_recompute'])
1424        if 'parallel_optimizer_comm_recompute' in kwargs.keys():
1425            if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1:
1426                raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
1427                                 "are not support recomputation in pipeline parallel.")
1428            self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
1429
1430        for key, _ in kwargs.items():
1431            if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'):
1432                raise ValueError("Recompute keyword %s is not recognized!" % key)
1433
1434    def infer_param_pipeline_stage(self):
1435        """
1436        Infer pipeline stages of all parameters in the cell.
1437
1438        Note:
1439            - If a parameter does not belong to any cell which has been set pipeline_stage,
1440              the parameter should use add_pipeline_stage to add it's pipeline_stage information.
1441            - If a parameter P has been used by two operator in different stages "stageA" and "stageB",
1442              the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
1443              to add it's stage information before use infer_param_pipeline_stage.
1444
1445        Returns:
1446            The params belong to current stage in pipeline parallel.
1447
1448        Raises:
1449            RuntimeError: If there is a parameter does not belong to any stage.
1450        """
1451        from mindspore.parallel._utils import _get_global_rank, _get_device_num
1452        stage_num = context.get_auto_parallel_context("pipeline_stages")
1453        device_num = _get_device_num()
1454        rank_id = _get_global_rank()
1455        per_stage_devices = device_num // stage_num
1456        current_stage = rank_id // per_stage_devices
1457        params = []
1458        for param in self.trainable_params():
1459            if not param._pipeline_stage_list:
1460                raise RuntimeError("The parameter {} does not belong to any stage, "
1461                                   "please check whether the cell where the param locates"
1462                                   " has been set pipeline_stage. "
1463                                   "Otherwise, the parameter should use add_pipeline_stage "
1464                                   "to add its stage information".format(param.name))
1465            if current_stage in param._pipeline_stage_list:
1466                params.append(param)
1467        return params
1468
1469
1470class GraphKernel(Cell):
1471    """
1472    Base class for GraphKernel.
1473
1474    A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
1475    enable_graph_kernel in context is set to True.
1476
1477    This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead.
1478
1479    GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as
1480    normal `Cell` objects.
1481
1482    Args:
1483        auto_prefix (bool): Recursively generate namespaces. Default: True.
1484        flags (dict) : Set graph flags. Default: None.
1485
1486    Supported Platforms:
1487        ``Ascend`` ``GPU`` ``CPU``
1488
1489    Examples:
1490        >>> class Relu(nn.GraphKernel):
1491        ...    def __init__(self):
1492        ...        super(Relu, self).__init__()
1493        ...        self.max = P.Maximum()
1494        ...
1495        ...    def construct(self, x):
1496        ...        return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
1497    """
1498
1499    @deprecated("1.3", "Cell", True)
1500    def __init__(self, auto_prefix=True, flags=None):
1501        super(GraphKernel, self).__init__(auto_prefix, flags)
1502
1503    def construct(self):
1504        raise NotImplementedError
1505
1506
1507class GraphCell(Cell):
1508    """
1509    Base class for running the graph loaded from a MindIR.
1510
1511    This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
1512    diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
1513
1514    Args:
1515        graph (object): A compiled graph loaded from MindIR.
1516
1517    Supported Platforms:
1518        ``Ascend`` ``GPU`` ``CPU``
1519
1520    Examples:
1521        >>> import numpy as np
1522        >>> import mindspore.nn as nn
1523        >>> from mindspore import Tensor, export, load
1524        >>>
1525        >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
1526        >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1527        >>> export(net, input, file_name="net", file_format="MINDIR")
1528        >>> graph = load("net.mindir")
1529        >>> net = nn.GraphCell(graph)
1530        >>> output = net(input)
1531        >>> print(output)
1532        [[[[4. 6. 4.]
1533           [6. 9. 6.]
1534           [4. 6. 4.]]]]
1535    """
1536    def __init__(self, graph):
1537        super(GraphCell, self).__init__(auto_prefix=True)
1538        if not isinstance(graph, FuncGraph):
1539            raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
1540        self.graph = graph
1541
1542    def construct(self, *inputs):
1543        return self.graph(*inputs)
1544
1545    def __call__(self, *inputs):
1546        self.phase = "graph_load_from_mindir"
1547        self._add_attr("graph_load_from_mindir", self.graph)
1548        return self.compile_and_run(*inputs)
1549