• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""cell"""
16from __future__ import absolute_import
17
18import gc
19import inspect
20import os
21import time
22from collections import OrderedDict
23import numpy
24
25from mindspore._checkparam import args_type_check, check_hook_fn
26from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
27from mindspore import log as logger
28from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
29from mindspore.common.hook_handle import HookHandle
30from mindspore.context import ParallelMode
31from mindspore import context
32from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
33from mindspore import _checkparam as Validator
34from mindspore.common import dtype as mstype
35from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
36from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
37from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
38from mindspore.common.parameter import Parameter, ParameterTuple
39from mindspore.common.tensor import Tensor
40from mindspore.ops.operations import Cast
41from mindspore.ops.primitive import Primitive
42from mindspore.ops.operations import _inner_ops as inner
43from mindspore.parallel.shard import Shard
44from mindspore._check_jit_forbidden_api import jit_forbidden_register
45from mindspore.common._decorator import deprecated
46from mindspore.common._register_for_recompute import recompute_registry
47
48
49class Cell(Cell_):
50    """
51    The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
52    base class.
53
54    Layers in `mindspore.nn` are also the subclass of Cell, such as :class:`mindspore.nn.Conv2d`,
55    and :class:`mindspore.nn.ReLU`, etc. Cell will be compiled into a calculation
56    graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
57    PYNATIVE_MODE (dynamic graph mode).
58
59    .. note::
60        Cell is the inference mode by default. For a class that inherits a Cell,
61        if the training and inference have different structures, the subclass performs the inference branch by default.
62        To set the training mode, refer to `mindspore.nn.Cell.set_train` .
63
64    .. warning::
65        In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
66        named 'phase' or 'cells', otherwise, an error will be raised.
67
68    Args:
69        auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
70                      affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
71                      automatically prefixed, otherwise not. In general, the backbone network should be set to
72                      ``True`` , otherwise the duplicate name problem will appear. The cell to train the backbone
73                      network, such as optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to
74                      ``False`` , otherwise the parameter name in backbone will be changed by mistake.
75                      Default: ``True`` .
76        flags (dict, optional): Network configuration information, currently it is used for the binding of network
77                      and dataset. Users can also customize network attributes by this parameter. Default: ``None`` .
78
79    Supported Platforms:
80        ``Ascend`` ``GPU`` ``CPU``
81
82    Examples:
83        >>> import mindspore.nn as nn
84        >>> from mindspore import ops
85        >>> class MyCell(nn.Cell):
86        ...     def __init__(self, forward_net):
87        ...         super(MyCell, self).__init__(auto_prefix=False)
88        ...         self.net = forward_net
89        ...         self.relu = ops.ReLU()
90        ...
91        ...     def construct(self, x):
92        ...         y = self.net(x)
93        ...         return self.relu(y)
94        >>>
95        >>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
96        >>> my_net = MyCell(inner_net)
97        >>> print(my_net.trainable_params())
98        ... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
99        ... # the parameter's name will be 'net.weight'.
100        [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
101    """
102
103    IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
104                   '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
105                   '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
106                   '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
107                   '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
108    total_instance_count = 0
109
110    def __init__(self, auto_prefix=True, flags=None):
111        Cell_.__init__(self, self._cell_tag)
112        Cell.total_instance_count += 1
113        self.instance_count = Cell.total_instance_count
114        self._params = OrderedDict()
115        self._cells = OrderedDict()
116        self._params_list = OrderedDict()
117        self._primitives = OrderedDict()
118        self.training = False
119        self.requires_grad = False
120        self.pynative = False
121        self._attr_synced = False
122        self._param_prefix = ''
123        self._auto_prefix = auto_prefix
124        self._scope = None
125        self._phase = 'train'
126        self._parameter_layout_dict = {}
127        self._parallel_parameter_name_list = ()
128        self._parallel_parameter_merge_net_dict = {}
129        self._create_time = int(time.time() * 1e9)
130        self.arguments_key = ""
131        self.compile_cache = set()
132        self.phase_cache = dict()
133        cells_compile_cache[id(self)] = self.compile_cache
134        self.parameter_broadcast_done = False
135        self._id = 1
136        self.exist_names = set("")
137        self.exist_objs = set()
138        self.recompute_cell = None
139        self.sig = inspect.signature(self.construct)
140        init_pipeline()
141
142        # call gc to release GE session resources used by non-used cell objects
143        if os.getenv('GC_COLLECT_IN_CELL') == '1':
144            gc.collect()
145
146        if flags:
147            self.add_flags(**flags)
148        self._bprop_debug = False
149        self._forward_pre_hook = OrderedDict()
150        self._forward_hook = OrderedDict()
151        self._enable_forward_pre_hook = False
152        self._enable_forward_hook = False
153        self._enable_backward_hook = False
154        self._cell_backward_hook = None
155        self._is_recursion_hook = False
156        self.cell_type = None
157        self.cast = Cast()
158        self._has_config_recompute = False
159        self._user_parameters = []
160        self._dynamic_shape_inputs = None
161        self._compile_args = None
162        self.saved_dynamic_shape = None
163        self._jit_config_dict = dict()
164        self.grad_ops_label = False
165        self.ge_sync_data = False
166        self._is_check_and_refresh = False
167        self._amp_level = ""
168        self._init_flag = False
169
170    def __getstate__(self):
171        base = Cell_.__getstate__(self)
172        return base, self.__dict__
173
174    def __setstate__(self, state):
175        base, dict_ = state
176        Cell_.__setstate__(self, base)
177        self.__dict__ = dict_
178        self._attr_synced = False
179
180    def __bool__(self):
181        return True
182
183    @property
184    def _cell_tag(self):
185        # `<class 'xxxxxxx'>` to `xxxxxxx`
186        return str(self.__class__)[8:-2]
187
188    @property
189    def create_time(self):
190        return self._create_time
191
192    @property
193    def cell_init_args(self):
194        return self._cell_init_args
195
196    @property
197    def param_prefix(self):
198        """
199        Param prefix is the prefix of current cell's direct child parameter.
200
201        Examples:
202            >>> import mindspore as ms
203            >>> from mindspore import Tensor, nn
204            ...
205            >>> class Net(nn.Cell):
206            ...     def __init__(self):
207            ...         super(Net, self).__init__()
208            ...         self.dense = nn.Dense(2, 2)
209            ...
210            ...     def construct(self, x):
211            ...         x = self.dense(x)
212            ...         return x
213            >>> net = Net()
214            >>> net.update_cell_prefix()
215            >>> print(net.dense.param_prefix)
216            dense
217        """
218        return self._param_prefix
219
220    @property
221    def bprop_debug(self):
222        """
223        Get whether cell custom bprop debug is enabled.
224
225        Tutorial Examples:
226            - `Cell and Parameter - Custom Cell Reverse
227              <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
228        """
229        return self._bprop_debug
230
231    @bprop_debug.setter
232    def bprop_debug(self, value):
233        """
234        Set whether to enable cell custom bprop debug.
235
236        Note:
237            When bprop is defined in cell, the bprop function will be executed
238            in python interpreter when bprop debug is true, and will be parsed
239            and add to graph when bprop debug is false.
240
241        Args:
242            value (bool): Specifies whether to enable bprop debug. Default: ``False``.
243        """
244        if not isinstance(value, bool):
245            raise TypeError(f"For 'Cell', the property 'bprop_debug' must be bool type, but got type {type(value)}.")
246        self._bprop_debug = value
247
248    def update_cell_prefix(self):
249        """
250        Update the `param_prefix` of all child cells.
251
252        After being invoked, it can get all the cell's children's name prefix by '_param_prefix'.
253        """
254        cells_name = self.cells_and_names()
255
256        for cell_name, cell in cells_name:
257            cell._param_prefix = cell_name
258
259    def update_cell_type(self, cell_type):
260        """
261        The current cell type is updated when a quantization aware training network is encountered.
262
263        After being invoked, it can set the cell type to 'cell_type'.
264
265        Args:
266            cell_type(str): The type of cell to be updated, cell_type can be "quant" or "second-order".
267        """
268        self.cell_type = cell_type
269
270    @cell_init_args.setter
271    def cell_init_args(self, value):
272        if not isinstance(value, str):
273            raise TypeError(f"For 'Cell', the property 'cell_init_args' must be string type, "
274                            f"but got type {type(value)}.")
275        self._cell_init_args = value
276
277    @property
278    def phase(self):
279        return self._phase
280
281    @phase.setter
282    def phase(self, value):
283        if not isinstance(value, str):
284            raise TypeError(f"For 'Cell', the property 'phase' must be string type, but got type {type(value)}.")
285        self._phase = value
286
287    @property
288    def parameter_layout_dict(self):
289        """
290        `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
291        distributed operator information.
292        """
293        return self._parameter_layout_dict
294
295    @property
296    def cls_name(self):
297        return self.__class__.__name__
298
299    @parameter_layout_dict.setter
300    def parameter_layout_dict(self, value):
301        if not isinstance(value, dict):
302            raise TypeError(f"For 'Cell', the property 'parameter_layout_dict' must be dict type, "
303                            f"but got type {type(value)}.")
304        self._parameter_layout_dict = value
305
306    @property
307    def parallel_parameter_name_list(self):
308        return self._parallel_parameter_name_list
309
310    @parallel_parameter_name_list.setter
311    def parallel_parameter_name_list(self, value):
312        if not isinstance(value, list):
313            raise TypeError(f"For 'Cell', the property 'parallel_parameter_name_list' must be list type, "
314                            f"but got type {type(value)}.")
315        self._parallel_parameter_name_list = value
316
317    @property
318    def pipeline_stage(self):
319        """
320        `pipeline_stage` represents the pipeline stage of current Cell.
321        """
322        return self._pipeline_stage
323
324    @pipeline_stage.setter
325    def pipeline_stage(self, value):
326        """
327        Set the `pipeline_stage` of a Cell.
328
329        Args:
330            value (int): The pipeline stage of a parameter.
331
332        Raises:
333            TypeError: If `value` is not int type or is a bool type.
334            ValueError: If `value` is not a positive integer.
335        """
336        if not isinstance(value, int) or isinstance(value, bool):
337            raise TypeError("For 'Cell', the property 'pipeline_stage' "
338                            "must be int type, but got type : {}".format(type(value)))
339
340        if value < 0:
341            raise ValueError("For 'Cell', the property 'pipeline_stage' "
342                             "can not be less than 0, but got {}".format(value))
343        self._pipeline_stage = value
344        for item in self.trainable_params():
345            item.add_pipeline_stage(value)
346
347    @property
348    def pipeline_segment(self):
349        return self._pipeline_segment
350
351    @pipeline_segment.setter
352    def pipeline_segment(self, value):
353        if not isinstance(value, int) or isinstance(value, bool):
354            raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
355                            "must be int type, but got type : {}".format(type(value)))
356
357        if value < 0:
358            raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
359                             "can not be less than 0, but got {}".format(value))
360        self._pipeline_segment = value
361
362    @property
363    def parallel_parameter_merge_net_dict(self):
364        return self._parallel_parameter_merge_net_dict
365
366    @parallel_parameter_merge_net_dict.setter
367    def parallel_parameter_merge_net_dict(self, value):
368        if not isinstance(value, dict):
369            raise TypeError(f"For 'Cell', the property 'parallel_parameter_merge_net_dict' must be dict type, "
370                            f"but got type {type(value)}.")
371        self._parallel_parameter_merge_net_dict = value
372
373    @property
374    def jit_config_dict(self):
375        return self._jit_config_dict
376
377    def get_func_graph_proto(self):
378        """Return graph binary proto."""
379        exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
380        return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True)
381
382    def __getattr__(self, name):
383        if '_params' in self.__dict__:
384            params = self.__dict__['_params']
385            if name in params:
386                return params[name]
387        if '_cells' in self.__dict__:
388            cells = self.__dict__['_cells']
389            if name in cells:
390                return cells[name]
391        if '_params_list' in self.__dict__:
392            params_list = self.__dict__['_params_list']
393            if name in params_list:
394                return params_list[name]
395        raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
396
397    def __del__(self):
398        if isinstance(cells_compile_cache, dict):
399            # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
400            # here using pop(id(self), None) to avoid KeyError exception
401            cells_compile_cache.pop(id(self), None)
402        if hasattr(self, "compile_cache") and self.compile_cache:
403            _cell_graph_executor.del_net_res(self, self.compile_cache)
404        if isinstance(self, GraphCell):
405            _cell_graph_executor.dec_graph_cell_count()
406        Cell.total_instance_count -= 1
407
408    def __delattr__(self, name):
409        if name in self._params:
410            del self._params[name]
411        elif name in self._cells:
412            del self._cells[name]
413        elif '_params_list' in self.__dict__ and name in self._params_list:
414            del self._params_list[name]
415        else:
416            object.__delattr__(self, name)
417        self._attr_synced = False
418
419    def _cast_mixed_precision_inputs(self, inputs, dst_type):
420        """Cast input for mixed precision"""
421        res = list()
422        for item in inputs:
423            if isinstance(item, tuple):
424                res.append(self._cast_mixed_precision_inputs(item, dst_type))
425            elif isinstance(item, float):
426                res.append(self.cast(item, dst_type))
427            elif hasattr(item, "dtype") and item.dtype in \
428                    {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
429                res.append(self.cast(item, dst_type))
430            else:
431                res.append(item)
432        return tuple(res)
433
434    def cast_inputs(self, inputs, dst_type):
435        """
436        Cast inputs to specified type.
437
438        Args:
439            inputs (tuple[Tensor]): The cell inputs.
440            dst_type (mindspore.dtype): The specified data type.
441
442        returns:
443            tuple[Tensor], the result with destination data type.
444        """
445        res = list()
446        for item in inputs:
447            if isinstance(item, tuple):
448                res.append(self.cast_inputs(item, dst_type))
449            else:
450                res.append(self.cast(item, dst_type))
451        return tuple(res)
452
453    def _do_parameter_broadcast(self):
454        if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
455            if not self.parameter_broadcast_done:
456                _pynative_executor.parameter_broadcast(self, self.phase)
457                self.parameter_broadcast_done = True
458
459    def run_construct(self, cast_inputs, kwargs):
460        """
461        Run the construct function.
462
463        Note:
464            This function will be removed in a future version. It is not recommended to call this function.
465
466        Args:
467            cast_inputs (tuple): The input objects of Cell.
468            kwargs (dict): Provide keyword arguments.
469
470        Returns:
471            output, the output object of Cell.
472        """
473        logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. "
474                       f"Calling this function is not recommended.")
475        output = self._run_construct(cast_inputs, kwargs)
476        return output
477
478    def _run_construct(self, cast_inputs, kwargs):
479        """Run the construct function"""
480        if self._enable_forward_pre_hook:
481            cast_inputs = self._run_forward_pre_hook(cast_inputs)
482        if self._enable_backward_hook:
483            output = self._backward_hook_construct(*cast_inputs, **kwargs)
484        elif hasattr(self, "_shard_fn"):
485            output = self._shard_fn(*cast_inputs, **kwargs)
486        else:
487            if self.recompute_cell is not None:
488                output = self.recompute_cell(*cast_inputs, **kwargs)
489            else:
490                output = self.construct(*cast_inputs, **kwargs)
491        if self._enable_forward_hook:
492            output = self._run_forward_hook(cast_inputs, output)
493        return output
494
495    def _check_construct_args(self, *args):
496        """Check the args needed by the function construct"""
497        positional_args = 0
498        default_args = 0
499        has_var = False
500        for value in inspect.signature(self.construct).parameters.values():
501            if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
502                has_var = True
503            if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
504                if value.default is inspect.Parameter.empty:
505                    positional_args += 1
506                else:
507                    default_args += 1
508
509        if has_var:
510            return
511
512        if len(args) < positional_args:
513            raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
514                            f"but got {len(args)}. When using set_inputs, please make sure that all networks "
515                            f"and loss functions are configured with set_inputs.")
516
517        if len(args) > positional_args + default_args:
518            construct_inputs_names = self.construct.__code__.co_varnames
519            if 'self' not in construct_inputs_names:
520                raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
521
522            raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
523                            f"{default_args} default argument, total {positional_args + default_args}, "
524                            f"but got {len(args)}.")
525
526    def _hook_fn_registered(self):
527        '''Hook function in graph mode'''
528        # Check super().__init__() in graph mode.
529        try:
530            if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
531                return True
532        except AttributeError as e:
533            raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
534                                 f"Please use 'super().__init__()'.") from e
535        if not self._is_recursion_hook:
536            self._is_recursion_hook = True
537            for cell in self.cells():
538                if cell._hook_fn_registered():
539                    return True
540        return False
541
542    def _get_prims_recursively(self):
543        all_prims = list()
544        for _, value in self._primitives.items():
545            if value:
546                all_prims.append(value)
547
548        for cell in self.cells():
549            all_prims.extend(cell._get_prims_recursively())
550
551        return all_prims
552
553    def set_data_parallel(self):
554        """
555        For all primitive ops in this cell(including ops of cells that wrapped by this cell),
556        if parallel strategy is not specified, then instead of auto-searching, data parallel
557        strategy will be generated for those primitive ops.
558
559        Note:
560            Only effective while using auto_parallel_context = ParallelMode.AUTO_PARALLEL under graph mode.
561
562        Examples:
563            >>> import mindspore.nn as nn
564            >>> net = nn.Dense(3, 4)
565            >>> net.set_data_parallel()
566        """
567        if context._get_mode() == context.PYNATIVE_MODE:
568            raise ValueError("set_data_parallel: does not support PyNative mode.")
569
570        all_prims = self._get_prims_recursively()
571        for prim in all_prims:
572            prim.add_prim_attr("strategy_gen_mode", "data_parallel")
573
574    def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
575        """
576        Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
577        generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
578        execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
579        strategy for others will be set by sharding propagation.
580        in_strategy and out_strategy define the input and output layout respectively.
581        in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
582        this input/output, and None represents data_parallel,
583        which can refer to the description of `mindspore.ops.Primitive.shard`.
584        The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
585
586        Note:
587            If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
588            "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
589            If the input contain Parameter, its strategy should be set in `in_strategy`.
590
591        Args:
592            in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
593                             defines the layout of the corresponding input and None represents a data parallel strategy.
594            out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
595                                               It is not in use right now. Default: ``None`` .
596            parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
597                                                defines the layout of the parameter like "param_name: layout".
598                                                The key is a parameter name of type 'str'.
599                                                The value is a 1-D integer tuple, indicating the corresponding layout.
600                                                If the parameter name is incorrect or the corresponding parameter
601                                                has been set, the parameter setting will be ignored.
602                                                Default: ``None`` .
603            device (string): Select a certain device target. It is not in use right now.
604                             Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
605            level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
606                         over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
607                         use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
608
609        Returns:
610            Function, return the cell construct function that will be executed under auto parallel process.
611
612        Examples:
613            >>> import mindspore.nn as nn
614            >>>
615            >>> class Block(nn.Cell):
616            ...   def __init__(self):
617            ...     self.dense1 = nn.Dense(10, 10)
618            ...     self.relu = nn.ReLU()
619            ...     self.dense2 = nn.Dense2(10, 10)
620            ...   def construct(self, x):
621            ...     x = self.relu(self.dense2(self.relu(self.dense1(x))))
622            ...     return x
623            >>>
624            >>> class example(nn.Cell):
625            ...   def __init__(self):
626            ...     self.block1 = Block()
627            ...     self.block2 = Block()
628            ...     self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
629            ...                                           parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
630            ...   def construct(self, x):
631            ...     x = self.block1(x)
632            ...     x = self.block2_shard(x)
633            ...     return x
634        """
635        if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
636            raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
637                                 f"Please check the parallel mode in parallel context.")
638
639        shard_fn = Shard()
640        fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
641        object.__setattr__(self, "_shard_fn", fn)
642        return fn
643
644    def auto_cast_inputs(self, inputs):
645        """
646        Auto cast inputs in mixed precision scenarios.
647
648        Args:
649            inputs (tuple): the inputs of construct.
650
651        Returns:
652            Tuple, the inputs after data type cast.
653        """
654        msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
655        logger.warning(msg)
656        cast_inputs = inputs
657        mixed_type = self.get_mixed_precision_type()
658        if mixed_type == MixedPrecisionType.FP16:
659            cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
660        if mixed_type == MixedPrecisionType.FP32:
661            cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
662
663        return cast_inputs
664
665    def _init_check(self):
666        for param in self.get_parameters(expand=False):
667            if param.has_init:
668                param.init_data()
669
670    def _self_check(self):
671        if not self._is_check_and_refresh:
672            self.check_names_and_refresh_name()
673            self._is_check_and_refresh = True
674
675    def _predict(self, *args, **kwargs):
676        if not hasattr(self, "phase"):
677            return False, None
678        if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
679            new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
680            res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
681            res = _convert_python_data(res)
682            return True, res
683        return False, None
684
685    def __call__(self, *args, **kwargs):
686        # Run in Graph mode.
687        if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
688            if kwargs:
689                bound_arguments = self.sig.bind(*args, **kwargs)
690                bound_arguments.apply_defaults()
691                args = bound_arguments.args
692                kwargs = bound_arguments.kwargs
693
694            predict_compiled, res = self._predict(*args, **kwargs)
695            if predict_compiled:
696                return res
697            self._check_construct_args(*args)
698
699            if self._hook_fn_registered():
700                logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
701                               f"function, please use context.set_context to set pynative mode.")
702            self._self_check()
703            out = self.compile_and_run(*args, **kwargs)
704            return out
705
706        # Run in PyNative mode.
707        self._self_check()
708        if not self._init_flag:
709            self._init_check()
710            self._init_flag = True
711
712        if self.requires_grad:
713            _pynative_executor.set_grad_flag(True)
714
715        try:
716            _pynative_executor.new_graph(self, *args, **kwargs)
717            output = self._run_construct(args, kwargs)
718            _pynative_executor.end_graph(self, output, *args, **kwargs)
719        except Exception as err:
720            _pynative_executor.clear_res()
721            raise err
722
723        return output
724
725    def _add_attr(self, name, value):
726        if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
727            super(Cell, self)._add_attr(name, value)
728
729    def _sync_attr_for_compile(self):
730        """Sync the attr to c++ object."""
731        if self._attr_synced:
732            return
733        cells = self.__dict__.get('_cells')
734        for key in cells:
735            cell = cells[key]
736            cell._sync_attr_for_compile()
737            self._add_attr(key, cell)
738        params = self.__dict__.get('_params')
739        for key in params:
740            if '.' in key:
741                continue
742            param = params[key]
743            self._add_attr(key, param)
744        params_list = self.__dict__.get('_params_list')
745        for key in params_list:
746            params_list_item = params_list[key]
747            self._add_attr(key, params_list_item)
748        for key in self.__dict__:
749            value = self.__dict__[key]
750            self._add_attr(key, value)
751        self._attr_synced = True
752
753    def _set_attr_for_parameter(self, name, value):
754        """Set attr for parameter."""
755        cells = self.__dict__.get('_cells')
756        params = self.__dict__.get('_params')
757        if params is None:
758            raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
759        if name in self.__dict__:
760            if self.__dict__[name] is not None:
761                raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
762            del self.__dict__[name]
763        if cells and name in cells:
764            raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
765        self.insert_param_to_cell(name, value)
766
767    def _set_attr_for_parameter_tuple(self, name, value):
768        """Set attr for parameter in ParameterTuple."""
769        params = self.__dict__.get('_params')
770        params_list = self.__dict__.get('_params_list')
771        if params is None:
772            raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
773        exist_names = set("")
774        exist_objs = set()
775        for item in value:
776            if item in exist_objs:
777                # If there are multiple identical objects, their names only check once.
778                continue
779            exist_objs.add(item)
780            if item.name == PARAMETER_NAME_DEFAULT:
781                logger.warning("For 'Cell', the parameter definition is deprecated.\n"
782                               "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
783                item.name = item.name + "$" + str(self._id)
784                self._id += 1
785            self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
786            if item.name in exist_names:
787                raise ValueError("The value {} , its name '{}' already exists. "
788                                 "Please set a unique name for the parameter.".format(value, item.name))
789            exist_names.add(item.name)
790
791        if context._get_mode() == context.PYNATIVE_MODE:
792            if name in self.__dict__:
793                del self.__dict__[name]
794            if name in params:
795                del params[name]
796            params_list[name] = value
797        else:
798            object.__setattr__(self, name, value)
799
800    def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
801        """Set attr for parameter in list or tuple."""
802        for item in value:
803            if item in self.exist_objs:
804                # If there are multiple identical objects, their names only check once.
805                continue
806            self.exist_objs.add(item)
807            if item.name == PARAMETER_NAME_DEFAULT:
808                item.name = item.name + "$" + str(self._id)
809                self._id += 1
810            if item.name in self.exist_names:
811                raise ValueError("The value {} , its name '{}' already exists. "
812                                 "Please set a unique name for the parameter.".format(value, item.name))
813            self.exist_names.add(item.name)
814        object.__setattr__(self, name, value)
815
816    def _set_attr_for_cell(self, name, value):
817        """Set attr for cell."""
818        cells = self.__dict__.get('_cells')
819        params = self.__dict__.get('_params')
820        if cells is None:
821            raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
822        if name in self.__dict__:
823            del self.__dict__[name]
824        if params and name in params:
825            raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
826        if self._auto_prefix:
827            value.update_parameters_name(name + '.')
828        cells[name] = value
829        if hasattr(self, '_cell_init_args'):
830            self.cell_init_args += str({name: value})
831
832    def _set_attr_for_params(self, name, value):
833        if isinstance(value, Tensor) and self._params[name] is not None:
834            self._params[name].set_data(value)
835        elif value is not None:
836            raise TypeError(f"For 'Cell', the type of {name} must be Parameter or ParameterTuple, "
837                            f"but got {type(value).__name__}.")
838        else:
839            self.insert_param_to_cell(name, None)
840
841    def __setattr__(self, name, value):
842        cells = self.__dict__.get('_cells')
843        params = self.__dict__.get('_params')
844        if isinstance(value, Parameter):
845            self._set_attr_for_parameter(name, value)
846        elif isinstance(value, ParameterTuple):
847            self._set_attr_for_parameter_tuple(name, value)
848        elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
849            self._set_attr_for_parameter_in_list_or_tuple(name, value)
850        elif isinstance(value, Cell):
851            self._set_attr_for_cell(name, value)
852        elif params and name in params:
853            self._set_attr_for_params(name, value)
854        elif cells and name in cells:
855            if value is not None:
856                raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
857            self._cells[name] = None
858        else:
859            if isinstance(value, Primitive):
860                value.set_prim_instance_name(name)
861                self._primitives[name] = value
862            object.__setattr__(self, name, value)
863        if name not in Cell.IGNORE_LIST:
864            self._attr_synced = False
865
866    def extend_repr(self):
867        """
868        Expand the description of Cell.
869
870        To print customized extended information, re-implement this method in your own cells.
871        """
872        return ''
873
874    def __str__(self):
875        return self.__repr__()
876
877    def __repr__(self):
878        extra_str = self.extend_repr()
879        info_str = self.__class__.__name__ + '<'
880        if self._cells:
881            sub_str = '\n'
882            if extra_str:
883                sub_str += '{}\n'.format(self.extend_repr())
884            for key, value in self._cells.items():
885                sub_str += '({}): {}\n'.format(key, repr(value))
886            sub_str = sub_str.replace('\n', '\n  ') + '>'
887            info_str += sub_str
888        else:
889            info_str += extra_str + '>'
890        return info_str
891
892    def load_parameter_slice(self, params):
893        """
894        Replace parameters with sliced tensors by parallel strategies.
895
896        Note:
897            This interface is deprecated.
898        """
899        logger.warning("'load_parameter_slice' function is deprecated.")
900
901    def set_parallel_input_with_inputs(self, *inputs):
902        """
903        Slice inputs tensors by parallel strategies.
904
905        Note:
906            This interface is deprecated.
907        """
908        logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
909
910    def set_inputs(self, *inputs, **kwargs):
911        """
912        Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
913        using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
914        configured with set_inputs. The shape of input Tensor can be either dynamic or static.
915
916        .. note::
917            There are two mode:
918
919            - Full mode: arguments will be used as all compile inputs for graph-compiling.
920            - Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input
921              at the corresponding position for graph-compiling.
922
923            Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
924
925        Args:
926            inputs (tuple): Full mode arguments.
927            kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined
928                in `self.construct`.
929
930        .. warning::
931            This is an experimental API that is subject to change or deletion.
932
933        Examples:
934            >>> import numpy as np
935            >>> import mindspore as ms
936            >>> from mindspore import nn, Tensor
937            >>>
938            >>> class ReluNet(nn.Cell):
939            ...     def __init__(self):
940            ...         super(ReluNet, self).__init__()
941            ...         self.relu = nn.ReLU()
942            ...     def construct(self, x):
943            ...         return self.relu(x)
944            >>>
945            >>> net = ReluNet()
946            >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
947            >>> net.set_inputs(input_dyn)
948            >>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
949            >>> output = net(input)
950            >>>
951            >>> net2 = ReluNet()
952            >>> net2.set_inputs(x=input_dyn)
953            >>> output = net2(input)
954        """
955        if self.grad_ops_label:
956            logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
957                           f'generated.')
958        if kwargs and inputs:
959            raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!'
960                             % (inputs, kwargs))
961
962        if not kwargs:
963            self._dynamic_shape_inputs = inputs
964            self._check_construct_args(*inputs)
965            if context._get_mode() == context.PYNATIVE_MODE:
966                _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
967        else:
968            self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
969
970    def get_inputs(self):
971        """
972        Returns the dynamic_inputs of a cell object in one network.
973
974        Returns:
975            inputs (tuple), Inputs of the Cell object.
976
977        .. warning::
978            This is an experimental API that is subject to change or deletion.
979
980        Examples:
981            >>> import numpy as np
982            >>> import mindspore as ms
983            >>> from mindspore import nn, Tensor
984            >>>
985            >>> class ReluNet(nn.Cell):
986            ...     def __init__(self):
987            ...         super(ReluNet, self).__init__()
988            ...         self.relu = nn.ReLU()
989            ...     def construct(self, x):
990            ...         return self.relu(x)
991            >>>
992            >>> net = ReluNet()
993            >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
994            >>> net.set_inputs(input_dyn)
995            >>> get_inputs = net.get_inputs()
996            >>> print(get_inputs)
997            (Tensor(shape=[3, -1], dtype=Float32, value= ),)
998
999        """
1000
1001        return self._dynamic_shape_inputs
1002
1003    def _check_parameter_consistency(self, set_inputs, net_inputs):
1004        """Check consistency for parameter."""
1005        for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
1006            if isinstance(set_input, Tensor):
1007                if not isinstance(net_input, Tensor):
1008                    raise TypeError(
1009                        f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must "
1010                        f"be Tensor, but got {type(net_input)}.")
1011                if isinstance(set_input, Parameter) != isinstance(net_input, Parameter):
1012                    raise TypeError(
1013                        f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
1014                        f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.")
1015            elif isinstance(set_input, (tuple, list)):
1016                if not isinstance(net_input, (tuple, list)):
1017                    raise TypeError(
1018                        f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
1019                        f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
1020                self._check_parameter_consistency(set_input, net_input)
1021
1022    def _get_compile_args(self, args):
1023        """Get compile arguments."""
1024        # this is used only for test
1025        set_by_auto_dynamic = False
1026        if is_auto_dynamic():
1027            if self._dynamic_shape_inputs is None:
1028                set_by_auto_dynamic = True
1029            else:
1030                if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None:
1031                    set_by_auto_dynamic = True
1032        if set_by_auto_dynamic:
1033            self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
1034
1035        if self._dynamic_shape_inputs is not None:
1036            logger.debug("Compiled Graph with dynamic shape")
1037            compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs)
1038            _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
1039            self._check_parameter_consistency(compile_args, args)
1040            Validator.check_symbolic_shape(compile_args, args)
1041            self.saved_dynamic_shape = compile_args
1042            return compile_args
1043        return args
1044
1045    def compile(self, *args, **kwargs):
1046        """
1047        Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
1048
1049        Args:
1050            args (tuple): Args of the Cell object.
1051            kwargs (dict): Kwargs of the Cell object.
1052        """
1053        self._compile_args = self._get_compile_args(args)
1054        _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1055                                     jit_config_dict=self._jit_config_dict, **kwargs)
1056
1057    def compile_and_run(self, *args, **kwargs):
1058        """
1059        Compile and run Cell, the input must be consistent with the input defined in construct.
1060
1061        Note:
1062            It is not recommended to call directly.
1063
1064        Args:
1065            args (tuple): Args of the Cell object.
1066            kwargs (dict): Kwargs of the Cell object.
1067
1068        Returns:
1069            Object, the result of executing.
1070        """
1071        self.compile(*args, **kwargs)
1072        self.add_flags(ge_sync_data=False)
1073        new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
1074        return _cell_graph_executor(self, *new_args, phase=self.phase)
1075
1076    def auto_parallel_compile_and_run(self):
1077        """
1078        Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1079
1080        Note:
1081            This interface is deprecated.
1082        """
1083        logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
1084
1085    def exec_checkpoint_graph(self):
1086        """Executes GE saving checkpoint graph operation."""
1087        logger.warning("'exec_checkpoint_graph' function is deprecated.")
1088        self.add_flags(ge_sync_data=True)
1089        _cell_graph_executor(self, phase='save')
1090
1091    def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
1092        """
1093        Adds a parameter to the current cell.
1094
1095        Inserts a parameter with given name to the cell. The method is currently used in
1096        `mindspore.nn.Cell.__setattr__`.
1097
1098        Args:
1099            param_name (str): Name of the parameter.
1100            param (Parameter): Parameter to be inserted to the cell.
1101            check_name_contain_dot (bool): Determines whether the name input is compatible. Default: ``True`` .
1102
1103        Raises:
1104            KeyError: If the name of parameter is null or contains dot.
1105            TypeError: If the type of parameter is not Parameter.
1106
1107        Examples:
1108            >>> import mindspore as ms
1109            >>> from mindspore import Tensor, nn, Parameter
1110            ...
1111            >>> class Net(nn.Cell):
1112            ...     def __init__(self):
1113            ...         super(Net, self).__init__()
1114            ...         self.relu = nn.ReLU()
1115            ...
1116            ...     def construct(self, x):
1117            ...         x = self.relu(x)
1118            ...         return x
1119            >>> net = Net()
1120            >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
1121            >>> print(net.bias)
1122            Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
1123        """
1124        if not param_name:
1125            raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
1126        if check_name_contain_dot and '.' in param_name:
1127            raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
1128        if '_params' not in self.__dict__:
1129            raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
1130        if hasattr(self, param_name) and param_name not in self._params:
1131            raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
1132                           f"Cannot insert another parameter with the same name.")
1133        if not isinstance(param, Parameter) and param is not None:
1134            raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1135                            f"but got {type(param)}.")
1136        if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1137            param.name = param_name
1138        self._params[param_name] = param
1139
1140    def cast_param(self, param):
1141        """
1142        Cast parameter according to auto mix precision level in pynative mode.
1143
1144        This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
1145
1146        Args:
1147            param (Parameter): Parameters, the type of which should be cast.
1148
1149        Returns:
1150            Parameter, the input parameter with type automatically cast.
1151        """
1152        msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
1153        logger.warning(msg)
1154        mixed_type = self.get_mixed_precision_type()
1155        if mixed_type != MixedPrecisionType.NOTSET:
1156            if mixed_type == MixedPrecisionType.FP32:
1157                param.set_cast_dtype(mstype.float32)
1158            elif mixed_type == MixedPrecisionType.FP16:
1159                param.set_cast_dtype(mstype.float16)
1160        elif hasattr(param, "set_cast_dtype"):
1161            # retest dtype
1162            param.set_cast_dtype()
1163        return param
1164
1165    def insert_child_to_cell(self, child_name, child_cell):
1166        """
1167        Adds a child cell to the current cell with a given name.
1168
1169        Args:
1170            child_name (str): Name of the child cell.
1171            child_cell (Cell): The child cell to be inserted.
1172
1173        Raises:
1174            KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1175            TypeError: If type of `child_name` is not str.
1176            TypeError: Child Cell's type is incorrect.
1177
1178        Examples:
1179            >>> import mindspore as ms
1180            >>> from mindspore import Tensor, nn
1181            ...
1182            >>> net1 = nn.ReLU()
1183            >>> net2 = nn.Dense(2, 2)
1184            >>> net1.insert_child_to_cell("child", net2)
1185            >>> print(net1)
1186            ReLU<
1187              (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1188              >
1189        """
1190        if not isinstance(child_name, str):
1191            raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
1192                            f"but got {type(child_name)}.")
1193        if not child_name or '.' in child_name:
1194            raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
1195                           "can not contain '.' ")
1196        if hasattr(self, child_name) and child_name not in self._cells:
1197            raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
1198                           f"Cannot insert another child cell with the same name.")
1199        if not isinstance(child_cell, Cell) and child_cell is not None:
1200            raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
1201                            f"but got type {type(child_cell)}.")
1202        self._cells[child_name] = child_cell
1203
1204    def construct(self, *args, **kwargs):
1205        """
1206        Defines the computation to be performed. This method must be overridden by all subclasses.
1207
1208        Note:
1209            It is not supported currently that inputs contain both tuple and non-tuple types at same time.
1210
1211        Args:
1212            args (tuple): Tuple of variable parameters.
1213            kwargs (dict): Dictionary of variable keyword parameters.
1214
1215        Returns:
1216            Tensor, returns the computed result.
1217        """
1218        raise AttributeError("For 'Cell', the method 'construct' is not defined.")
1219
1220    def remove_redundant_parameters(self):
1221        """
1222        Remove the redundant parameters.
1223
1224        This interface usually needs not to be used explicitly.
1225        """
1226        cells = self.cells_and_names()
1227        for _, cell in cells:
1228            params = cell._params.items()
1229            for param_name, param in list(params):
1230                if param.name not in self.parallel_parameter_name_list:
1231                    cell._params.pop(param_name)
1232                    logger.info("remove the redundant parameter: %s", param.name)
1233                    continue
1234            cell_dict = cell.__dict__
1235            for key in cell_dict:
1236                if isinstance(cell_dict[key], ParameterTuple):
1237                    param_tuple = cell_dict[key]
1238                    new_param_tuple = []
1239                    for param in param_tuple:
1240                        if param.name not in self.parallel_parameter_name_list:
1241                            logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
1242                            continue
1243                        new_param_tuple.append(param)
1244                    cell.__dict__[key] = ParameterTuple(new_param_tuple)
1245
1246    def init_parameters_data(self, auto_parallel_mode=False):
1247        """
1248        Initialize all parameters and replace the original saved parameters in cell.
1249
1250        Note:
1251            trainable_params() and other similar interfaces may return different parameter instance after
1252            `init_parameters_data`, do not save these results.
1253
1254        Args:
1255            auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
1256
1257        Returns:
1258            Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
1259
1260        Examples:
1261            >>> import mindspore as ms
1262            >>> from mindspore import Tensor, nn
1263            ...
1264            >>> class Net(nn.Cell):
1265            ...     def __init__(self):
1266            ...         super(Net, self).__init__()
1267            ...         self.dense = nn.Dense(2, 2)
1268            ...
1269            ...     def construct(self, x):
1270            ...         x = self.dense(x)
1271            ...         return x
1272            >>> net = Net()
1273            >>> print(net.init_parameters_data())
1274            {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
1275             Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
1276             Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
1277             Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
1278        """
1279        replace = dict()
1280
1281        def _updata(param):
1282            if param in replace:
1283                return replace.get(param)
1284            new_p = param.init_data(None, set_sliced=False)
1285            replace[param] = new_p
1286            return new_p
1287
1288        # replace all original usage.
1289        cells = self.cells_and_names()
1290        for _, cell in cells:
1291            params = cell._params.items()
1292            for param_name, param in params:
1293                if not auto_parallel_mode:
1294                    cell._params[param_name] = _updata(param)
1295                    continue
1296                if param.name in self.parallel_parameter_name_list:
1297                    cell._params[param_name] = _updata(param)
1298            cell_dict = cell.__dict__
1299            for key in cell_dict:
1300                if isinstance(cell_dict[key], ParameterTuple):
1301                    param_tuple = cell_dict[key]
1302                    new_param_tuple = []
1303                    for param in param_tuple:
1304                        if not auto_parallel_mode:
1305                            new_param_tuple.append(_updata(param))
1306                            continue
1307                        if param.name in self.parallel_parameter_name_list:
1308                            new_param_tuple.append(_updata(param))
1309                        else:
1310                            new_param_tuple.append(param)
1311                    cell.__dict__[key] = ParameterTuple(new_param_tuple)
1312        return replace
1313
1314    def parameters_dict(self, recurse=True):
1315        """
1316        Gets the parameters dictionary of this cell.
1317
1318        Args:
1319            recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1320
1321        Returns:
1322            OrderedDict, return parameters dictionary.
1323
1324        Examples:
1325            >>> import mindspore as ms
1326            >>> from mindspore import Tensor, nn, Parameter
1327            ...
1328            >>> class Net(nn.Cell):
1329            ...     def __init__(self):
1330            ...         super(Net, self).__init__()
1331            ...         self.dense = nn.Dense(2, 2)
1332            ...
1333            ...     def construct(self, x):
1334            ...         x = self.dense(x)
1335            ...         return x
1336            >>> net = Net()
1337            >>> print(net.parameters_dict())
1338            OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
1339            requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
1340            requires_grad=True))])
1341        """
1342        param_dict = OrderedDict()
1343        for param in self.get_parameters(expand=recurse):
1344            param_dict[param.name] = param
1345        return param_dict
1346
1347    def parameters_broadcast_dict(self, recurse=True):
1348        """
1349        Gets the parameters broadcast dictionary of this cell.
1350
1351        Args:
1352            recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1353
1354        Returns:
1355            OrderedDict, return parameters broadcast dictionary.
1356        """
1357        param_dict = OrderedDict()
1358        for param in self.get_parameters(expand=recurse):
1359            if param.layerwise_parallel is False:
1360                param_dict[param.name] = param
1361        if not param_dict:
1362            return None
1363        return param_dict
1364
1365    def update_parameters_name(self, prefix='', recurse=True):
1366        """
1367        Adds the `prefix` string to the names of parameters.
1368
1369        Args:
1370            prefix (str): The prefix string. Default: ``''`` .
1371            recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1372        """
1373
1374        Validator.check_str_and_none_by_regular(prefix)
1375        for name, param in self.parameters_and_names(expand=recurse):
1376            if prefix != '':
1377                param.is_init = False
1378            param.name = prefix + name
1379
1380    def _update_local_parameters_name(self, prefix='', recurse=True):
1381        """
1382        Updates the names of local parameters with given prefix string.
1383
1384        Adds the given prefix to the names of local parameters.
1385
1386        Local parameters means the parameters without user input.
1387
1388        Args:
1389            prefix (str): The prefix string. Default: ''.
1390            recurse (bool): Whether contains the parameters of subcells. Default: ``True``.
1391        """
1392
1393        Validator.check_str_by_regular(prefix)
1394        for name, param in self.parameters_and_names(expand=recurse):
1395            if name in self._user_parameters:
1396                continue
1397            if prefix != '':
1398                param.is_init = False
1399            param.name = prefix + name
1400
1401    @jit_forbidden_register
1402    def trainable_params(self, recurse=True):
1403        """
1404        Returns all trainable parameters.
1405
1406        Returns a list of all trainable parameters.
1407
1408        Args:
1409            recurse (bool): Whether contains the trainable parameters of subcells. Default: ``True`` .
1410
1411        Returns:
1412            List, the list of trainable parameters.
1413
1414        Tutorial Examples:
1415            - `Model Training - Optimizer
1416              <https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
1417        """
1418        return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1419
1420    @jit_forbidden_register
1421    def untrainable_params(self, recurse=True):
1422        """
1423        Returns all untrainable parameters.
1424
1425        Returns a list of all untrainable parameters.
1426
1427        Args:
1428            recurse (bool): Whether contains the untrainable parameters of subcells. Default: ``True`` .
1429
1430        Returns:
1431            List, the list of untrainable parameters.
1432        """
1433        return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
1434
1435    @jit_forbidden_register
1436    def get_parameters(self, expand=True):
1437        """
1438        Returns an iterator over cell parameters.
1439
1440        Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
1441        For more details about subcells, please see the example below.
1442
1443        Args:
1444            expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
1445                           parameters that are direct members of this cell. Default: ``True`` .
1446
1447        Returns:
1448            Iteration, all parameters at the cell.
1449
1450        Examples:
1451            >>> import mindspore as ms
1452            >>> from mindspore import nn, ops, Tensor
1453            >>> import numpy as np
1454            >>> class TestNet(nn.Cell):
1455            ...     def __init__(self):
1456            ...         super().__init__()
1457            ...         self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1458            ...         self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
1459            ...     def construct(self, x):
1460            ...         x += self.my_w1
1461            ...         x = ops.reshape(x, (16,)) - self.my_w2
1462            ...         return x
1463            >>> class TestNet2(nn.Cell):
1464            ...     def __init__(self):
1465            ...         super().__init__()
1466            ...         self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1467            ...         # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
1468            ...         # also be gathered.
1469            ...         self.subcell = TestNet()
1470            ...     def construct(self, x):
1471            ...         x += self.my_w1
1472            ...         x = ops.reshape(x, (16,)) - self.my_w2
1473            ...         return x
1474            >>> net = TestNet2()
1475            >>> print([p for p in net.get_parameters(expand=True)])
1476            [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
1477            shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
1478            requires_grad=True)]
1479        """
1480        for _, param in self.parameters_and_names(expand=expand):
1481            yield param
1482
1483    # pylint: disable=missing-docstring
1484    def check_names_and_refresh_name(self):
1485        if not hasattr(self, "_params"):
1486            return
1487        all_name = [i.name for i in dict(self.parameters_and_names()).values()]
1488        if len(set(all_name)) < len(all_name):
1489            self.update_parameters_name()
1490            self.check_names()
1491
1492    def check_names(self):
1493        """
1494        Check the names of cell parameters.
1495        """
1496        names = set("")
1497        for value, param in self.parameters_and_names():
1498            if param.name in names:
1499                raise ValueError("The value of {} is {}, its name '{}' already exists. "
1500                                 "Please set a unique name for the parameter.".format(value, param, param.name))
1501            names.add(param.name)
1502
1503    def parameters_and_names(self, name_prefix='', expand=True):
1504        """
1505        Returns an iterator over cell parameters.
1506
1507        Includes the parameter's name and itself.
1508
1509        Args:
1510            name_prefix (str): Namespace. Default: ``''`` .
1511            expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
1512                           that are direct members of this cell. Default: ``True`` .
1513
1514        Returns:
1515            Iteration, all the names and corresponding parameters in the cell.
1516
1517        Examples:
1518            >>> from mindspore import nn
1519            >>> n = nn.Dense(3, 4)
1520            >>> names = []
1521            >>> for m in n.parameters_and_names():
1522            ...     if m[0]:
1523            ...         names.append(m[0])
1524
1525        Tutorial Examples:
1526            - `Building a Network - Model Parameters
1527              <https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
1528        """
1529        cells = []
1530        if expand:
1531            cells = self.cells_and_names(name_prefix=name_prefix)
1532        else:
1533            cells.append((name_prefix, self))
1534
1535        params_set = set()
1536        for cell_name, cell in cells:
1537            params = cell._params.items()
1538            for par_name, par in params:
1539                if par is not None and par.inited_param is not None:
1540                    par = par.inited_param
1541                if par is not None and id(par) not in params_set:
1542                    params_set.add(id(par))
1543                    par_new_name = par_name
1544                    if cell_name:
1545                        par_new_name = cell_name + '.' + par_new_name
1546
1547                    yield par_new_name, par
1548
1549    def cells_and_names(self, cells=None, name_prefix=''):
1550        """
1551        Returns an iterator over all cells in the network, including the cell's name and itself.
1552
1553        Args:
1554            cells (str): Cells to iterate over. Default: ``None`` .
1555            name_prefix (str): Namespace. Default: ``''`` .
1556
1557        Returns:
1558            Iteration, all the child cells and corresponding names in the cell.
1559
1560        Examples:
1561            >>> from mindspore import nn
1562            >>> class Net(nn.Cell):
1563            ...     def __init__(self):
1564            ...         super(Net, self).__init__()
1565            ...         self.conv = nn.Conv2d(3, 64, 3)
1566            ...     def construct(self, x):
1567            ...         out = self.conv(x)
1568            ...         return out
1569            >>> names = []
1570            >>> n = Net()
1571            >>> for m in n.cells_and_names():
1572            ...     if m[0]:
1573            ...         names.append(m[0])
1574        """
1575        t_cells = cells if cells else set()
1576        if self in t_cells:
1577            return
1578
1579        t_cells.add(self)
1580        yield name_prefix, self
1581
1582        for name, cell in self._cells.items():
1583            if cell:
1584                cells_name_prefix = name
1585                if name_prefix:
1586                    cells_name_prefix = name_prefix + '.' + cells_name_prefix
1587                for ele in cell.cells_and_names(t_cells, cells_name_prefix):
1588                    yield ele
1589
1590    def cells(self):
1591        """
1592        Returns an iterator over immediate cells.
1593
1594        Returns:
1595            Iteration, the immediate cells in the cell.
1596
1597        Examples:
1598            >>> import mindspore as ms
1599            >>> from mindspore import Tensor, nn
1600            ...
1601            >>> class Net(nn.Cell):
1602            ...     def __init__(self):
1603            ...         super(Net, self).__init__()
1604            ...         self.dense = nn.Dense(2, 2)
1605            ...
1606            ...     def construct(self, x):
1607            ...         x = self.dense(x)
1608            ...         return x
1609            >>> net = Net()
1610            >>> print(net.cells())
1611            odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
1612        """
1613        return self.name_cells().values()
1614
1615    def _set_scope(self, name):
1616        """Sets the name on the first time."""
1617        if self._scope is None:
1618            self._scope = name
1619        elif self._scope == 'recompute_':
1620            self._scope = self._scope + name
1621
1622    def _children_scope_recursive(self, parent_prefix='Default'):
1623        """Generates the scope of each layer of the network recursively."""
1624        reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope")
1625
1626        for name, cell in self.name_cells().items():
1627            class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else ""
1628            yield parent_prefix + "/" + name + class_name, cell
1629
1630        for name, cell in self.name_cells().items():
1631            class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else ""
1632            for key, value in cell._children_scope_recursive(parent_prefix + "/" + name + class_name):
1633                yield key, value
1634
1635    def get_scope(self):
1636        """
1637        Returns the scope of a cell object in one network.
1638
1639        Returns:
1640            String, scope of the cell.
1641        """
1642        return self._scope
1643
1644    def generate_scope(self):
1645        """Generate the scope for each cell object in the network."""
1646        for name, cell in self._children_scope_recursive():
1647            cell._set_scope(name)
1648
1649    def name_cells(self):
1650        """
1651        Returns an iterator over all immediate cells in the network.
1652
1653        Include name of the cell and cell itself.
1654
1655        Returns:
1656            Dict, all the child cells and corresponding names in the cell.
1657
1658        Examples:
1659            >>> import mindspore as ms
1660            >>> from mindspore import Tensor, nn
1661            ...
1662            >>> class Net(nn.Cell):
1663            ...     def __init__(self):
1664            ...         super(Net, self).__init__()
1665            ...         self.dense = nn.Dense(2, 2)
1666            ...
1667            ...     def construct(self, x):
1668            ...         x = self.dense(x)
1669            ...         return x
1670            >>> net = Net()
1671            >>> print(net.name_cells())
1672            OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
1673        """
1674        value_set = set()
1675        cells = OrderedDict()
1676        for name, cell in self._cells.items():
1677            if cell is not None and cell not in value_set:
1678                value_set.add(cell)
1679                cells[name] = cell
1680        return cells
1681
1682    def _add_mixed_precision_flag(self, **flags):
1683        """Add mixed precision flag to current cell"""
1684        if "fp16" in flags and flags.get("fp16", False):
1685            Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1686        if "fp32" in flags and flags.get("fp32", False):
1687            Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1688        if "bf16" in flags and flags.get("bf16", False):
1689            Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1690
1691    def apply(self, fn):
1692        """
1693        Applies fn recursively to every subcell (as returned by .cells()) as well as self.
1694        Typical use includes initializing the parameters of a model.
1695
1696        Args:
1697            fn (function): function to be applied to each subcell.
1698
1699        Returns:
1700            Cell, self.
1701
1702        Examples:
1703            >>> import mindspore.nn as nn
1704            >>> from mindspore.common.initializer import initializer, One
1705            >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
1706            >>> def func(cell):
1707            ...     if isinstance(cell, nn.Dense):
1708            ...         cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
1709            >>> net.apply(func)
1710            SequentialCell<
1711              (0): Dense<input_channels=2, output_channels=2, has_bias=True>
1712              (1): Dense<input_channels=2, output_channels=2, has_bias=True>
1713              >
1714            >>> print(net[0].weight.asnumpy())
1715            [[1. 1.]
1716             [1. 1.]]
1717        """
1718        for cell in self.cells():
1719            cell.apply(fn)
1720        fn(self)
1721        return self
1722
1723    def add_flags(self, **flags):
1724        """
1725        Add customized attributes for cell.
1726
1727        This method is also called when the cell class is instantiated and the class parameter 'flags' is set to True.
1728
1729        Args:
1730            flags (dict): Network configuration information, currently it is used for the binding of network and
1731                dataset. Users can also customize network attributes by this parameter.
1732
1733        Examples:
1734            >>> import mindspore as ms
1735            >>> from mindspore import Tensor, nn
1736            ...
1737            >>> class Net(nn.Cell):
1738            ...     def __init__(self):
1739            ...         super(Net, self).__init__()
1740            ...         self.relu = nn.ReLU()
1741            ...
1742            ...     def construct(self, x):
1743            ...         x = self.relu(x)
1744            ...         return x
1745            >>> net = Net()
1746            >>> net.add_flags(sink_mode=True)
1747            >>> print(net.sink_mode)
1748            True
1749        """
1750        if not hasattr(self, "_func_graph_flags"):
1751            self._func_graph_flags = {}
1752        self._func_graph_flags.update({**flags})
1753        if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
1754            raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
1755                            "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
1756        self.__dict__.update({**flags})
1757        self._add_mixed_precision_flag(**flags)
1758        return self
1759
1760    def add_flags_recursive(self, **flags):
1761        """
1762        If a cell contains child cells, this method can recursively customize attributes of all cells.
1763
1764        Args:
1765            flags (dict): Network configuration information, currently it is used for the binding of network and
1766                dataset. Users can also customize network attributes by this parameter.
1767
1768        Examples:
1769            >>> import mindspore as ms
1770            >>> from mindspore import Tensor, nn
1771            ...
1772            >>> class Net(nn.Cell):
1773            ...     def __init__(self):
1774            ...         super(Net, self).__init__()
1775            ...         self.relu = nn.ReLU()
1776            ...
1777            ...     def construct(self, x):
1778            ...         x = self.relu(x)
1779            ...         return x
1780            >>> net = Net()
1781            >>> net.add_flags_recursive(sink_mode=True)
1782            >>> print(net.sink_mode)
1783            True
1784        """
1785        self.add_flags(**flags)
1786        for cell in self.cells():
1787            cell.add_flags_recursive(**flags)
1788        return self
1789
1790    def _add_init_args(self, **args):
1791        if hasattr(self, '_cell_init_args'):
1792            self._cell_init_args += str({**args})
1793
1794    def get_flags(self):
1795        """
1796        Get the self_defined attributes of the cell, which can be added by `add_flags` method.
1797
1798        Examples:
1799            >>> import mindspore as ms
1800            >>> from mindspore import Tensor, nn
1801            ...
1802            >>> class Net(nn.Cell):
1803            ...     def __init__(self):
1804            ...         super(Net, self).__init__()
1805            ...         self.relu = nn.ReLU()
1806            ...
1807            ...     def construct(self, x):
1808            ...         x = self.relu(x)
1809            ...         return x
1810            >>> net = Net()
1811            >>> net.add_flags(sink_mode=True)
1812            >>> print(net.get_flags())
1813            {'sink_mode':True}
1814        """
1815        if not hasattr(self, "_func_graph_flags"):
1816            self._func_graph_flags = {}
1817        return self._func_graph_flags
1818
1819    def to_float(self, dst_type):
1820        """
1821        Add cast on all inputs of cell and child cells to run with certain float type.
1822
1823        If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will
1824        be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`.
1825
1826        Note:
1827            Multiple calls will overwrite.
1828
1829        Args:
1830            dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1831                dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
1832
1833        Returns:
1834            Cell, the cell itself.
1835
1836        Raises:
1837            ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
1838
1839        Supported Platforms:
1840            ``Ascend`` ``GPU`` ``CPU``
1841
1842        Examples:
1843            >>> import mindspore.nn as nn
1844            >>> from mindspore import dtype as mstype
1845            >>>
1846            >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1847            >>> net.to_float(mstype.float16)
1848            Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1849            padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
1850        """
1851        if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1852            raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
1853                             "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
1854        flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
1855                 'bf16': dst_type == mstype.bfloat16}
1856        self._add_init_args(**flags)
1857        self.add_flags_recursive(**flags)
1858        return self
1859
1860    def set_boost(self, boost_type):
1861        """
1862        In order to improve the network performance, configure the network auto enable to
1863        accelerate the algorithm in the algorithm library.
1864
1865        If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1866        `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
1867
1868        Note:
1869            Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
1870
1871        Args:
1872            boost_type (str): accelerate algorithm.
1873
1874        Returns:
1875            Cell, the cell itself.
1876
1877        Raises:
1878            ValueError: If boost_type is not in the algorithm library.
1879        """
1880        if boost_type not in ("less_bn",):
1881            raise ValueError("For 'set_boost', the argument 'boost_type' must be 'less_bn', "
1882                             "but got {}.".format(boost_type))
1883        flags = {"less_bn": boost_type == "less_bn"}
1884        self.add_flags_recursive(**flags)
1885        return self
1886
1887    def set_grad(self, requires_grad=True):
1888        """
1889        Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires
1890        gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward
1891        network is executed.
1892
1893        Args:
1894            requires_grad (bool): Specifies if the net need to grad, if it is
1895                ``true`` , the cell will construct backward network in pynative mode. Default: ``True`` .
1896
1897        Returns:
1898            Cell, the cell itself.
1899        """
1900        self.requires_grad = requires_grad
1901        return self
1902
1903    def set_train(self, mode=True):
1904        """
1905        Sets the cell to training mode.
1906
1907        The cell itself and all children cells will be set to training mode. Layers that have different constructions
1908        for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
1909        set to true, the training branch will be executed, otherwise another branch.
1910
1911        Note:
1912            When execute function Model.train(), framework will call Cell.set_train(True).
1913            When execute function Model.eval(), framework will call Cell.set_train(False).
1914
1915        Args:
1916            mode (bool): Specifies whether the model is training. Default: ``True`` .
1917
1918        Returns:
1919            Cell, the cell itself.
1920
1921        Tutorial Examples:
1922            - `Model Training - Implementing Training and Evaluation
1923              <https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
1924        """
1925        if mode:
1926            self._phase = 'train'
1927        else:
1928            self._phase = 'predict'
1929        self.add_flags_recursive(training=mode)
1930        return self
1931
1932    def set_broadcast_flag(self, mode=True):
1933        """
1934        Set parameter broadcast mode for this cell.
1935
1936        Args:
1937            mode (bool): Specifies whether the mode is parameter broadcast. Default: ``True`` .
1938        """
1939        self.add_flags_recursive(broadcast_flag=mode)
1940        return self
1941
1942    def set_auto_parallel(self):
1943        """
1944        Set the cell to auto parallel mode.
1945
1946        Note:
1947            This interface is deprecated.
1948        """
1949        logger.warning("'set_auto_parallel' function is deprecated.")
1950
1951    def set_jit_config(self, jit_config):
1952        """
1953        Set jit config for cell.
1954
1955        Args:
1956            jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
1957
1958        Examples:
1959            >>> import mindspore as ms
1960            >>> from mindspore import Tensor, nn
1961            ...
1962            >>> class Net(nn.Cell):
1963            ...     def __init__(self):
1964            ...         super(Net, self).__init__()
1965            ...         self.relu = nn.ReLU()
1966            ...
1967            ...     def construct(self, x):
1968            ...         x = self.relu(x)
1969            ...         return x
1970            >>> net = Net()
1971            >>> jitconfig = ms.JitConfig()
1972            >>> net.set_jit_config(jitconfig)
1973        """
1974        if self._jit_config_dict:
1975            logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1976        else:
1977            self._jit_config_dict = jit_config.jit_config_dict
1978
1979    def flatten_weights(self, fusion_size=0):
1980        """
1981        Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.
1982
1983        Note:
1984            By default, parameters with same data type will using a single contiguous memory chunk. but for
1985            some models with huge number of parameters, splitting a large memory chunk into several smaller
1986            memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size'
1987            to limit the maximum memory chunk size.
1988
1989        Args:
1990            fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
1991        """
1992        if fusion_size < 0:
1993            raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
1994        Tensor._flatten_tensors(self.trainable_params(), fusion_size)  # pylint: disable=W0212
1995
1996    def register_forward_pre_hook(self, hook_fn):
1997        """
1998        Register forward pre hook function for Cell object.
1999
2000        Note:
2001            - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2002            - 'hook_fn' must be defined as the following code.
2003              `cell` is the object of registered Cell. `inputs` is the forward
2004              input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
2005              forward input objects.
2006            - It should have the following signature:
2007              hook_fn(cell, inputs) -> new input objects or none.
2008            - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2009              `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
2010              called in the `construct` function of the Cell object, a hook function will be added at each run time of
2011              Cell object.
2012
2013        Args:
2014            hook_fn (function): Python function. Forward pre hook function.
2015
2016        Returns:
2017            A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2018            `handle.remove()` .
2019
2020        Raises:
2021            TypeError: If the `hook_fn` is not a function of python.
2022
2023        Supported Platforms:
2024        ``Ascend`` ``GPU`` ``CPU``
2025
2026        Examples:
2027            >>> import numpy as np
2028            >>> import mindspore as ms
2029            >>> from mindspore import Tensor, nn, ops
2030            >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2031            >>> def forward_pre_hook_fn(cell, inputs):
2032            ...     print("forward inputs: ", inputs)
2033            ...
2034            >>> class Net(nn.Cell):
2035            ...     def __init__(self):
2036            ...         super(Net, self).__init__()
2037            ...         self.mul = nn.MatMul()
2038            ...         self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn)
2039            ...
2040            ...     def construct(self, x, y):
2041            ...         x = x + x
2042            ...         x = self.mul(x, y)
2043            ...         return x
2044            >>> grad = ops.GradOperation(get_all=True)
2045            >>> net = Net()
2046            >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
2047            forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
2048                            dtype=Float32, value= [ 1.00000000e+00]))
2049            >>> print(output)
2050            (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2051            value= [ 2.00000000e+00]))
2052        """
2053        if not check_hook_fn("register_forward_pre_hook", hook_fn):
2054            return HookHandle()
2055        self._enable_forward_pre_hook = True
2056        _pynative_executor.set_hook_changed(self)
2057        if not hasattr(self, '_forward_pre_hook_key'):
2058            self._forward_pre_hook_key = -1
2059        self._forward_pre_hook_key += 1
2060        self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
2061        handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
2062        return handle
2063
2064    def _run_forward_pre_hook(self, inputs):
2065        """
2066        Running forward pre hook function registered on Cell object.
2067
2068        Args:
2069            inputs: The input objects of cell object.
2070
2071        Returns:
2072            - **outputs** - New input objects or none.
2073
2074        Supported Platforms:
2075        ``Ascend`` ``GPU`` ``CPU``
2076        """
2077        for fn in self._forward_pre_hook.values():
2078            ret = fn(self, inputs)
2079            if ret is not None:
2080                if not isinstance(ret, tuple):
2081                    inputs = (ret,)
2082                else:
2083                    inputs = ret
2084        return inputs
2085
2086    def register_forward_hook(self, hook_fn):
2087        """
2088        Set the Cell forward hook function.
2089
2090        Note:
2091            - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2092            - 'hook_fn' must be defined as the following code.
2093              `cell` is the object of registered Cell. `inputs` is the forward
2094              input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
2095              modify the forward output object by returning new forward output object.
2096            - It should have the following signature:
2097              hook_fn(cell, inputs, output) -> new output object or none.
2098            - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2099              `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
2100              called in the `construct` function of the Cell object, a hook function will be added at each run time of
2101              Cell object.
2102
2103        Args:
2104            hook_fn (function): Python function. Forward hook function.
2105
2106        Returns:
2107            A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2108            `handle.remove()` .
2109
2110        Raises:
2111            TypeError: If the `hook_fn` is not a function of python.
2112
2113        Supported Platforms:
2114        ``Ascend`` ``GPU`` ``CPU``
2115
2116        Examples:
2117            >>> import numpy as np
2118            >>> import mindspore as ms
2119            >>> from mindspore import Tensor, nn, ops
2120            >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2121            >>> def forward_hook_fn(cell, inputs, output):
2122            ...     print("forward inputs: ", inputs)
2123            ...     print("forward output: ", output)
2124            ...
2125            >>> class Net(nn.Cell):
2126            ...     def __init__(self):
2127            ...         super(Net, self).__init__()
2128            ...         self.mul = nn.MatMul()
2129            ...         self.handle = self.mul.register_forward_hook(forward_hook_fn)
2130            ...
2131            ...     def construct(self, x, y):
2132            ...         x = x + x
2133            ...         x = self.mul(x, y)
2134            ...         return x
2135            >>> grad = ops.GradOperation(get_all=True)
2136            >>> net = Net()
2137            >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
2138            forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
2139                            dtype=Float32, value= [ 1.00000000e+00]))
2140            forward output: 2.0
2141            >>> print(output)
2142            (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2143            value= [ 2.00000000e+00]))
2144        """
2145        if not check_hook_fn("register_forward_hook", hook_fn):
2146            return HookHandle()
2147        self._enable_forward_hook = True
2148        _pynative_executor.set_hook_changed(self)
2149        if not hasattr(self, '_forward_hook_key'):
2150            self._forward_hook_key = -1
2151        self._forward_hook_key += 1
2152        self._forward_hook[self._forward_hook_key] = hook_fn
2153        handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
2154        return handle
2155
2156    def _run_forward_hook(self, inputs, output):
2157        """
2158        Running forward hook function registered on Cell object.
2159
2160        Args:
2161            inputs: The input objects of Cell object.
2162            output: The output object of Cell object.
2163
2164        Returns:
2165            - **output** - New output object or none.
2166
2167        Supported Platforms:
2168        ``Ascend`` ``GPU`` ``CPU``
2169        """
2170        for fn in self._forward_hook.values():
2171            ret = fn(self, inputs, output)
2172            if ret is not None:
2173                output = ret
2174        return output
2175
2176    def register_backward_hook(self, hook_fn):
2177        """
2178        Register the backward hook function.
2179
2180        Note:
2181            - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2182            - The 'hook_fn' must be defined as the following code.
2183              `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
2184              gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
2185              primitive, which may be modified by returning a new output gradient.
2186            - The 'hook_fn' should have the following signature:
2187              hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
2188            - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2189              graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
2190              mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
2191              a hook function will be added at each run time of Cell object.
2192
2193        Args:
2194            hook_fn (function): Python function. Backward hook function.
2195
2196        Returns:
2197            A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2198            `handle.remove()` .
2199
2200        Raises:
2201            TypeError: If the `hook_fn` is not a function of python.
2202
2203        Supported Platforms:
2204        ``Ascend`` ``GPU`` ``CPU``
2205
2206        Examples:
2207            >>> import numpy as np
2208            >>> import mindspore as ms
2209            >>> from mindspore import Tensor, nn, ops
2210            >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2211            >>> def backward_hook_fn(cell_id, grad_input, grad_output):
2212            ...     print("backward input: ", grad_input)
2213            ...     print("backward output: ", grad_output)
2214            ...
2215            >>> class Net(nn.Cell):
2216            ...     def __init__(self):
2217            ...         super(Net, self).__init__()
2218            ...         self.relu = nn.ReLU()
2219            ...         self.handle = self.relu.register_backward_hook(backward_hook_fn)
2220            ...
2221            ...     def construct(self, x):
2222            ...         x = x + x
2223            ...         x = self.relu(x)
2224            ...         return x
2225            >>> grad = ops.GradOperation(get_all=True)
2226            >>> net = Net()
2227            >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
2228            backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2229            backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2230            >>> print(output)
2231            (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2232        """
2233        if not check_hook_fn("register_backward_hook", hook_fn):
2234            return HookHandle()
2235        if self._cell_backward_hook is None:
2236            self._enable_backward_hook = True
2237            self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
2238            backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2239            handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2240        else:
2241            backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2242            handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2243        return handle
2244
2245    def _backward_hook_construct(self, *inputs, **kwargs):
2246        """
2247        Backward hook construct method to replace original construct method.
2248
2249        Args:
2250            inputs: The input objects of Cell object.
2251            kwargs (dict): Dictionary of variable keyword parameters.
2252
2253        Returns:
2254            - **outputs** - The output objects of Cell object.
2255
2256        Supported Platforms:
2257        ``Ascend`` ``GPU`` ``CPU``
2258        """
2259        if len(inputs) > 1:
2260            inputs = self._cell_backward_hook(inputs)
2261        else:
2262            inputs = self._cell_backward_hook(*inputs)
2263            inputs = (inputs,)
2264        if self.recompute_cell is not None:
2265            if isinstance(inputs, tuple):
2266                outputs = self.recompute_cell(*inputs, **kwargs)
2267            else:
2268                outputs = self.recompute_cell(inputs, **kwargs)
2269        else:
2270            if isinstance(inputs, tuple):
2271                outputs = self.construct(*inputs, **kwargs)
2272            else:
2273                outputs = self.construct(inputs, **kwargs)
2274        outputs = self._cell_backward_hook(outputs)
2275        return outputs
2276
2277    def set_param_ps(self, recurse=True, init_in_server=False):
2278        """
2279        Set whether the trainable parameters are updated by parameter server and whether the
2280        trainable parameters are initialized on server.
2281
2282        Note:
2283            It only works when a running task is in the parameter server mode.
2284            It is only supported in graph mode.
2285
2286        Args:
2287            recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2288            init_in_server (bool): Whether trainable parameters updated by parameter server are
2289                initialized on server. Default: ``False`` .
2290        """
2291        params = self.trainable_params(recurse)
2292        for param in params:
2293            param.set_param_ps(init_in_server)
2294
2295    @deprecated("1.8", "set_param_fl")
2296    def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
2297        params = self.parameters_and_names()
2298        for param in params:
2299            param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
2300
2301    def set_comm_fusion(self, fusion_type, recurse=True):
2302        """
2303        Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
2304        :class:`mindspore.Parameter.comm_fusion`.
2305
2306        Note:
2307            The value of attribute will be overwritten when the function is called multiply.
2308
2309        Args:
2310            fusion_type (int): The value of `comm_fusion`.
2311            recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2312        """
2313        Validator.check_non_negative_int(fusion_type)
2314        for param in self.trainable_params(recurse):
2315            param.comm_fusion = fusion_type
2316        return self
2317
2318    def _set_recompute_scope(self, mode):
2319        prefix = 'recompute_'
2320        if mode:
2321            if self._scope is None:
2322                self._scope = prefix
2323            elif not self._scope.startswith(prefix):
2324                self._scope = prefix + self._scope
2325        elif self._scope is not None and self._scope.startswith(prefix):
2326            self._scope = self._scope[len(prefix):]
2327
2328    def _mp_comm_recompute(self, mp_comm_recompute=True):
2329        """
2330        Set the model parallel communication in cell recomputed.
2331        """
2332        for _, value in self._primitives.items():
2333            if value:
2334                value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
2335        for cell in self.cells():
2336            cell._mp_comm_recompute(mp_comm_recompute)
2337
2338    def _parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute=False):
2339        """
2340        Set the parallel optimizer communication in cell recomputed.
2341        """
2342        for param in self.trainable_params():
2343            param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
2344
2345    def _recompute_slice_activation(self, slice_activation=False):
2346        """
2347        Slice the cell output which would remains in memory.
2348        """
2349        for _, value in self._primitives.items():
2350            if value:
2351                value.add_prim_attr("slice_activation", slice_activation)
2352        for cell in self.cells():
2353            cell._recompute_slice_activation(slice_activation)
2354
2355    def _recompute(self, mode=True, output_recompute=False):
2356        """
2357        Set the cell recomputed.
2358        """
2359        Validator.check_bool(mode)
2360        Validator.check_bool(output_recompute)
2361        if not self._has_config_recompute:
2362            self._has_config_recompute = True
2363        else:
2364            raise RuntimeError("The recompute interface can be configured only once."
2365                               " When the parent cell is configured, the child cell should not be configured")
2366        self._set_recompute_scope(mode)
2367        if mode and not output_recompute:
2368            self.add_flags(output_no_recompute=True)
2369        for cell in self.cells():
2370            cell._recompute(mode, True)
2371
2372    @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
2373    def recompute(self, **kwargs):
2374        """
2375        Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed.
2376        If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than
2377        storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
2378
2379        Note:
2380
2381            - If the computation involves something like randomization or global variable, the equivalence
2382              is not guaranteed currently.
2383            - If the recompute api of a primitive in this cell is also called, the recompute mode of this
2384              primitive is subject to the recompute api of the primitive.
2385            - The interface can be configured only once.
2386              Therefore, when the parent cell is configured, the child cell should not be configured.
2387            - The outputs of cell are excluded from recomputation by default, which is based on our configuration
2388              experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted
2389              to be set recomputed, use the recompute api of the primtive.
2390            - When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
2391              to improve performance if necessary.
2392            - When the memory still not enough after applying the recompute, configuring
2393              'parallel_optimizer_comm_recompute=True' to save more memory if necessary.
2394              Cells in the same fusion group should have the same parallel_optimizer_comm_recompute configures.
2395
2396        Args:
2397            mp_comm_recompute (bool): Specifies whether the model parallel communication operators
2398                in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True`` .
2399            parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
2400                introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
2401                Default: ``False`` .
2402        """
2403        if context.get_context("mode") == context.PYNATIVE_MODE:
2404            self.recompute_cell = recompute_registry.get()(self.construct)
2405            return
2406        self._recompute()
2407        if 'mp_comm_recompute' in kwargs.keys():
2408            self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
2409        if 'parallel_optimizer_comm_recompute' in kwargs.keys():
2410            if (kwargs.get('parallel_optimizer_comm_recompute', False) and
2411                    context.get_auto_parallel_context("pipeline_stages") > 1):
2412                logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
2413                               "are not support recomputation in pipeline parallel.")
2414            elif context.get_auto_parallel_context("pipeline_stages") == 1:
2415                self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
2416        if 'recompute_slice_activation' in kwargs:
2417            self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
2418
2419        for key, _ in kwargs.items():
2420            if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
2421                raise ValueError("For 'recompute', keyword '%s' is not recognized! "
2422                                 "the key kwargs must be 'mp_comm_recompute', "
2423                                 "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
2424
2425    @deprecated("2.3", "infer_param_pipeline_stage")
2426    def infer_param_pipeline_stage(self):
2427        """
2428        Infer pipeline stages of all parameters in the cell.
2429
2430        Note:
2431            - The interface is deprecated from version 2.3 and will be removed in a future version.
2432
2433        Returns:
2434            The params belong to current stage in pipeline parallel.
2435
2436        Raises:
2437            RuntimeError: If there is a parameter does not belong to any stage.
2438        """
2439        from mindspore.parallel._utils import _get_global_rank, _get_device_num
2440        logger.warning(f"This interface may be deleted in the future.")
2441        stage_num = context.get_auto_parallel_context("pipeline_stages")
2442        device_num = _get_device_num()
2443        rank_id = _get_global_rank()
2444        per_stage_devices = device_num // stage_num
2445        current_stage = rank_id // per_stage_devices
2446        params = []
2447        for param in self.trainable_params():
2448            if not param._pipeline_stage_list:  # pylint: disable=W0212
2449                raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
2450                                   "please check whether the cell where the param locates has been set "
2451                                   "'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
2452                                   "to add its stage information".format(param.name))
2453            if current_stage in param._pipeline_stage_list:
2454                params.append(param)
2455        return params
2456
2457    def place(self, role, rank_id):
2458        """
2459        Set the label for all operators in this cell.
2460        This label tells MindSpore compiler on which process this cell should be launched.
2461        And each process's identical label consists of input `role` and `rank_id`.
2462        So by setting different cells with different labels, which will be launched on different processes,
2463        users can launch a distributed training or predicting job.
2464
2465        Note:
2466            - This method is effective only after
2467              `mindspore.communication.init()` is called for dynamic cluster building.
2468
2469        Args:
2470            role (str): The role of the process on which this cell will be launched.
2471                        Only 'MS_WORKER' is supported for now.
2472            rank_id (int): The rank id of the process on which this cell will be launched.
2473                           The rank is unique in processes with the same role.
2474
2475        Examples:
2476            >>> from mindspore import context
2477            >>> import mindspore.nn as nn
2478            >>> context.set_context(mode=context.GRAPH_MODE)
2479            >>> fc = nn.Dense(2, 3)
2480            >>> fc.place('MS_WORKER', 0)
2481        """
2482        all_ops = self._get_prims_recursively()
2483        for op in all_ops:
2484            op.place(role, rank_id)
2485
2486    def _mixed_precision_cast(self, inputs):
2487        mixed_type = self.get_mixed_precision_type()
2488        if mixed_type == MixedPrecisionType.NOTSET:
2489            return inputs
2490        if mixed_type == MixedPrecisionType.FP16:
2491            cast_type = mstype.float16
2492        elif mixed_type == MixedPrecisionType.BF16:
2493            cast_type = mstype.bfloat16
2494        else:
2495            cast_type = mstype.float32
2496        cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
2497        return cast_inputs
2498
2499    def _get_attr_from_cell(self, network):
2500        if not isinstance(network, Cell):
2501            return
2502        if hasattr(network, "jit_config_dict"):
2503            self._jit_config_dict = network.jit_config_dict
2504        if hasattr(network, "_amp_level"):
2505            self._amp_level = getattr(network, "_amp_level")
2506
2507
2508class GraphCell(Cell):
2509    """
2510    Base class for running the graph loaded from a MindIR.
2511
2512    This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
2513    diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
2514
2515    Args:
2516        graph (FuncGraph): A compiled graph loaded from MindIR.
2517        params_init (dict): Parameters need to be inited in the graph.
2518            The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2519            If the parameter exists in the graph according to the name, update it's value.
2520            If the parameter does not exist, ignore it. Default: ``None`` .
2521        obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2522            used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2523            a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2524            provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
2525
2526    Raises:
2527        TypeError: If the `graph` is not a FuncGraph.
2528        TypeError: If the `params_init` is not a dict.
2529        TypeError: If the key of the `params_init` is not a str.
2530        TypeError: If the value of the `params_init` is neither a Tensor nor a Parameter.
2531
2532    Supported Platforms:
2533        ``Ascend`` ``GPU`` ``CPU``
2534
2535    Examples:
2536        >>> import numpy as np
2537        >>> import mindspore as ms
2538        >>> import mindspore.nn as nn
2539        >>> from mindspore import Tensor
2540        >>> from mindspore import context
2541        >>> context.set_context(mode=context.GRAPH_MODE)
2542        >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
2543        >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2544        >>> ms.export(net, input, file_name="net", file_format="MINDIR")
2545        >>> graph = ms.load("net.mindir")
2546        >>> net = nn.GraphCell(graph)
2547        >>> output = net(input)
2548        >>> print(output)
2549        [[[[4. 6. 4.]
2550           [6. 9. 6.]
2551           [4. 6. 4.]]]]
2552    """
2553
2554    def __init__(self, graph, params_init=None, obf_random_seed=None):
2555        super(GraphCell, self).__init__(auto_prefix=True)
2556        if not isinstance(graph, FuncGraph):
2557            raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
2558                            f"but got type {type(graph)}.")
2559        self.graph = graph
2560        self.obf_random_seed = obf_random_seed
2561        if obf_random_seed is not None:
2562            if not isinstance(obf_random_seed, int):
2563                raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
2564            int_64_max = 9223372036854775807
2565            if obf_random_seed <= 0 or obf_random_seed > int_64_max:
2566                raise ValueError(
2567                    "'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
2568                    "but got {}.".format(int_64_max, obf_random_seed))
2569            self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
2570        params_init = {} if params_init is None else params_init
2571        if not isinstance(params_init, dict):
2572            raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
2573        for name, value in params_init.items():
2574            if not isinstance(name, str) or not isinstance(value, Tensor):
2575                raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, "
2576                                "and the value must be Tensor or Parameter, "
2577                                f"but got the key type: {type(name)}, and the value type: {type(value)}")
2578
2579        params_dict = update_func_graph_hyper_params(self.graph, params_init)
2580        for name, param in params_dict.items():
2581            self._params[name] = param
2582        _cell_graph_executor.inc_graph_cell_count()
2583
2584    def construct(self, *inputs):
2585        return self.graph(*inputs)
2586
2587    def __call__(self, *args, **kwargs):
2588        self.phase = "graph_load_from_mindir"
2589        self._add_attr("graph_load_from_mindir", self.graph)
2590        if not self.obf_random_seed:
2591            return self.compile_and_run(*args, **kwargs)
2592        append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
2593        return self.compile_and_run(*args, append_input, **kwargs)
2594
2595
2596def _check_param_list_tuple(value):
2597    """
2598    Check the type of input in list or tuple is Parameter.
2599    :param value: list or tuple.
2600    :return: The types of all inputs are parameter.
2601    """
2602    for item in value:
2603        if not isinstance(item, Parameter):
2604            return False
2605    return True
2606