• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2024 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Providing interface methods."""
18from __future__ import absolute_import
19
20import types
21import sys
22import os
23import time
24import ast
25import inspect
26import importlib
27import hashlib
28import contextlib
29from collections import OrderedDict, namedtuple
30from functools import wraps
31import numpy as np
32import mindspore as ms
33from mindspore import context
34from mindspore import log as logger
35from mindspore._extends.remote import kernel_build_server
36from mindspore.common.jit_config import JitConfig
37from mindspore.common.tensor import Tensor as PythonTensor
38from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
42    PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
43    _ms_memory_recycle, _bind_device_ctx, jit_mode_pi_enable, jit_mode_pi_compile
44from mindspore.parallel._ps_context import _is_role_sched
45from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
46    _is_in_auto_parallel_mode
47from mindspore import _checkparam as Validator
48from mindspore._checkparam import is_stub_tensor
49from mindspore.common._utils import is_shape_unknown
50from mindspore.common.mutable import mutable
51from mindspore.common._register_for_adapter import ms_adapter_registry
52from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
53    get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
54
55# Store ms_function class compiled pipeline cache.
56ms_compile_cache = set()
57# Store cell compiled pipeline cache.
58cells_compile_cache = {}
59# Store function compiled times information.
60function_phases = dict()
61
62BROADCAST_PHASE = "_broadcast_"
63_PYNATIVE_PARALLEL_FUNC_NAME = "after_shard"
64
65ARG_SPECIFIED = "arg_specified_infos"
66TOTAL_ARG_LEN = "total_arg_length"
67
68
69def _check_recompile_args(compile_args, kwargs):
70    """Check recompile of graph"""
71
72    def _check_constant_tensor_arg(arg):
73        if hasattr(arg, "__ms_mutable__"):
74            return False
75        if isinstance(arg, (list, tuple)):
76            return any(_check_constant_tensor_arg(x) for x in arg)
77        return isinstance(arg, Tensor)
78
79    for v in kwargs.values():
80        compile_args += (v,)
81    for arg in compile_args:
82        if not isinstance(arg, tuple) and not isinstance(arg, list):
83            continue
84        if _check_constant_tensor_arg(arg):
85            logger.warning(f"Constant value tensor are detected in tuple or list, which might cause recompiling "
86                           f"when tensor value changes. You can use mutable(Tensor) or mutable(tuple(Tensor)) "
87                           f"to set tensor's value as variable to to avoid recompiling. The tuple or list arg "
88                           f"is: {arg} .")
89            return
90
91
92def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time, echo_function_name):
93    """Warning when the function has been compiled."""
94    ignore_dirs = ["mindspore/ops", "mindspore/nn"]
95    if any((lambda x: x in full_function_name)(x) for x in ignore_dirs):
96        return
97
98    if full_function_name in function_phases:
99        warning_times = 1
100        if len(function_phases[full_function_name]) >= warning_times \
101                and create_time not in function_phases[full_function_name]:
102            if isinstance(obj, ms.nn.Cell):
103                tips = f"Please try to create {echo_function_name} instance only once to avoid recompiling. "
104                logger.info(f"The {echo_function_name} has been compiled again. "
105                            f"{tips} ")
106            else:
107                tips = "Try to decorate the function with @jit(hash_args=...) " \
108                       "or @jit(compile_once=True) to reduce the compile time. " \
109                       "For more details, get instructions about `jit` at " \
110                       "https://www.mindspore.cn/search?inputValue=jit."
111                logger.warning(f"The {echo_function_name} has been compiled again. "
112                               f"{tips} ")
113        else:
114            _check_recompile_args(compile_args, kwargs)
115    else:
116        function_phases[full_function_name] = set()
117    function_phases[full_function_name].add(create_time)
118
119
120def _ms_adapter_tensor_as_parameter_output(data):
121    """Check whether the data is an output from a parameter which is a ms_adapter tensor.
122       Pylint: disable=unidiomatic-typecheck.
123    """
124    return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
125        and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
126
127
128def _convert_python_data(data):
129    """
130    Convert C++ data to python.
131
132    Args:
133        data : The data need be convert.
134
135    Returns:
136        data, a data convert C++ to python
137    """
138    if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
139        return ms_adapter_registry.tensor(data)
140    if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
141        return data.tensor
142    if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
143        return PythonTensor(data, internal=True)
144    if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
145        return PythonCSRTensor(csr_tensor=data)
146    if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
147        return PythonCOOTensor(coo_tensor=data)
148    if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
149        return PythonRowTensor(row_tensor=data)
150    if data.__class__ is tuple:
151        # Handle namedtuple since its type is tuple.
152        if hasattr(data, "_fields"):
153            type_name = data.__class__.__name__
154            data_dict = data._asdict()
155            fields = data_dict.keys()
156            return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
157        return tuple(_convert_python_data(x) for x in data)
158    if data.__class__ is list:
159        # Keep list object not change for inplace operation.
160        for i in range(len(data)):
161            data[i] = _convert_python_data(data[i])
162        return data
163    if data.__class__ is dict:
164        # Keep the dict object not change.
165        keys = tuple(data.keys())
166        for key in keys:
167            data[_convert_python_data(key)] = _convert_python_data(data.pop(key))
168        return data
169    return data
170
171
172def _wrap_func(fn):
173    """
174    Wrapper function, convert return data to tensor or tuple of tensor.
175
176    Args:
177        fn (Function): The function need be wrapped.
178
179    Returns:
180        Function, a new function with return suitable format data.
181    """
182
183    @wraps(fn)
184    def wrapper(*arg, **kwargs):
185        results = fn(*arg, **kwargs)
186        return _convert_python_data(results)
187
188    return wrapper
189
190
191def _check_all_tensor(sequence):
192    for element in sequence:
193        if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple)
194                                                                                    and _check_all_tensor(element)):
195            return False
196    return True
197
198
199def _handle_func_args(func, *args, **kwargs):
200    """Handle the *args and **kwargs inputs of the function."""
201    if not isinstance(func, (types.FunctionType, types.MethodType)):
202        raise RuntimeError('fn {} is not function or method'.format(func))
203    if kwargs:
204        bound_arguments = inspect.signature(func).bind(*args, **kwargs)
205        bound_arguments.apply_defaults()
206        args = bound_arguments.args
207        kwargs = bound_arguments.kwargs
208
209    positional_args = 0
210    default_args = 0
211    has_var = False
212    for value in inspect.signature(func).parameters.values():
213        if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
214            has_var = True
215        if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
216            if value.default is inspect.Parameter.empty:
217                positional_args += 1
218            else:
219                default_args += 1
220
221    if has_var:
222        return args, kwargs
223
224    if len(args) < positional_args:
225        raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
226    if len(args) > positional_args + default_args:
227        raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
228                        f"default argument, total {positional_args + default_args}, but got {len(args)}.")
229    return args, kwargs
230
231
232sys_path = list(sys.path)
233# Get the entry script path.
234entry_script_path = None
235if sys.argv and sys.argv[0] != '':
236    entry_script_path = os.path.realpath(sys.argv[0])
237    entry_script_path_dir = os.path.split(entry_script_path)[0]
238    if entry_script_path_dir in sys_path:
239        sys_path.remove(entry_script_path_dir)
240
241
242def _in_sys_path(file_path):
243    for path in sys_path:
244        if file_path.startswith(path):
245            return True
246    return False
247
248
249def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
250    """Get the dependency files of the network"""
251    with open(file_path) as fh:
252        root = ast.parse(fh.read(), file_path)
253    for node in ast.iter_child_nodes(root):
254        module_name = ""
255        if isinstance(node, ast.ImportFrom):
256            if node.module is not None:
257                module_name = node.module
258            module_name = "." * node.level + module_name
259        elif not isinstance(node, ast.Import):
260            continue
261        # Do not care the files in mindspore package
262        if module_name.startswith("mindspore"):
263            continue
264
265        for n in node.names:
266            if n.name.startswith("mindspore"):
267                continue
268            if module_name == "":
269                whole_module = n.name
270            else:
271                whole_module = module_name
272                if n.name is not None:
273                    whole_module += "." + n.name
274            try:
275                module_spec = importlib.util.find_spec(whole_module, pkg)
276            except (ModuleNotFoundError, ValueError):
277                whole_module = whole_module[0:whole_module.rfind('.')]
278                module_spec = importlib.util.find_spec(whole_module, pkg)
279            if module_spec is None:
280                continue
281            module = importlib.util.module_from_spec(module_spec)
282            if hasattr(module, '__file__'):
283                dep_file_path = module.__file__
284            else:
285                continue
286            # Exclude the installed modules.
287            if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
288                logger.debug(f"dependent file path: {dep_file_path}")
289                compile_cache_dep_files.append(dep_file_path)
290                __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
291
292
293def _get_compile_cache_dep_files():
294    """Get the dependency files of the network"""
295    if entry_script_path is None:
296        logger.warning("Can not get the entry script file path.")
297        return []
298    compile_cache_dep_files = []
299    logger.debug(f"entry script file path: {entry_script_path}")
300    compile_cache_dep_files.append(entry_script_path)
301    __get_compile_cache_dep_files(entry_script_path, compile_cache_dep_files, None)
302    return compile_cache_dep_files
303
304
305def _restore_mutable_attr(args_list, compile_args):
306    """Restore the mutable attr for every arg."""
307    new_compile_args = ()
308    for idx, arg in enumerate(args_list):
309        if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \
310                not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")):
311            if hasattr(arg, "__ms_dynamic_len__"):
312                new_compile_args += (mutable(compile_args[idx], getattr(arg, "__ms_dynamic_len__")),)
313            else:
314                new_compile_args += (mutable(compile_args[idx], False),)
315        else:
316            new_compile_args += (compile_args[idx],)
317    return new_compile_args
318
319
320def _get_parameter_layout():
321    graph_executor = GraphExecutor_.get_instance()
322    layout = dict()
323    for phase in ms_compile_cache:
324        layout.update(graph_executor.get_parameter_layout(phase))
325    return layout
326
327
328def _handle_arg(obj, arg, compile_arg):
329    """Handle arg for runtime .If need handle the arg, return True"""
330    if isinstance(arg, PythonTensor):
331        if arg.has_init:
332            arg.init_data()
333        if not arg.const_arg:
334            return arg
335    elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
336        return arg
337    elif compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and getattr(compile_arg, "__ms_mutable__"):
338        # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
339        if isinstance(arg, list) and not arg:
340            return None
341        return arg
342    elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
343        return arg
344    elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
345            _check_all_tensor(arg):
346        return arg
347    return None
348
349
350def _handle_arg_predict(obj, arg, compile_arg):
351    """Handle arg for runtime .If need handle the arg, return True"""
352    if arg is None:
353        return None
354
355    if isinstance(arg, (int, float)):
356        return None
357
358    if isinstance(arg, (list, tuple)):
359        if compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
360                getattr(compile_arg, "__ms_mutable__"):
361            # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
362            if isinstance(arg, list) and not arg:
363                return None
364            return arg
365        if hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
366                _check_all_tensor(arg):
367            return arg
368        return None
369    return arg
370
371
372def _get_args_for_run(obj, args, kwargs, compile_args):
373    """Get the actual input args and kwargs for runtime."""
374    new_args = []
375    for arg, compile_arg in zip(args, compile_args):
376        new_arg = _handle_arg(obj, arg, compile_arg)
377        if new_arg is not None:
378            new_args.append(new_arg)
379
380    for _, value in kwargs.items():
381        new_value = _handle_arg(obj, value, None)
382        if new_value is not None:
383            new_args.append(new_value)
384
385    return new_args
386
387
388def _get_args_for_run_predict(obj, args, kwargs, compile_args):
389    """Get the actual input args and kwargs for runtime."""
390    new_args = []
391    for arg, compile_arg in zip(args, compile_args):
392        new_arg = _handle_arg_predict(obj, arg, compile_arg)
393        if new_arg is not None:
394            new_args.append(new_arg)
395
396    for _, value in kwargs.items():
397        new_value = _handle_arg_predict(obj, value, None)
398        if new_value is not None:
399            new_args.append(new_value)
400
401    return new_args
402
403
404def _is_args_fullmode(args, is_init=True):
405    """Check whether the arguments is for incremental-mode.
406
407    Args:
408        args (Union[list, tuple, dict, Tensor]): Given arguments.
409        is_init (bool): Is check in argument initialization phase.
410
411    Raises:
412        RuntimeError: loss necessary keys and values for incremental-mode.
413
414    Returns:
415        bool: Fullmode or not.
416    """
417    if not isinstance(args, dict):
418        return True
419    if not is_init and (args.get(ARG_SPECIFIED, None) is None or args.get(TOTAL_ARG_LEN, None) is None):
420        raise RuntimeError(
421            "The incremental inputs should be processed(with \"%s\" and \"%s\"), but got %s." %
422            (ARG_SPECIFIED, TOTAL_ARG_LEN, str(args)))
423    return False
424
425
426def _process_dyn_args(fn, dyn_args):
427    """Process the dynamic arguments, return the necessary data for latter processing.
428
429    Args:
430        fn (Function): The root function to compile.
431        dyn_args (Union[dict, list, tuple, None]): Given arguments for dynamic compilation.
432            None for nothing, list or tuple for fullmode setting, dict for incremental configuration.
433
434    Returns:
435        A dict which contains args for dynamic compilation. None for nothing dynamic.
436    """
437    if dyn_args is None:
438        # nothing should be done for None.
439        return dyn_args
440
441    if isinstance(dyn_args, dict) and ARG_SPECIFIED in dyn_args:
442        return dyn_args
443
444    args_sig = inspect.signature(fn)
445    if _is_args_fullmode(dyn_args):
446        if not isinstance(dyn_args, (list, tuple)):
447            temp_dyn_args = (dyn_args,)
448        else:
449            temp_dyn_args = dyn_args
450
451        # If dyn_args is fullmode, it should be apply directly.
452        args_sig_parameters = list(args_sig.parameters.values())
453        if not args_sig_parameters:
454            return ()
455
456        # fn may be Cell's construct while the first input is 'self'.
457        if args_sig_parameters[0].name == "self" and (len(temp_dyn_args) + 1) == len(args_sig_parameters):
458            bound_args = args_sig.bind(None, *temp_dyn_args)
459            bound_args.apply_defaults()
460            return bound_args.args[1:]
461
462        bound_args = args_sig.bind(*temp_dyn_args)
463        bound_args.apply_defaults()
464        return bound_args.args
465
466    # The dyn_args is not fullmode, a real compilation arguments should be assembled by latter procession...
467    arg_names = []
468    args_sig_parameters = list(args_sig.parameters.values())
469    for arg_p in args_sig_parameters:
470        if arg_p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
471            arg_names.append(arg_p.name)
472        else:
473            raise TypeError("Dynamic arguments is not accepted for VAR_POSITIONAL or VAR_KEYWORD parameters!")
474
475    offset = -1 if fn.__name__ == 'construct' and args_sig_parameters[0].name == "self" else 0
476    meet_index = set()
477
478    def _check_index_valid(index):
479        if index >= len(arg_names):
480            raise ValueError("For dict mode, valid index is \"0\"-\"%d\", but got %s!" % (len(arg_names) - 1, index))
481        if index in meet_index:
482            raise ValueError("For dict mode, there are more than one same specified key for real index: %d!" % index)
483        meet_index.add(index)
484
485    arg_handler_infos = []
486    for k, v in dyn_args.items():
487        if not isinstance(k, str):
488            raise TypeError("For dict mode, only string key is accepted, but got %s!" % k)
489        if k in arg_names:
490            cur_id = arg_names.index(k)
491            _check_index_valid(cur_id)
492            arg_handler_infos.append([cur_id + offset, v])
493        else:
494            raise ValueError("For dict mode, valid key is %s, but got %s!" % (arg_names, k))
495    return {ARG_SPECIFIED: arg_handler_infos, TOTAL_ARG_LEN: len(args_sig_parameters)}
496
497
498def _generate_dyn_compile_args(compile_args, dyn_args):
499    """Generate the dynamic compile arguments."""
500    if not dyn_args:
501        return compile_args
502    if _is_args_fullmode(dyn_args, False):
503        if not isinstance(dyn_args, (list, tuple)):
504            return (dyn_args,)
505        return dyn_args
506    arg_specified_infos = dyn_args.get(ARG_SPECIFIED, None)
507    if arg_specified_infos is None:
508        raise RuntimeError("For dict mode, a key with \"%s\" should exist, but got %s!" %
509                           (ARG_SPECIFIED, str(dyn_args)))
510    new_compile_args = list(compile_args)
511    for index, arg in arg_specified_infos:
512        new_compile_args[index] = arg
513    return tuple(new_compile_args)
514
515
516class _MindsporeFunctionExecutor:
517    """
518    Represents a function compiled by graph compiler.
519
520    _MindsporeFunctionExecutor will compile the original function for every combination
521    of argument types and shapes it is given (as well as their values, optionally).
522
523    Args:
524        fn (Function): The root function to compile.
525        input_signature (Function): User defines signature to verify input.
526        ms_create_time(TimeStamp): Time the function was created
527        obj (Object): If function is a method, obj is the owner of function,
528             else, obj is none.
529
530    Returns:
531        The result of pipeline running in graph mode.
532    """
533
534    def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
535        init_pipeline()
536        if not isinstance(fn, (types.FunctionType, types.MethodType)):
537            raise RuntimeError('fn {} is not function or method'.format(fn))
538
539        self.fn = fn
540        self.input_signature = input_signature
541        self.obj = None
542        if obj and hasattr(obj, fn.__name__):
543            self.obj = obj
544        self.shard_parent_obj = obj
545        self.enable_tuple_broaden = False
546        self._graph_executor = GraphExecutor_.get_instance()
547        self._create_time = ms_create_time
548        self._compile_args = None
549        self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
550
551    @_wrap_func
552    def __call__(self, *args, **kwargs):
553        args_list = args
554        if self.obj is not None:
555            args_list = args_list[1:]
556        phase = ""
557        try:
558            if context.get_context("mode") == context.PYNATIVE_MODE:
559                _pynative_executor.set_jit_compile_status(True, phase)
560                phase = self.compile(self.fn.__name__, *args_list, **kwargs)
561                _pynative_executor.set_jit_compile_status(False, phase)
562            else:
563                phase = self.compile(self.fn.__name__, *args_list, **kwargs)
564        except Exception as err:
565            _pynative_executor.clear_res()
566            raise err
567
568        if context.get_context("precompile_only"):
569            return None
570
571        new_inputs = self._generate_run_args(args_list, kwargs)
572        output = self._graph_executor(tuple(new_inputs), phase)
573        if context.get_context("mode") == context.PYNATIVE_MODE:
574            output = _pynative_executor.grad_jit(output, *new_inputs)
575
576        return output
577
578    def compile(self, method_name, *args, **kwargs):
579        """Returns pipeline for the given args."""
580        # Check whether hook function registered on Cell object.
581        if self.obj and hasattr(self.obj, "_hook_fn_registered"):
582            if self.obj._hook_fn_registered():
583                logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
584                               f"If you want to use hook function, please use context.set_context to set "
585                               f"pynative mode and remove 'jit' decorator.")
586        # Chose dynamic shape tensors or actual input tensors as compile args.
587        compile_args = self._generate_compile_args(args)
588        key_id = self._get_key_id()
589        compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
590                                                                              self.input_signature)
591
592        # Restore the mutable attr for every arg.
593        compile_args = _restore_mutable_attr(args, compile_args)
594        self._compile_args = compile_args
595        generate_name, echo_function_name = self._get_generate_name()
596        # The full Function name
597        full_function_name = generate_name
598        create_time = ''
599
600        # Add key with obj
601        if self.obj is not None:
602            if self.obj.__module__ != self.fn.__module__:
603                logger.info(
604                    f'The module of `self.obj`: `{self.obj.__module__}` is not same with the module of `self.fn`: '
605                    f'`{self.fn.__module__}`')
606            self.obj.__parse_method__ = method_name
607            if isinstance(self.obj, ms.nn.Cell):
608                generate_name = generate_name + '.' + str(self.obj.create_time)
609                create_time = str(self.obj.create_time)
610            else:
611                generate_name = generate_name + '.' + str(self._create_time)
612                create_time = str(self._create_time)
613
614            generate_name = generate_name + '.' + str(id(self.obj))
615            full_function_name = generate_name
616        else:
617            # Different instance of same class may use same memory(means same obj_id) at diff times.
618            # To avoid unexpected phase matched, add create_time to generate_name.
619            generate_name = generate_name + '.' + str(self._create_time)
620            create_time = str(self._create_time)
621
622        self.enable_tuple_broaden = False
623        if hasattr(self.obj, "enable_tuple_broaden"):
624            self.enable_tuple_broaden = self.obj.enable_tuple_broaden
625
626        self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
627        key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
628        phase = generate_name + '.' + str(key)
629
630        update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
631
632        if phase in ms_compile_cache:
633            # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
634            # generated in generate_arguments_key.
635            self._graph_executor.clear_compile_arguments_resource()
636            return phase
637
638        _check_recompile(self.obj, compile_args, kwargs, full_function_name, create_time, echo_function_name)
639
640        # If enable compile cache, get the dependency files list and set to graph executor.
641        self._set_compile_cache_dep_files()
642        if self.jit_config_dict:
643            self._graph_executor.set_jit_config(self.jit_config_dict)
644        else:
645            jit_config_dict = JitConfig().jit_config_dict
646            self._graph_executor.set_jit_config(jit_config_dict)
647
648        if self.obj is None:
649            # Set an attribute to fn as an identifier.
650            if isinstance(self.fn, types.MethodType):
651                setattr(self.fn.__func__, "__jit_function__", True)
652            else:
653                setattr(self.fn, "__jit_function__", True)
654            is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
655            if isinstance(self.fn, types.MethodType):
656                delattr(self.fn.__func__, "__jit_function__")
657            else:
658                delattr(self.fn, "__jit_function__")
659        else:
660            if isinstance(self.obj, ms.nn.Cell):
661                self._graph_executor.set_weights_values(self.obj.parameters_dict())
662            is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
663
664        if not is_compile:
665            raise RuntimeError("Executor compile failed.")
666        ms_compile_cache.add(phase)
667
668        return phase
669
670    @staticmethod
671    def _optimizer_state_init(opt_states):
672        """set data for all optimizer states in case it is executed in graph mode"""
673        prefix_list = ["moments", "accum", "moment1", "moment2", "lamb_m", "lamb_v", "mean_grad",
674                       "mean_square", "prev"]
675        for opt_param in opt_states:
676            prefix = opt_param.name[:opt_param.name.find(".")]
677            if opt_param.has_init and (prefix in prefix_list or opt_param.name == "global_step"):
678                opt_param.init_data()
679
680    def _get_key_id(self):
681        """get key id."""
682        if isinstance(self.obj, ms.nn.Cell):
683            key_id = str(id(self.obj)) + str(self.obj.create_time)
684        else:
685            key_id = str(id(self.obj)) + str(self._create_time)
686
687        if _pynative_executor.grad_flag():
688            key_id = key_id + ".grad"
689        return key_id
690
691    def _get_generate_name(self):
692        """get generate name."""
693        generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + str(
694            self.fn.__code__.co_firstlineno)
695        echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
696                             + "\", line " + str(self.fn.__code__.co_firstlineno)
697        if _pynative_executor.grad_flag():
698            generate_name = generate_name + ".grad"
699        if _is_pynative_parallel():
700            generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
701        return generate_name, echo_function_name
702
703    def _set_compile_cache_dep_files(self):
704        # If enable compile cache, get the dependency files list
705        enable_compile_cache = context.get_context("enable_compile_cache")
706        if enable_compile_cache is None:
707            enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
708        if enable_compile_cache is True or enable_compile_cache == "1":
709            self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
710
711    def _generate_compile_args(self, args_list):
712        """Chose dynamic shape tensors or actual input tensors as compile args."""
713        # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
714        compile_args = _pynative_executor.get_dynamic_input(args_list)
715        # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
716        if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
717            compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
718            if len(compile_args) != len(args_list):
719                raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
720                                 f"dynamic shape tensors: {len(compile_args)}.")
721            self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
722            Validator.check_symbolic_shape(compile_args, args_list)
723
724        # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
725        if self.input_signature is not None:
726            compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
727            dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
728            Validator.check_symbolic_shape(self.input_signature, args_list)
729            if dyn_shape:
730                # Checkout whether the `sens` has been added to args_list.
731                if len(compile_args) == len(args_list) - 1:
732                    logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
733                                   f"of input_signature args '{len(compile_args)}'. The last actual args may "
734                                   f"be 'sens' and added it to compile args.")
735                    compile_args.append(args_list[-1])
736                compile_args = tuple(compile_args)
737                self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
738                if self.obj is not None:
739                    _pynative_executor.set_dynamic_input(self.obj, *compile_args)
740                else:
741                    _pynative_executor.set_dynamic_input(self.fn, *compile_args)
742            else:
743                if not verify_inputs_signature(compile_args, args_list):
744                    raise ValueError("The input args is incompatible with the args in `input_signature`!")
745        return compile_args
746
747    def _generate_run_args(self, args_list, kwargs):
748        """
749        Generate input args, which are required for running.
750
751        Args:
752            args_list (Tuple): Actual input args.
753            kwargs (Dict): Actual input kwargs.
754
755        Returns:
756            new_inputs, new input args, which are required for running.
757        """
758        return _get_args_for_run(self, args_list, kwargs, self._compile_args)
759
760
761# The attributes used to identify a given object.
762attr_op = {"__str__": lambda x: x.__str__(),
763           "__hash__": lambda x: str(x.__hash__()),
764           "__module__": lambda x: x.__module__,
765           "__name__": lambda x: x.__name__,
766           "__qualname__": lambda x: x.__qualname__,
767           "__len__": lambda x: str(x.__len__()),
768           "__code__": lambda x: x.__code__.co_filename + str(x.__code__.co_firstlineno)
769           }
770
771
772def _get_obj_id(input_obj):
773    """Get hash id of single object."""
774    obj_id = ".".join(
775        (map(lambda x: attr_op.get(x)(input_obj) if hasattr(input_obj, x) and getattr(input_obj, x) else "", attr_op)))
776    return obj_id + str(id(input_obj))
777
778
779def _get_jit_hash(hash_input):
780    """Get hash value of single object or list of objects."""
781    if isinstance(list, tuple):
782        return ".".join(map(_get_obj_id, hash_input))
783    return _get_obj_id(hash_input)
784
785
786def _update_graph_executor_config(jit_config):
787    """Update GraphExecutor jit_config"""
788    if isinstance(jit_config, JitConfig):
789        jit_config = jit_config.jit_config_dict
790    if not isinstance(jit_config, dict):
791        return
792    valid_config = dict()
793    for k, v in jit_config.items():
794        valid_config[str(k)] = str(v)
795    GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
796
797
798def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
799    """
800    Create a callable MindSpore graph from a Python function.
801
802    This allows the MindSpore runtime to apply optimizations based on graph.
803
804    Args:
805        fn (Function): The Python function that will be run as a graph. Default: ``None`` .
806        mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
807
808            - `PSJit <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html>`_ :
809              Parse python ast to build graph.
810            - `PIJit <https://www.mindspore.cn/docs/en/master/design/dynamic_graph_and_static_graph.html>`_ :
811              Parse python bytecode to build graph at runtime.
812
813        input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
814            shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
815            input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
816            same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
817
818            - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
819              for graph-compiling.
820            - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
821              substituted into the input at the corresponding position for graph-compiling.
822
823            Default: ``None`` .
824
825        hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
826            like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
827            will trigger recompilation. Default: ``None`` .
828        jit_config (JitConfig): Jit config for compile. Default: ``None`` .
829        compile_once(bool): ``True``: The function would be compiled once when it was created many times.
830            But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
831            it was created again.
832            Default: ``False`` .
833
834    Note:
835        If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
836        will not accept `**kwargs`.
837
838    Returns:
839        Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
840        None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
841        equal to the case when `fn` is not None.
842
843    Supported Platforms:
844        ``Ascend`` ``GPU`` ``CPU``
845
846    Examples:
847        >>> import numpy as np
848        >>> from mindspore import Tensor
849        >>> from mindspore import ops
850        >>> from mindspore import jit
851        ...
852        >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
853        >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
854        ...
855        >>> # create a callable MindSpore graph by calling decorator @jit
856        >>> def tensor_add(x, y):
857        ...     z = x + y
858        ...     return z
859        ...
860        >>> tensor_add_graph = jit(fn=tensor_add)
861        >>> out = tensor_add_graph(x, y)
862        ...
863        >>> # create a callable MindSpore graph through decorator @jit
864        >>> @jit
865        ... def tensor_add_with_dec(x, y):
866        ...     z = x + y
867        ...     return z
868        ...
869        >>> out = tensor_add_with_dec(x, y)
870        ...
871        >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
872        >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
873        ...                       Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
874        ... def tensor_add_with_sig(x, y):
875        ...     z = x + y
876        ...     return z
877        ...
878        >>> out = tensor_add_with_sig(x, y)
879        ...
880        >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
881        ... def tensor_add_with_sig_1(x, y):
882        ...     z = x + y
883        ...     return z
884        ...
885        >>> out1 = tensor_add_with_sig_1(x, y)
886        ...
887        ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
888        ... # While fn differs during calling again, recompilation will be triggered.
889        >>> def func(x):
890        ...     return ops.exp(x)
891        ...
892        >>> def closure_fn(x, fn):
893        ...     @jit(hash_args=fn)
894        ...     def inner_fn(a):
895        ...         return fn(a)
896        ...     return inner_fn(x)
897        ...
898        >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
899        >>> for i in range(10):
900        ...     closure_fn(inputs, func)
901        ...
902        ... # Set compile_once = True, otherwise the train_step will be compiled again.
903        >>> def train(x):
904        ...     @jit(compile_once = True)
905        ...     def train_step(x):
906        ...         return ops.exp(x)
907        ...     for i in range(10):
908        ...         train_step(x)
909        ...
910        >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
911        >>> for i in range(10):
912        ...     train(inputs)
913    """
914
915    def wrap_mindspore(func):
916        if not isinstance(compile_once, bool):
917            logger.warning(f"The parameter `compile_once` of jit should be a bool, "
918                           f"but got {type(compile_once)}.")
919        if hash_args:
920            hash_obj = _get_jit_hash(hash_args)
921        elif compile_once:
922            hash_obj = 0
923        else:
924            hash_obj = int(time.time() * 1e9)
925
926        dyn_args = _process_dyn_args(func, input_signature)
927
928        @wraps(func)
929        def staging_specialize(*args, **kwargs):
930            if os.getenv("MS_JIT") == '0':
931                return func(*args, **kwargs)
932
933            args, kwargs = _handle_func_args(func, *args, **kwargs)
934
935            process_obj = None
936            if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
937                process_obj = args[0]
938            # only the function or cell instance wrapped by shard will fall into this branch
939            if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
940                process_obj = hash_args
941            out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
942            return out
943
944        return staging_specialize
945
946    def pi_wrap_mindspore(decorated):
947        func = decorated
948        if isinstance(func, ms.nn.Cell):
949            func = func.construct
950        if isinstance(func, type) and issubclass(func, ms.nn.Cell):
951            func = func.construct
952        if isinstance(func, types.MethodType):
953            func = func.__func__
954        if not isinstance(func, types.FunctionType):
955            logger.warning("only support function and mindspore.nn.Cell instance")
956            return decorated
957
958        # generator, coroutine, awaitable and a function that return them is unsupported
959        UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE |
960                                 inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE)
961        if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE:
962            return decorated
963
964        _update_graph_executor_config(jit_config)
965        config = dict()
966        if isinstance(jit_config, JitConfig):
967            config.update(jit_config.jit_config_dict)
968        elif jit_config is not None:
969            config.update(jit_config)
970        jit_mode_pi_enable()
971
972        if jit_mode_pi_compile(func, config, input_signature) is False:
973            logger.warning('add fn {} to compile failed '.format(func))
974
975        return decorated
976
977    wrap_func = wrap_mindspore
978    if mode == "PIJit":
979        wrap_func = pi_wrap_mindspore
980
981    if fn is not None:
982        return wrap_func(fn)
983    return wrap_func
984
985
986def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
987    """
988    Create a callable MindSpore graph from a Python function.
989
990    This allows the MindSpore runtime to apply optimizations based on graph.
991
992    Note:
993        - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
994        - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
995          will not accept `**kwargs`.
996
997    Args:
998        fn (Function): The Python function that will be run as a graph. Default: ``None`` .
999        input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
1000            will be supplied to this function. The shape and dtype of actual inputs of `fn` should
1001            keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
1002        hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
1003            like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
1004            will trigger recompilation. Default: ``None`` .
1005        jit_config (JitConfig): Jit config for compile. Default: ``None`` .
1006
1007    Returns:
1008        Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
1009        None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
1010        equal to the case when `fn` is not None.
1011
1012    Supported Platforms:
1013        ``Ascend`` ``GPU`` ``CPU``
1014
1015    Examples:
1016        >>> import numpy as np
1017        >>> from mindspore import Tensor
1018        >>> from mindspore import ops
1019        >>> from mindspore import ms_function
1020        ...
1021        >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1022        >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1023        ...
1024        >>> # create a callable MindSpore graph by calling ms_function
1025        >>> def tensor_add(x, y):
1026        ...     z = x + y
1027        ...     return z
1028        ...
1029        >>> tensor_add_graph = ms_function(fn=tensor_add)
1030        >>> out = tensor_add_graph(x, y)
1031        ...
1032        >>> # create a callable MindSpore graph through decorator @ms_function
1033        >>> @ms_function
1034        ... def tensor_add_with_dec(x, y):
1035        ...     z = x + y
1036        ...     return z
1037        ...
1038        >>> out = tensor_add_with_dec(x, y)
1039        ...
1040        >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
1041        >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
1042        ...                               Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
1043        ... def tensor_add_with_sig(x, y):
1044        ...     z = x + y
1045        ...     return z
1046        ...
1047        >>> out = tensor_add_with_sig(x, y)
1048        ...
1049        ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
1050        ... # While fn differs during calling again, recompilation will be triggered.
1051        >>> def func(x):
1052        ...     return ops.exp(x)
1053        ...
1054        >>> def closure_fn(x, fn):
1055        ...     @ms_function(hash_args=fn)
1056        ...     def inner_fn(a):
1057        ...         return fn(a)
1058        ...     return inner_fn(x)
1059        ...
1060        >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
1061        >>> for i in range(10):
1062        ...     closure_fn(inputs, func)
1063    """
1064
1065    logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
1066                   "Please use 'mindspore.jit' instead.")
1067    return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
1068
1069
1070def _core(fn=None, **flags):
1071    """
1072    A decorator that adds a flag to the function.
1073
1074    By default, the function is marked as True, enabling to use this decorator to
1075    set flag to a graph.
1076
1077    Args:
1078        fn (Function): Function to add flag. Default: ``None``.
1079        flags (dict): The following flags can be set core, which indicates that this is a core function or
1080                      other flag. Default: ``None``.
1081
1082    Returns:
1083        Function, the function with core flag.
1084
1085    Supported Platforms:
1086        ``Ascend`` ``GPU`` ``CPU``
1087    """
1088
1089    # need set the attr and access on c++
1090    def deco(fn):
1091        fn._func_graph_flags = {
1092            'core': True,
1093            **flags,
1094        }
1095        return fn
1096
1097    if fn is not None:
1098        ret = deco(fn)
1099    else:
1100        ret = deco
1101    return ret
1102
1103
1104def _add_flags(fn=None, **flags):
1105    """
1106    A decorator that adds a flag to the function.
1107
1108    Note:
1109        Only supports bool value.
1110
1111    Args:
1112        fn (Function): Function or cell to add flag. Default: ``None``.
1113        flags (dict): Flags use kwargs. Default: ``None``.
1114
1115    Returns:
1116        Function, the function with added flags.
1117
1118    Supported Platforms:
1119        ``Ascend`` ``GPU`` ``CPU``
1120    """
1121
1122    def deco(fn):
1123        # need set the attr and access on c++
1124        if not hasattr(fn, "_func_graph_flags"):
1125            fn._func_graph_flags = {}
1126
1127        fn._func_graph_flags.update({**flags})
1128        return fn
1129
1130    ret = deco
1131    if fn is not None:
1132        ret = deco(fn)
1133    return ret
1134
1135
1136def _no_recursive(callable_obj):
1137    """
1138    Method or function decorator for ignoring recursive check.
1139
1140    This allows MindSpore to skip the procedure of checking function or method recursive.
1141
1142    Args:
1143        callable_obj (Union(method, function)): The function or method to call.
1144
1145    Returns:
1146        Function or method with no_recursive flag.
1147
1148    Raises:
1149        TypeError: If ms_class is used for non-class types or nn.Cell.
1150        AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called.
1151
1152    Supported Platforms:
1153        ``Ascend`` ``GPU`` ``CPU``
1154    """
1155    is_cell_subclass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell)
1156    if not is_cell_subclass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj):
1157        raise TypeError(f"Decorator no_recursive is used for callable object, but got {callable_obj}.")
1158    _add_flags(callable_obj, no_recursive=True)
1159    return callable_obj
1160
1161
1162def ms_class(cls):
1163    """
1164    Class decorator for user-defined classes.
1165
1166    This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1167
1168    Note:
1169        `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
1170
1171    Args:
1172        cls (Class): User-defined class.
1173
1174    Returns:
1175        Class.
1176
1177    Raises:
1178        TypeError: If ms_class is used for non-class types or nn.Cell.
1179        AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
1180
1181    Supported Platforms:
1182        ``Ascend`` ``GPU`` ``CPU``
1183
1184    Examples:
1185        >>> import mindspore.nn as nn
1186        >>> from mindspore import ms_class
1187        ...
1188        >>> @ms_class
1189        ... class UserDefinedNet:
1190        ...     def __init__(self):
1191        ...         self.value = 10
1192        ...
1193        ...     def func(self, x):
1194        ...         return 2 * x
1195        ...
1196        >>> class Net(nn.Cell):
1197        ...     def __init__(self):
1198        ...         super(Net, self).__init__()
1199        ...         self.net = UserDefinedNet()
1200        ...
1201        ...     def construct(self, x):
1202        ...         out = self.net.value + self.net.func(x)
1203        ...         return out
1204        ...
1205        >>> net = Net()
1206        >>> out = net(5)
1207        >>> print(out)
1208        20
1209    """
1210
1211    logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
1212                   "Please use 'mindspore.jit_class' instead.")
1213
1214    # Check if cls is of type class.
1215    if not inspect.isclass(cls):
1216        raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
1217    # Check if cls is nn.Cell.
1218    if issubclass(cls, ms.nn.Cell):
1219        raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1220    logger.info(f'Found ms_class: {cls}.')
1221    setattr(cls, '__ms_class__', True)
1222    return cls
1223
1224
1225def jit_class(cls):
1226    """
1227    Class decorator for user-defined classes.
1228
1229    This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
1230
1231    Args:
1232        cls (Class): User-defined class.
1233
1234    Returns:
1235        Class.
1236
1237    Raises:
1238        TypeError: If `jit_class` is used for non-class types or nn.Cell.
1239        AttributeError: If the private attributes or magic methods of the class decorated with `jit_class` is called.
1240
1241    Supported Platforms:
1242        ``Ascend`` ``GPU`` ``CPU``
1243
1244    Examples:
1245        >>> import mindspore.nn as nn
1246        >>> from mindspore import jit_class
1247        ...
1248        >>> @jit_class
1249        ... class UserDefinedNet:
1250        ...     def __init__(self):
1251        ...         self.value = 10
1252        ...
1253        ...     def func(self, x):
1254        ...         return 2 * x
1255        ...
1256        >>> class Net(nn.Cell):
1257        ...     def __init__(self):
1258        ...         super(Net, self).__init__()
1259        ...         self.net = UserDefinedNet()
1260        ...
1261        ...     def construct(self, x):
1262        ...         out = self.net.value + self.net.func(x)
1263        ...         return out
1264        ...
1265        >>> net = Net()
1266        >>> out = net(5)
1267        >>> print(out)
1268        20
1269    """
1270    from mindspore import nn
1271    # Check if cls is of type class.
1272    if not inspect.isclass(cls):
1273        raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
1274    # Check if cls is nn.Cell.
1275    if issubclass(cls, nn.Cell):
1276        raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1277    setattr(cls, '__ms_class__', True)
1278    return cls
1279
1280
1281def set_adapter_config(config):
1282    """
1283    Register configuration information for MSAdapter.
1284
1285    Args:
1286        config (dict): Configuration information.
1287    """
1288    if not isinstance(config, dict):
1289        raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
1290    for key, value in config.items():
1291        if key == "Tensor":
1292            ms_adapter_registry.register_tensor(value)
1293        elif key == "Parameter":
1294            ms_adapter_registry.register_parameter(value)
1295        elif key == "convert_object_map":
1296            ms_adapter_registry.register_convert_map(value)
1297        elif key == "convert_adapter_tensor_map":
1298            ms_adapter_registry.register_convert_adapter_tensor_map(value)
1299        else:
1300            raise ValueError(f"Unsupported key in adapter config: {key}")
1301
1302
1303def _function_forbid_reuse(func):
1304    if not inspect.isfunction(func):
1305        raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
1306    setattr(func, '__function_forbid_reuse__', True)
1307    return func
1308
1309
1310def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
1311    """Build broadcast graph."""
1312    from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
1313    if not broadcast_params_dict:
1314        broadcast_params_dict = {}
1315    broadcast_params = []
1316    for param in broadcast_params_dict.values():
1317        broadcast_params.append(Tensor(param.asnumpy()))
1318    _broadcast_net = _BroadCastCell(broadcast_params)
1319    _broadcast_net.phase = broadcast_phase
1320    broadcasted_params = _broadcast_net()
1321    for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
1322        broadcast_params_dict.get(param_name).set_data(param)
1323
1324
1325def _get_auto_split_param_names(parameter_layout_dict):
1326    auto_split_param_names = []
1327    for key, value in parameter_layout_dict.items():
1328        for dim in value[1]:
1329            if dim != -1:
1330                auto_split_param_names.append(key)
1331                break
1332    return auto_split_param_names
1333
1334
1335def _parameter_broadcast(obj):
1336    """
1337    Parameter broadcast.
1338    When the parallel mode is 'semi_auto_parallel' or 'auto_parallel', it will broadcast the parameters that have not
1339    split.
1340    """
1341    auto_split_param_names = []
1342    if _is_in_auto_parallel_mode():
1343        auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
1344
1345    broadcast_params_dict = obj.parameters_broadcast_dict()
1346    if auto_split_param_names and broadcast_params_dict:
1347        broadcast_params_dict = OrderedDict()
1348        for param_name, param in obj.parameters_broadcast_dict().items():
1349            if param_name not in auto_split_param_names:
1350                broadcast_params_dict[param_name] = param
1351    broadcast_phase = "_broadcast_subgraph"
1352    _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
1353
1354
1355class _no_grad(contextlib.ContextDecorator):
1356    """
1357    Context Manager to disable gradient calculation. When enter this context, we will disable calculate
1358    gradient. When exit this context, we will resume its prev state.
1359    Currently, it can only use in Pynative mode. It also can be used as decorator.
1360    """
1361
1362    def __init__(self):
1363        self.prev_state = False
1364
1365    def __enter__(self):
1366        if context.get_context("mode") == context.GRAPH_MODE:
1367            raise RuntimeError("For no_grad feature, currently only support Pynative mode, but got Graph mode.")
1368        self.prev_state = _pynative_executor.enable_grad()
1369        _pynative_executor.set_enable_grad(False)
1370
1371    def __exit__(self, exc_type, exc_val, exc_tb):
1372        _pynative_executor.set_enable_grad(self.prev_state)
1373        return False
1374
1375
1376class _PyNativeExecutor:
1377    """
1378    A pynative executor used to compile/manage/run single op.
1379
1380    The main functions include: construct op graph, compile op graph, auto grad and run op graph.
1381
1382    Args:
1383        obj (Object): The python network that will be run in pynative mode.
1384        args (Tuple(Tensor...)): The inputs of network in tuple form.
1385
1386    Returns:
1387        gradients (Tuple(Tensor...)): The gradients of network parameters and inputs.
1388
1389    Supported Platforms:
1390        ``Ascend`` ``GPU`` ``CPU``
1391    """
1392
1393    def __init__(self):
1394        self._executor = PyNativeExecutor_.get_instance()
1395        self._executor.set_py_exe_path(sys.executable)
1396        self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
1397
1398    @staticmethod
1399    def parameter_broadcast(obj, phase):
1400        """
1401        Run broadcast for parameter.
1402
1403        Args:
1404            obj (Cell): The cell instance.
1405            phase (str): The phase of cell instance.
1406
1407        Return:
1408            None.
1409        """
1410        if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1411            _parameter_broadcast(obj)
1412
1413    def real_run_op(self, *args):
1414        """
1415        Run single op.
1416
1417        Args:
1418            args (tuple): Op prim and input arguments.
1419
1420        Return:
1421            Tensor, result of run op.
1422        """
1423        return self._executor.real_run_op(*args)
1424
1425    def run_op_async(self, *args):
1426        """
1427        Run single op async.
1428
1429        Args:
1430            args (tuple): Op prim and input arguments.
1431
1432        Return:
1433            StubNode, result of run op.
1434        """
1435        return self._executor.run_op_async(*args)
1436
1437    def new_graph(self, obj, *args, **kwargs):
1438        """
1439        Initialize resources for building forward and backward graph.
1440
1441        Args:
1442            obj (Function/Cell): The function or cell instance.
1443            args (tuple): Function or cell input arguments.
1444            kwargs (dict): keyword arguments.
1445
1446        Return:
1447            None.
1448        """
1449        self._executor.new_graph(obj, *args, *(kwargs.values()))
1450
1451    def end_graph(self, obj, output, *args, **kwargs):
1452        """
1453        Clean resources after building forward and backward graph.
1454
1455        Args:
1456            obj (Function/Cell): The function or cell instance.
1457            output (Tensor/tuple/list): Function or cell output object.
1458            args (tuple): Function or cell input arguments.
1459            kwargs (dict): keyword arguments.
1460
1461        Return:
1462            None.
1463        """
1464        self._executor.end_graph(obj, output, *args, *(kwargs.values()))
1465
1466    def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
1467        """
1468        Whether the forward graph need to construct.
1469
1470        Args:
1471            grad (GradOperation): The gradoperation object.
1472            obj (Function/Cell): The function or cell instance.
1473            grad_hash_id (tuple): The id of objects which contribute to cache of compiled graph in pynative mode.
1474            args (tuple): Function or cell input arguments.
1475            kwargs (dict): keyword arguments.
1476
1477        Return:
1478            bool, specifies whether the forward graph need to construct.
1479        """
1480        return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values()))
1481
1482    def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
1483        """
1484        Get grad graph.
1485
1486        Args:
1487            obj (Function/Cell): The function or cell instance.
1488            grad (GradOperation): The gradoperation object.
1489            weights (ParameterTuple): The weights of cell instance.
1490            grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1491              If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1492            args (tuple): Function or cell input arguments.
1493            kwargs (dict): keyword arguments.
1494
1495        Return:
1496            None.
1497        """
1498        return self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
1499
1500    def clear_res(self):
1501        """
1502        Clean resource for _PyNativeExecutor.
1503
1504        Return:
1505            None.
1506        """
1507        return self._executor.clear_res()
1508
1509    def sync(self):
1510        """
1511        SyncStream.
1512
1513        Return:
1514            None.
1515        """
1516        self._executor.sync()
1517
1518    def grad_jit(self, output, *args):
1519        """
1520        Building grad graph decorated by jit.
1521
1522        Args:
1523            output (tuple): The function or cell decorated by jit output object.
1524            args (tuple): Function or cell decorated by jit input arguments.
1525
1526        Return:
1527            None.
1528        """
1529        return self._executor.grad_jit(output, *args)
1530
1531    def grad_flag(self):
1532        """
1533        The flag of building grad graph.
1534
1535        Return:
1536            bool, whether building grad graph.
1537        """
1538        return self._executor.grad_flag()
1539
1540    def set_grad_flag(self, flag):
1541        """
1542        Set the flag of building grad graph.
1543
1544        Args:
1545            flag (bool): Specifying whether building grad graph.
1546
1547        Return:
1548            None.
1549        """
1550        self._executor.set_grad_flag(flag)
1551
1552    def set_async_for_graph(self, flag):
1553        """
1554        Set the flag for graph async run.
1555
1556        Args:
1557            flag (bool): Specifying whether enable graph async run.
1558
1559        Return:
1560            None.
1561        """
1562        self._executor.set_async_for_graph(flag)
1563
1564    def enable_grad(self):
1565        """
1566        The global flag whether needing to calculate gradient.
1567
1568        Return:
1569            bool, whether needing to calculate gradient.
1570        """
1571        return self._executor.enable_grad()
1572
1573    def set_enable_grad(self, flag):
1574        """
1575        Set the flag of calculating gradient.
1576
1577        Args:
1578            flag (bool): Specifying whether calculating gradient.
1579
1580        Return:
1581            None.
1582        """
1583        self._executor.set_enable_grad(flag)
1584
1585    def set_jit_compile_status(self, status, phase):
1586        """
1587        Set jit is compiling
1588
1589        Args:
1590            status(bool): jit compile status
1591            phase (str): The phase of cell/function instance.
1592        Return:
1593            None.
1594        """
1595        self._executor.set_jit_compile_status(status, phase)
1596
1597    def set_is_run_recompute(self, status):
1598        """
1599        Set recompute grad is compiling
1600
1601        Args:
1602            status(bool): grad is in recompute status
1603        Return:
1604            None.
1605        """
1606        self._executor.set_is_run_recompute(status)
1607
1608    def set_dynamic_input(self, obj, *args):
1609        """
1610        Set dynamic shape tensor of input arguments.
1611
1612        Args:
1613            obj (Function/Cell): The function or cell instance.
1614            args (tuple): Function or cell dynamic input arguments.
1615
1616        Return:
1617            None.
1618        """
1619        self._executor.set_dynamic_input(obj, *args)
1620
1621    def get_dynamic_input(self, *actual_args):
1622        """
1623        Get dynamic shape arguments according to actual input arguments.
1624
1625        Args:
1626            actual_args(tuple): Actual input arguments of Function or Cell.
1627
1628        Return:
1629            dynamic_shape_args(tuple): Dynamic shape arguments of Function or Cell.
1630        """
1631        return self._executor.get_dynamic_input(*actual_args)
1632
1633    def is_first_cell(self):
1634        """
1635        The flag of first cell instance.
1636
1637        Return:
1638            bool, specifies whether is the first cell.
1639        """
1640
1641        return self._executor.is_first_cell()
1642
1643    def set_hook_changed(self, cell):
1644        """
1645        The flag of registering or removing a hook function on Cell instance.
1646
1647        Args:
1648            cell (Cell): The cell instance.
1649
1650        Return:
1651            None.
1652        """
1653        self._executor.set_hook_changed(cell)
1654
1655    def constant_folding(self, *args):
1656        """
1657        Get value by infer value.
1658
1659        Args:
1660            args (tuple): Op prim and input arguments.
1661
1662        Return:
1663            Tensor, the value get by op infer.
1664        """
1665        return self._executor.constant_folding(*args)
1666
1667
1668class _CellGraphExecutor:
1669    """
1670    An executor used to compile/manage/run graph for a Cell.
1671
1672    Including data_graph, train_graph, eval_graph and predict graph.
1673
1674    Args:
1675        obj (Function/Cell): The function or cell instance need compile.
1676        args (tuple): Function or cell input arguments.
1677
1678    Returns:
1679        Graph, return the result of pipeline running.
1680    """
1681
1682    def __init__(self):
1683        # create needed graph by lazy mode
1684        self.is_init = False
1685        self.enable_tuple_broaden = False
1686        self.obfuscate_config = None  # used for model's dynamic obfuscation
1687        self._graph_executor = GraphExecutor_.get_instance()
1688        self._graph_executor.set_py_exe_path(sys.executable)
1689        self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
1690
1691    def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
1692                     input_indexs, phase='dataset', need_run=True):
1693        """
1694        Initialization interface for calling data subgraph.
1695
1696        Args:
1697            queue_name (str): The name of tdt queue on the device.
1698            dataset_size (int): The size of dataset.
1699            batch_size (int): The size of batch.
1700            dataset_types (list): The output types of element in dataset.
1701            dataset_shapes (list): The output shapes of element in dataset.
1702            input_indexs (list): The index of data with net.
1703            phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
1704
1705        Returns:
1706            bool, specifies whether the data subgraph was initialized successfully.
1707        """
1708        if not init_exec_dataset(queue_name=queue_name,
1709                                 size=dataset_size,
1710                                 batch_size=batch_size,
1711                                 types=dataset_types,
1712                                 shapes=dataset_shapes,
1713                                 input_indexs=input_indexs,
1714                                 phase=phase,
1715                                 need_run=need_run):
1716            raise RuntimeError("Failure to init and dataset subgraph!")
1717        self._graph_executor.set_queue_name(queue_name)
1718        return True
1719
1720    def set_queue_name(self, queue_name):
1721        """
1722        while a mode use shared dataset with others, need set queue_name which saved in data_set
1723        :param queue_name:
1724        :return:
1725        """
1726        self._graph_executor.set_queue_name(queue_name)
1727
1728    def get_queue_name(self, dataset_phase):
1729        """
1730        Get cached queue name for the graph loaded from compile cache.
1731        :return: cached queue name
1732        """
1733        return self._graph_executor.get_queue_name(dataset_phase)
1734
1735    @staticmethod
1736    def _set_dataset_mode(obj):
1737        """set dataset mode."""
1738        # decide whether to sink based on the sink_mode flag which is set in connect_network_with_dataset
1739        if 'sink_mode' in obj.get_flags().keys() and obj.get_flags()['sink_mode'] is True:
1740            _set_dataset_mode_config('sink')
1741        else:
1742            _set_dataset_mode_config('normal')
1743
1744    @staticmethod
1745    def _use_vm_mode():
1746        enable_ge = context.get_context("enable_ge")
1747        enable_debug_runtime = context.get_context("enable_debug_runtime")
1748        exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
1749        return not enable_ge or (enable_debug_runtime and exe_mode)
1750
1751    def _build_data_graph(self, obj, phase):
1752        self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
1753
1754    def _set_compile_cache_dep_files(self, phase):
1755        # If enable compile cache, get the dependency files list
1756        enable_compile_cache = context.get_context("enable_compile_cache")
1757        if enable_compile_cache is None:
1758            enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
1759        if enable_compile_cache is True or enable_compile_cache == "1":
1760            self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
1761
1762    def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
1763        """
1764        Compiles graph.
1765
1766        Args:
1767            obj (Function/Cell): The function or cell instance need compile.
1768            phase (str): The name of compile phase. Default: 'predict'.
1769            do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
1770            jit_config_dict (dict): Jit config for compile. Default: ``None``.
1771            args (tuple): Args of the Cell object.
1772            kwargs (dict): Kwargs of the Cell object.
1773
1774        Return:
1775            Str, the full phase of the cell.
1776            Bool, if the graph has been compiled before, return False, else return True.
1777        """
1778        obj.__parse_method__ = 'construct'
1779        if not hasattr(obj, obj.__parse_method__):
1780            raise AttributeError(
1781                'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
1782        key_id = str(id(obj)) + str(obj.create_time)
1783        args = get_auto_dynamic_shape_args(args, key_id)
1784
1785        self.enable_tuple_broaden = False
1786        if hasattr(obj, "enable_tuple_broaden"):
1787            self.enable_tuple_broaden = obj.enable_tuple_broaden
1788        logger.debug(f"Convert the network: {do_convert}.")
1789        self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1790        key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1791        obj.arguments_key = str(key)
1792        raw_phase = phase
1793        phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1794        obj.phase_cache[raw_phase] = phase
1795        update_auto_dynamic_shape_phase(args, key_id, phase)
1796        obj.current_phase = phase
1797        if phase in obj.compile_cache and self.has_compiled(phase):
1798            logger.debug("%r graph has existed.", phase)
1799            # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1800            # generated in generate_arguments_key.
1801            self._graph_executor.clear_compile_arguments_resource()
1802            return phase, False
1803
1804        full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
1805        echo_function_name = obj.__class__.__name__
1806        _check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
1807
1808        obj.check_names()
1809        _check_full_batch()
1810        self._set_dataset_mode(obj)
1811        self._set_compile_cache_dep_files(phase)
1812
1813        self._graph_executor.set_weights_values(obj.parameters_dict())
1814        if jit_config_dict:
1815            self._graph_executor.set_jit_config(jit_config_dict)
1816        else:
1817            jit_config_dict = JitConfig().jit_config_dict
1818            self._graph_executor.set_jit_config(jit_config_dict)
1819        result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
1820        obj.compile_cache.add(phase)
1821        if not result:
1822            raise RuntimeError("Executor compile failed.")
1823        graph = self._graph_executor.get_func_graph(phase)
1824
1825        if graph is None:
1826            raise RuntimeError("Compile graph failed for phase {}.".format(phase))
1827
1828        auto_parallel_mode = _is_in_auto_parallel_mode()
1829        if not auto_parallel_mode:
1830            replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
1831            self._update_param_node_default_input(phase, replace)
1832        elif 'skip_auto_parallel_compile' not in obj.get_flags().keys():
1833            obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
1834            obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
1835        if "export.air" in phase:
1836            self._build_data_graph(obj, phase)
1837        elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
1838            _parameter_broadcast(obj)
1839        return phase, True
1840
1841    def _update_param_node_default_input(self, phase, replace):
1842        new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
1843        return self._graph_executor.updata_param_node_default_input(phase, new_param)
1844
1845    def _get_shard_strategy(self, obj):
1846        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1847        return self._graph_executor.get_strategy(real_phase)
1848
1849    def _get_num_parallel_ops(self, obj):
1850        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1851        return self._graph_executor.get_num_parallel_ops(real_phase)
1852
1853    def _get_allreduce_fusion(self, obj):
1854        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1855        return self._graph_executor.get_allreduce_fusion(real_phase)
1856
1857    def __call__(self, obj, *args, phase='predict'):
1858        if context.get_context("precompile_only") or _is_role_sched():
1859            return None
1860        return self.run(obj, *args, phase=phase)
1861
1862    def has_compiled(self, phase='predict'):
1863        """
1864        Specify whether have been compiled.
1865
1866        Args:
1867            phase (str): The phase name. Default: 'predict'.
1868
1869        Returns:
1870            bool, specifies whether the specific graph has been compiled.
1871        """
1872        return self._graph_executor.has_compiled(phase)
1873
1874    def flops_collection(self, phase='train'):
1875        """
1876        Specify whether have been compiled.
1877
1878        Args:
1879            phase (str): The phase name. Default: 'predict'.
1880
1881        Returns:
1882            bool, specifies whether the specific graph has been compiled.
1883        """
1884        return self._graph_executor.flops_collection(phase)
1885
1886    @_wrap_func
1887    def _exec_pip(self, obj, *args, phase=''):
1888        """Execute the generated pipeline."""
1889        fn = obj.construct
1890        obj.__parse_method__ = fn.__name__
1891        return self._graph_executor(args, phase)
1892
1893    def run(self, obj, *args, phase='predict'):
1894        """
1895        Run the specific graph.
1896
1897        Args:
1898            obj (Cell): The cell object.
1899            args (tuple): Args of the Cell object.
1900            phase (str): The phase name. Default: 'predict'.
1901
1902        Returns:
1903            Tensor/Tuple, return execute result.
1904        """
1905        if phase == 'save':
1906            exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1907            return self._graph_executor((), exe_phase)
1908
1909        phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1910        if self.has_compiled(phase_real):
1911            return self._exec_pip(obj, *args, phase=phase_real)
1912        raise KeyError('{} graph is not exist.'.format(phase_real))
1913
1914    def del_net_res(self, obj, net_id):
1915        """Clear the memory resource of a network."""
1916        self._graph_executor.del_net_res(obj, net_id)
1917
1918    def inc_graph_cell_count(self):
1919        """Increase the count of GraphCell instance."""
1920        self._graph_executor.inc_graph_cell_count()
1921
1922    def dec_graph_cell_count(self):
1923        """Decrease the count of GraphCell instance."""
1924        self._graph_executor.dec_graph_cell_count()
1925
1926    def _get_branch_control_input(self):
1927        if ('obf_ratio' not in self.obfuscate_config.keys()) or (
1928                'obf_random_seed' not in self.obfuscate_config.keys()):
1929            raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
1930        obf_random_seed = self.obfuscate_config.get('obf_random_seed')
1931        if obf_random_seed == 0:
1932            branch_control_input = 0
1933        else:
1934            branch_control_input = _generate_branch_control_input(obf_random_seed)
1935        return branch_control_input
1936
1937    def _get_func_graph(self, obj, exec_id, use_prefix=False):
1938        """Get func graph from pipeline."""
1939        if use_prefix:
1940            exec_id = exec_id + '.' + obj.arguments_key
1941        if self._graph_executor.has_compiled(exec_id) is False:
1942            return None
1943        if self.obfuscate_config is not None:
1944            raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
1945        return self._graph_executor.get_func_graph(exec_id)
1946
1947    def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
1948        """Get graph proto from pipeline."""
1949        if use_prefix:
1950            exec_id = exec_id + '.' + obj.arguments_key
1951        if self._graph_executor.has_compiled(exec_id) is False:
1952            return None
1953        if self.obfuscate_config is not None:
1954            branch_control_input = self._get_branch_control_input()
1955            return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
1956                                                                       self.obfuscate_config['obf_ratio'],
1957                                                                       branch_control_input)
1958        return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
1959
1960    def get_optimize_graph_proto(self, obj):
1961        """Return optimize graph binary proto."""
1962        exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1963        if self._graph_executor.has_compiled(exec_id) is False:
1964            return None
1965        graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
1966        if isinstance(graph_proto, str) and graph_proto == "":
1967            logger.warning("Can not get optimize graph proto. Instead, try to find function graph.")
1968            graph_proto = obj.get_func_graph_proto()
1969        return graph_proto
1970
1971    def export(self, file_name, graph_id, enc_key=None, encrypt_func=None):
1972        """
1973        Export graph.
1974
1975        Args:
1976            file_name (str): File name of model to export
1977            graph_id (str): id of graph to be exported
1978        """
1979        self._graph_executor.export_graph(file_name, graph_id, encrypt_func, enc_key)
1980
1981
1982def ms_memory_recycle():
1983    """
1984    Recycle memory used by MindSpore.
1985    When train multi Neural network models in one process, memory used by MindSpore is very large,
1986    this is because MindSpore cached runtime memory for every model.
1987    To recycle these cached memory, users can call this function after training of one model.
1988
1989    Examples:
1990        >>> import mindspore as ms
1991        >>> ms.ms_memory_recycle()
1992    """
1993    if ms_compile_cache:
1994        _cell_graph_executor.del_net_res(None, ms_compile_cache)
1995        ms_compile_cache.clear()
1996    for cell_cache in cells_compile_cache.values():
1997        if cell_cache:
1998            _cell_graph_executor.del_net_res(None, cell_cache)
1999            cell_cache.clear()
2000    _ms_memory_recycle()
2001
2002
2003def _generate_branch_control_input(obf_random_seed):
2004    """Generate append network input for dynamic obfuscation in random seed mode."""
2005    seed_max = 2 ** 32 - 1
2006    int_max = 2 ** 31 - 1
2007    np.random.seed(obf_random_seed % seed_max)
2008    # generate a string as hash function inputs
2009    word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
2010    repo_len = len(word_repo)
2011    sha_string = ''
2012    string_len = 1024 * 1024
2013    for _ in range(string_len):
2014        rand_index = np.random.randint(0, repo_len)
2015        sha_string += word_repo[rand_index]
2016    # get hash result
2017    sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest()  # len is 64
2018    branch_control_input = 1
2019    hex_base = 16
2020    for item in sha_result:
2021        if int(item, hex_base) > 0:
2022            branch_control_input *= int(item, hex_base)
2023    branch_control_input %= int_max
2024    return branch_control_input
2025
2026
2027def _bind_device_context():
2028    """Bind device context to current thread"""
2029    _bind_device_ctx()
2030
2031
2032def flops_collection(phase='train'):
2033    """
2034    Recycle memory used by MindSpore.
2035    When train multi Neural network models in one process, memory used by MindSpore is very large,
2036    this is because MindSpore cached runtime memory for every model.
2037    To recycle these cached memory, users can call this function after training of one model.
2038
2039    Examples:
2040        >>> import mindspore as ms
2041        >>> ms.ms_memory_recycle()
2042    """
2043    return _cell_graph_executor.flops_collection(phase)
2044
2045
2046_cell_graph_executor = _CellGraphExecutor()
2047_pynative_executor = _PyNativeExecutor()
2048
2049__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
2050