• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Providing interface methods."""
18import types
19import sys
20import os
21import time
22from collections import OrderedDict
23from functools import wraps
24
25from mindspore import context
26from mindspore import log as logger
27from mindspore._extends.remote import kernel_build_server
28from .tensor import Tensor as MsTensor
29from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, PynativeExecutor_
30from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
31from ..parallel._ps_context import _is_role_pserver
32from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
33    _get_parameter_broadcast, _get_pipeline_stages
34
35# store ms_function class compiled pipeline cache
36ms_compile_cache = {}
37
38BROADCAST_PHASE = "_broadcast_"
39
40
41def _convert_function_arguments(fn, *args):
42    """
43    Process the fn default parameters.
44
45    Args:
46        fn (Function): The function to be parsed.
47        args (tuple): The parameters of the function.
48    """
49    arguments_dict = OrderedDict()
50    parse_method = None
51    if isinstance(fn, (types.FunctionType, types.MethodType)):
52        parse_method = fn.__name__
53        index = 0
54        for value in args:
55            arguments_dict[f'arg{index}'] = value
56            index = index + 1
57        logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
58        converted = True
59    else:
60        logger.warning("Find error: fn isn't function or method")
61        converted = False
62    return converted, arguments_dict, parse_method
63
64
65def _wrap_func(fn):
66    """
67    Wrapper function, convert return data to tensor or tuple of tensor.
68
69    Args:
70        fn (Function): The function need be wrapped.
71
72    Returns:
73        Function, a new function with return suitable format data.
74    """
75
76    @wraps(fn)
77    def wrapper(*arg, **kwargs):
78        results = fn(*arg, **kwargs)
79
80        def _convert_data(data):
81            if isinstance(data, Tensor) and not isinstance(data, MsTensor):
82                return MsTensor(data)
83            if isinstance(data, tuple):
84                return tuple(_convert_data(x) for x in data)
85            if isinstance(data, list):
86                return list(_convert_data(x) for x in data)
87            return data
88
89        return _convert_data(results)
90
91    return wrapper
92
93
94def _exec_init_graph(obj, init_phase):
95    """Execute the parameter initializer graph."""
96    inst_executor = GraphExecutor_.get_instance()
97    param_dict = OrderedDict()
98    for name, param in obj.parameters_dict().items():
99        if not param.is_init:
100            param_dict[name] = param
101            param.is_init = True
102            param.data.init_flag = True
103
104    if param_dict:
105        inst_executor.run_init_graph(param_dict, init_phase)
106
107
108class _MindsporeFunctionExecutor:
109    """
110    Represents a function compiled by graph compiler.
111
112    _MindsporeFunctionExecutor will compile the original function for every combination
113    of argument types and shapes it is given (as well as their values, optionally).
114
115    Args:
116        fn (Function): The root function to compile.
117        input_signature (Function): User defines signature to verify input.
118        obj (Object): If function is a method, obj is the owner of function,
119             else, obj is none.
120
121    Returns:
122        The result of pipeline running in graph mode.
123    """
124
125    def __init__(self, fn, input_signature=None, obj=None):
126        self.fn = fn
127        self.input_signature = input_signature
128        self.obj = None
129        if hasattr(obj, fn.__name__):
130            self.obj = obj
131        self._graph_executor = GraphExecutor_.get_instance()
132        self._create_time = int(time.time() * 1e9)
133
134    def build_data_init_graph(self, graph_name):
135        """Build GE data graph and init graph for the given graph name."""
136        if self.obj is None:
137            logger.warning("Make sure parameter should not be used in function")
138            para_dict = OrderedDict()
139            self._graph_executor.build_data_graph(para_dict, graph_name)
140            return
141        self._graph_executor.build_data_graph(self.obj.parameters_dict(), graph_name,
142                                              self.obj.parameters_broadcast_dict())
143        init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
144        _exec_init_graph(self.obj, init_phase)
145
146    def compile(self, args_list, arg_names, method_name):
147        """Returns pipeline for the given args."""
148        # Verify the signature for both function and method
149        if self.input_signature is not None:
150            signatures = []
151            for sig_spec in self.input_signature:
152                if not isinstance(sig_spec, MetaTensor):
153                    raise TypeError("Input_signature is not MetaTensor")
154                signatures.append(sig_spec)
155            is_valid_input = verify_inputs_signature(signatures, args_list)
156            if not is_valid_input:
157                raise ValueError("Inputs is incompatible with input signature!")
158
159        dic = dict(zip(arg_names, args_list))
160        generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
161            str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
162        if _pynative_executor.grad_flag():
163            generate_name = generate_name + ".grad"
164        self.fn.__parse_method__ = method_name
165
166        # Add key with obj
167        if self.obj is not None:
168            if self.obj.__module__ != self.fn.__module__:
169                logger.error(f'`obj` module not equal to `fn` module: {self.obj.__module__}, {self.fn.__module__}')
170            self.obj.__parse_method__ = method_name
171            generate_name = generate_name + '.' + str(self.obj.create_time) + '.' + str(id(self.obj))
172        else:
173            # Different instance of same class may use same memory(means same obj_id) at diff times.
174            # To avoid unexpected phase matched, add create_time to generate_name.
175            generate_name = generate_name + '.' + str(self._create_time)
176
177        key = generate_arguments_key(dic)
178        phase = generate_name + '.' + str(key)
179        if phase not in ms_compile_cache.keys():
180            if self.obj is None:
181                is_compile = self._graph_executor.compile(self.fn, args_list, phase, True, "")
182            else:
183                is_compile = self._graph_executor.compile(self.obj, args_list, phase, True, "")
184            if not is_compile:
185                raise RuntimeError("Executor compile failed.")
186            if context.get_context("enable_ge"):
187                self.build_data_init_graph(phase)
188            ms_compile_cache[phase] = phase
189            return phase
190
191        return phase
192
193    @_wrap_func
194    def __call__(self, *args):
195        init_pipeline()
196        converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
197        if not converted:
198            raise RuntimeError('Process function parameter is failure')
199
200        args_list = tuple(arguments_dict.values())
201        arg_names = tuple(arguments_dict.keys())
202        if self.obj is not None:
203            args_list = args_list[1:]
204            arg_names = arg_names[1:]
205
206        phase = self.compile(args_list, arg_names, parse_method)
207
208        if context.get_context("precompile_only"):
209            return None
210        new_inputs = []
211        for i in args_list:
212            if isinstance(i, Tensor):
213                new_inputs.append(i)
214            elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
215                new_inputs.append(i)
216        output = self._graph_executor(tuple(new_inputs), phase)
217
218        if context.get_context("mode") == context.PYNATIVE_MODE:
219            _pynative_executor.set_graph_phase(phase)
220            output = _pynative_executor.grad_ms_function(output, *new_inputs)
221
222        return output
223
224
225def ms_function(fn=None, obj=None, input_signature=None):
226    """
227    Create a callable MindSpore graph from a Python function.
228
229    This allows the MindSpore runtime to apply optimizations based on graph.
230
231    Args:
232        fn (Function): The Python function that will be run as a graph. Default: None.
233        obj (Object): The Python object is used to distinguish the compiled function. Default: None.
234        input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
235            will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
236            And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
237            keep the same as input_signature. Otherwise, TypeError will be raised. Default: None.
238
239    Returns:
240        Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
241        None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
242        equal to the case when `fn` is not None.
243
244    Supported Platforms:
245        ``Ascend`` ``GPU`` ``CPU``
246
247    Examples:
248        >>> import numpy as np
249        >>> from mindspore import Tensor
250        >>> from mindspore import ms_function
251        ...
252        >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
253        >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
254        ...
255        >>> # create a callable MindSpore graph by calling ms_function
256        >>> def tensor_add(x, y):
257        ...     z = x + y
258        ...     return z
259        ...
260        >>> tensor_add_graph = ms_function(fn=tensor_add)
261        >>> out = tensor_add_graph(x, y)
262        ...
263        >>> # create a callable MindSpore graph through decorator @ms_function
264        >>> @ms_function
265        ... def tensor_add_with_dec(x, y):
266        ...     z = x + y
267        ...     return z
268        ...
269        >>> out = tensor_add_with_dec(x, y)
270        ...
271        >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
272        >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
273        ...                               Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
274        ... def tensor_add_with_sig(x, y):
275        ...     z = x + y
276        ...     return z
277        ...
278        >>> out = tensor_add_with_sig(x, y)
279    """
280
281    def wrap_mindspore(func):
282        @wraps(func)
283        def staging_specialize(*args):
284            if obj is not None:
285                logger.warning("Obj is no longer in use, and the function's own object has been used to \
286                                distinguish whether it has been compiled.")
287            process_obj = None
288            if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
289                process_obj = args[0]
290            out = _MindsporeFunctionExecutor(func, input_signature, process_obj)(*args)
291            return out
292
293        return staging_specialize
294
295    if fn is not None:
296        return wrap_mindspore(fn)
297    return wrap_mindspore
298
299
300def _generate_pip_args(obj, *args, method="construct"):
301    """Generate arguments for pipeline."""
302    if hasattr(obj, method):
303        fn = getattr(obj, method)
304    else:
305        raise AttributeError('The process method is not exist')
306    converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
307    if not converted:
308        raise RuntimeError('Process method parameter is failure')
309    args_list = tuple(arguments_dict.values())
310    args_names = tuple(arguments_dict.keys())
311    obj.__parse_method__ = parse_method
312    return args_names, args_list
313
314
315def _get_auto_split_param_names(parameter_layout_dict):
316    auto_split_param_names = []
317    for key, value in parameter_layout_dict.items():
318        for dim in value[1]:
319            if dim != -1:
320                auto_split_param_names.append(key)
321                break
322    return auto_split_param_names
323
324
325def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
326    """Build broadcast graph."""
327    from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
328    if not broadcast_params_dict:
329        broadcast_params_dict = {}
330    broadcast_params = []
331    for param in broadcast_params_dict.values():
332        broadcast_params.append(Tensor(param.asnumpy()))
333    _broadcast_net = _BroadCastCell(broadcast_params)
334    _broadcast_net.phase = broadcast_phase
335    broadcasted_params = _broadcast_net()
336    for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
337        broadcast_params_dict[param_name].set_data(param)
338
339
340def _parameter_broadcast(obj, auto_parallel_mode):
341    """Parameter broadcast."""
342    auto_split_param_names = []
343    if auto_parallel_mode:
344        auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
345
346    broadcast_params_dict = obj.parameters_broadcast_dict()
347    if auto_split_param_names and broadcast_params_dict:
348        broadcast_params_dict = OrderedDict()
349        for param_name, param in obj.parameters_broadcast_dict().items():
350            if param_name not in auto_split_param_names:
351                broadcast_params_dict[param_name] = param
352    broadcast_phase = "_broadcast_subgraph"
353    _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
354
355
356class _PynativeExecutor:
357    """
358    A pynative executor used to compile/manage/run single op.
359
360    The main functions include: construct op graph, compile op graph, auto grad and run op graph.
361
362    Args:
363        obj (Object): The python network that will be run in pynative mode.
364        args (Tuple(Tensor...)): The inputs of network in tuple form.
365
366    Returns:
367        gradients (Tuple(Tensor...)): The gradients of network parameters and inputs.
368
369    Supported Platforms:
370        ``Ascend`` ``GPU`` ``CPU``
371    """
372
373    def __init__(self):
374        self._executor = PynativeExecutor_.get_instance()
375        self._executor.set_py_exe_path(sys.executable)
376        self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
377
378    def new_graph(self, obj, *args, **kwargs):
379        self._executor.new_graph(obj, *args, *(kwargs.values()))
380
381    def end_graph(self, obj, output, *args, **kwargs):
382        self._executor.end_graph(obj, output, *args, *(kwargs.values()))
383
384    def check_graph(self, obj, *args, **kwargs):
385        return self._executor.check_graph(obj, *args, *(kwargs.values()))
386
387    def check_run(self, grad, obj, *args, **kwargs):
388        return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
389
390    def grad(self, grad, obj, weights, *args, **kwargs):
391        self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
392
393    def del_cell(self, cell_id=""):
394        self._executor.clear_cell(cell_id)
395
396    def clear_res(self):
397        return self._executor.clear_res()
398
399    def clear_grad(self, obj, *args, **kwargs):
400        self._executor.clear_grad(obj, *args, *(kwargs.values()))
401
402    def sync(self):
403        self._executor.sync()
404
405    def set_lazy_build(self, enable):
406        self._executor.set_lazy_build(enable)
407
408    def execute_all_task(self):
409        self._executor.execute_all_task()
410
411    def grad_ms_function(self, output, *args):
412        return self._executor.grad_ms_function(output, *args)
413
414    def set_graph_phase(self, phase):
415        self._executor.set_graph_phase(phase)
416
417    def grad_flag(self):
418        return self._executor.grad_flag()
419
420    def set_grad_flag(self, flag):
421        self._executor.set_grad_flag(flag)
422
423    def enter_construct(self, cell):
424        self._executor.enter_construct(cell)
425
426    def leave_construct(self, cell):
427        self._executor.leave_construct(cell)
428
429    def parameter_broadcast(self, obj, phase, auto_parallel_mode):
430        if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
431            _parameter_broadcast(obj, auto_parallel_mode)
432
433    def enter_cell(self):
434        self._executor.enter_cell()
435
436    def exit_cell(self):
437        self._executor.exit_cell()
438
439    def is_top_cell(self):
440        return self._executor.is_top_cell()
441
442    def __call__(self, obj, *args, **kwargs):
443        args = args + tuple(kwargs.values())
444        return self._executor(obj, args)
445
446
447class _CellGraphExecutor:
448    """
449    An executor used to compile/manage/run graph for a Cell.
450
451    Including data_graph, train_graph, eval_graph and predict graph.
452
453    Args:
454        obj (Function/Cell): The function or cell instance need compile.
455        args (tuple): Function or cell input arguments.
456
457    Returns:
458        Graph, return the result of pipeline running.
459    """
460
461    def __init__(self):
462        # create needed graph by lazy mode
463        self.is_init = False
464        self._graph_executor = GraphExecutor_.get_instance()
465        self.compile_cache = {}
466        self._graph_executor.set_py_exe_path(sys.executable)
467        self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
468        self.queue_name = ""
469
470    def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
471                     input_indexs, phase='dataset'):
472        """
473        Initialization interface for calling data subgraph.
474
475        Args:
476            queue_name (str): The name of tdt queue on the device.
477            dataset_size (int): The size of dataset.
478            batch_size (int): The size of batch.
479            dataset_types (list): The output types of element in dataset.
480            dataset_shapes (list): The output shapes of element in dataset.
481            input_indexs (list): The index of data with net.
482            phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
483
484        Returns:
485            bool, specifies whether the data subgraph was initialized successfully.
486        """
487        if not init_exec_dataset(queue_name=queue_name,
488                                 size=dataset_size,
489                                 batch_size=batch_size,
490                                 types=dataset_types,
491                                 shapes=dataset_shapes,
492                                 input_indexs=input_indexs,
493                                 phase=phase):
494            raise RuntimeError("Failure to init and dataset subgraph!")
495        self.queue_name = queue_name
496        return True
497
498    def _build_data_graph(self, obj, phase):
499        self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
500
501    def _set_dataset_mode(self, args_list):
502        """set dataset mode."""
503        # decide whether to sink based on whether the inputs is virtual or args_list is ()
504        if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
505                (args_list is not None and args_list == ()):
506            _set_dataset_mode_config('sink')
507        else:
508            _set_dataset_mode_config('normal')
509
510    def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
511        """
512        Compiles graph.
513
514        Args:
515            obj (Function/Cell): The function or cell instance need compile.
516            args (tuple): Function or cell input arguments.
517            phase (str): The name of compile phase. Default: 'predict'.
518            do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
519            auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
520
521        Return:
522            Str, the full phase of the cell.
523            Bool, if the graph has been compiled before, return False, else return True.
524        """
525
526        args_names, args_list = _generate_pip_args(obj, *args)
527        dic = dict(zip(args_names, args_list))
528        key = generate_arguments_key(dic)
529        obj.arguments_key = str(key)
530        phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
531
532        if phase in self.compile_cache.keys():
533            logger.debug("%r graph has existed.", phase)
534            return phase, False
535
536        obj.check_names()
537        _check_full_batch()
538        self._set_dataset_mode(args_list)
539
540        is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
541        if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
542            args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
543            _, args_list = _generate_pip_args(obj, *args_full)
544
545        enable_debug_runtime = context.get_context("enable_debug_runtime")
546        enable_ge = context.get_context("enable_ge")
547        use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
548        result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name)
549        self.compile_cache[phase] = phase
550        if not result:
551            raise RuntimeError("Executor compile failed.")
552        graph = self._graph_executor.get_func_graph(phase)
553
554        if graph is None:
555            raise RuntimeError("Compile graph failed for phase {}.".format(phase))
556
557        self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args)
558
559        if not do_convert:
560            return phase, True
561
562        # the following GE init process is not needed when use vm or ms backend
563        if enable_ge:
564            self._build_data_graph(obj, phase)
565            if "export" not in phase:
566                init_phase = "init_subgraph." + str(obj.create_time) + "." + str(id(obj))
567                _exec_init_graph(obj, init_phase)
568        elif "export" in phase:
569            self._build_data_graph(obj, phase)
570        elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
571            _parameter_broadcast(obj, auto_parallel_mode)
572
573        return phase, True
574
575    def _auto_parallel_process(self, obj, phase, is_sink_mode, auto_parallel_mode, *args):
576        """compile graph in auto parallel mode."""
577        if not auto_parallel_mode:
578            replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
579            self._update_param_node_default_input(phase, replace)
580            return
581
582        obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
583        obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
584        replace = obj.init_parameters_data(auto_parallel_mode=True)
585        if _get_pipeline_stages() > 1 and (not hasattr(obj, "is_first_iteration") or not obj.is_first_iteration):
586            obj.remove_redundant_parameters()
587        if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"):
588            obj.load_parameter_slice(None)
589
590        self._update_param_node_default_input(phase, replace)
591
592        # set parallel inputs in sink mode
593        if is_sink_mode:
594            obj.set_parallel_input_with_inputs(*args)
595
596    def _update_param_node_default_input(self, phase, replace):
597        new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
598        return self._graph_executor.updata_param_node_default_input(phase, new_param)
599
600    def _get_shard_strategy(self, obj):
601        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
602        return self._graph_executor.get_strategy(real_phase)
603
604    def _get_num_parallel_ops(self, obj):
605        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
606        return self._graph_executor.get_num_parallel_ops(real_phase)
607
608    def _get_allreduce_fusion(self, obj):
609        real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
610        return self._graph_executor.get_allreduce_fusion(real_phase)
611
612    def has_compiled(self, phase='predict'):
613        """
614        Specify whether have been compiled.
615
616        Args:
617            phase (str): The phase name. Default: 'predict'.
618
619        Returns:
620            bool, specifies whether the specific graph has been compiled.
621        """
622        return self._graph_executor.has_compiled(phase)
623
624    def __call__(self, obj, *args, phase='predict'):
625        if context.get_context("precompile_only") or _is_role_pserver():
626            return None
627        return self.run(obj, *args, phase=phase)
628
629    @_wrap_func
630    def _exec_pip(self, obj, *args, phase=''):
631        """Execute the generated pipeline."""
632        fn = obj.construct
633        converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
634        if not converted:
635            raise RuntimeError('Process method parameter is failure')
636        args_list = tuple(arguments_dict.values())
637        obj.__parse_method__ = parse_method
638        return self._graph_executor(args_list, phase)
639
640    def run(self, obj, *args, phase='predict'):
641        """
642        Run the specific graph.
643
644        Args:
645            phase (str): The phase name. Default: 'predict'.
646
647        Returns:
648            Tensor/Tuple, return execute result.
649        """
650        if phase == 'save':
651            return self._graph_executor((), phase + '.' + str(obj.create_time) + '.' + str(id(obj)))
652
653        phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
654        if self.has_compiled(phase_real):
655            return self._exec_pip(obj, *args, phase=phase_real)
656        raise KeyError('{} graph is not exist.'.format(phase_real))
657
658    def del_net_res(self, net_id):
659        self._graph_executor.del_net_res(net_id)
660
661    def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
662        """Get graph proto from pipeline."""
663        if use_prefix:
664            exec_id = exec_id + '.' + obj.arguments_key
665        if self._graph_executor.has_compiled(exec_id) is False:
666            return None
667        return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
668
669    def export(self, file_name, graph_id):
670        """
671        Export graph.
672
673        Args:
674            file_name (str): File name of model to export
675            graph_id (str): id of graph to be exported
676        """
677        from .._c_expression import export_graph
678        export_graph(file_name, 'AIR', graph_id)
679
680    def fetch_info_for_quant_export(self, exec_id):
681        """Get graph proto from pipeline."""
682        if self._graph_executor.has_compiled(exec_id) is False:
683            return None
684        return self._graph_executor.fetch_info_for_quant_export(exec_id)
685
686
687_cell_graph_executor = _CellGraphExecutor()
688_pynative_executor = _PynativeExecutor()
689
690__all__ = ['ms_function']
691