1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2021 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""Providing interface methods.""" 18import types 19import sys 20import os 21import time 22from collections import OrderedDict 23from functools import wraps 24 25from mindspore import context 26from mindspore import log as logger 27from mindspore._extends.remote import kernel_build_server 28from .tensor import Tensor as MsTensor 29from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, PynativeExecutor_ 30from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline 31from ..parallel._ps_context import _is_role_pserver 32from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \ 33 _get_parameter_broadcast, _get_pipeline_stages 34 35# store ms_function class compiled pipeline cache 36ms_compile_cache = {} 37 38BROADCAST_PHASE = "_broadcast_" 39 40 41def _convert_function_arguments(fn, *args): 42 """ 43 Process the fn default parameters. 44 45 Args: 46 fn (Function): The function to be parsed. 47 args (tuple): The parameters of the function. 48 """ 49 arguments_dict = OrderedDict() 50 parse_method = None 51 if isinstance(fn, (types.FunctionType, types.MethodType)): 52 parse_method = fn.__name__ 53 index = 0 54 for value in args: 55 arguments_dict[f'arg{index}'] = value 56 index = index + 1 57 logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict) 58 converted = True 59 else: 60 logger.warning("Find error: fn isn't function or method") 61 converted = False 62 return converted, arguments_dict, parse_method 63 64 65def _wrap_func(fn): 66 """ 67 Wrapper function, convert return data to tensor or tuple of tensor. 68 69 Args: 70 fn (Function): The function need be wrapped. 71 72 Returns: 73 Function, a new function with return suitable format data. 74 """ 75 76 @wraps(fn) 77 def wrapper(*arg, **kwargs): 78 results = fn(*arg, **kwargs) 79 80 def _convert_data(data): 81 if isinstance(data, Tensor) and not isinstance(data, MsTensor): 82 return MsTensor(data) 83 if isinstance(data, tuple): 84 return tuple(_convert_data(x) for x in data) 85 if isinstance(data, list): 86 return list(_convert_data(x) for x in data) 87 return data 88 89 return _convert_data(results) 90 91 return wrapper 92 93 94def _exec_init_graph(obj, init_phase): 95 """Execute the parameter initializer graph.""" 96 inst_executor = GraphExecutor_.get_instance() 97 param_dict = OrderedDict() 98 for name, param in obj.parameters_dict().items(): 99 if not param.is_init: 100 param_dict[name] = param 101 param.is_init = True 102 param.data.init_flag = True 103 104 if param_dict: 105 inst_executor.run_init_graph(param_dict, init_phase) 106 107 108class _MindsporeFunctionExecutor: 109 """ 110 Represents a function compiled by graph compiler. 111 112 _MindsporeFunctionExecutor will compile the original function for every combination 113 of argument types and shapes it is given (as well as their values, optionally). 114 115 Args: 116 fn (Function): The root function to compile. 117 input_signature (Function): User defines signature to verify input. 118 obj (Object): If function is a method, obj is the owner of function, 119 else, obj is none. 120 121 Returns: 122 The result of pipeline running in graph mode. 123 """ 124 125 def __init__(self, fn, input_signature=None, obj=None): 126 self.fn = fn 127 self.input_signature = input_signature 128 self.obj = None 129 if hasattr(obj, fn.__name__): 130 self.obj = obj 131 self._graph_executor = GraphExecutor_.get_instance() 132 self._create_time = int(time.time() * 1e9) 133 134 def build_data_init_graph(self, graph_name): 135 """Build GE data graph and init graph for the given graph name.""" 136 if self.obj is None: 137 logger.warning("Make sure parameter should not be used in function") 138 para_dict = OrderedDict() 139 self._graph_executor.build_data_graph(para_dict, graph_name) 140 return 141 self._graph_executor.build_data_graph(self.obj.parameters_dict(), graph_name, 142 self.obj.parameters_broadcast_dict()) 143 init_phase = "init_subgraph" + graph_name[graph_name.find("."):] 144 _exec_init_graph(self.obj, init_phase) 145 146 def compile(self, args_list, arg_names, method_name): 147 """Returns pipeline for the given args.""" 148 # Verify the signature for both function and method 149 if self.input_signature is not None: 150 signatures = [] 151 for sig_spec in self.input_signature: 152 if not isinstance(sig_spec, MetaTensor): 153 raise TypeError("Input_signature is not MetaTensor") 154 signatures.append(sig_spec) 155 is_valid_input = verify_inputs_signature(signatures, args_list) 156 if not is_valid_input: 157 raise ValueError("Inputs is incompatible with input signature!") 158 159 dic = dict(zip(arg_names, args_list)) 160 generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \ 161 str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn)) 162 if _pynative_executor.grad_flag(): 163 generate_name = generate_name + ".grad" 164 self.fn.__parse_method__ = method_name 165 166 # Add key with obj 167 if self.obj is not None: 168 if self.obj.__module__ != self.fn.__module__: 169 logger.error(f'`obj` module not equal to `fn` module: {self.obj.__module__}, {self.fn.__module__}') 170 self.obj.__parse_method__ = method_name 171 generate_name = generate_name + '.' + str(self.obj.create_time) + '.' + str(id(self.obj)) 172 else: 173 # Different instance of same class may use same memory(means same obj_id) at diff times. 174 # To avoid unexpected phase matched, add create_time to generate_name. 175 generate_name = generate_name + '.' + str(self._create_time) 176 177 key = generate_arguments_key(dic) 178 phase = generate_name + '.' + str(key) 179 if phase not in ms_compile_cache.keys(): 180 if self.obj is None: 181 is_compile = self._graph_executor.compile(self.fn, args_list, phase, True, "") 182 else: 183 is_compile = self._graph_executor.compile(self.obj, args_list, phase, True, "") 184 if not is_compile: 185 raise RuntimeError("Executor compile failed.") 186 if context.get_context("enable_ge"): 187 self.build_data_init_graph(phase) 188 ms_compile_cache[phase] = phase 189 return phase 190 191 return phase 192 193 @_wrap_func 194 def __call__(self, *args): 195 init_pipeline() 196 converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args) 197 if not converted: 198 raise RuntimeError('Process function parameter is failure') 199 200 args_list = tuple(arguments_dict.values()) 201 arg_names = tuple(arguments_dict.keys()) 202 if self.obj is not None: 203 args_list = args_list[1:] 204 arg_names = arg_names[1:] 205 206 phase = self.compile(args_list, arg_names, parse_method) 207 208 if context.get_context("precompile_only"): 209 return None 210 new_inputs = [] 211 for i in args_list: 212 if isinstance(i, Tensor): 213 new_inputs.append(i) 214 elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): 215 new_inputs.append(i) 216 output = self._graph_executor(tuple(new_inputs), phase) 217 218 if context.get_context("mode") == context.PYNATIVE_MODE: 219 _pynative_executor.set_graph_phase(phase) 220 output = _pynative_executor.grad_ms_function(output, *new_inputs) 221 222 return output 223 224 225def ms_function(fn=None, obj=None, input_signature=None): 226 """ 227 Create a callable MindSpore graph from a Python function. 228 229 This allows the MindSpore runtime to apply optimizations based on graph. 230 231 Args: 232 fn (Function): The Python function that will be run as a graph. Default: None. 233 obj (Object): The Python object is used to distinguish the compiled function. Default: None. 234 input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor 235 will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`. 236 And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should 237 keep the same as input_signature. Otherwise, TypeError will be raised. Default: None. 238 239 Returns: 240 Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is 241 None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is 242 equal to the case when `fn` is not None. 243 244 Supported Platforms: 245 ``Ascend`` ``GPU`` ``CPU`` 246 247 Examples: 248 >>> import numpy as np 249 >>> from mindspore import Tensor 250 >>> from mindspore import ms_function 251 ... 252 >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 253 >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 254 ... 255 >>> # create a callable MindSpore graph by calling ms_function 256 >>> def tensor_add(x, y): 257 ... z = x + y 258 ... return z 259 ... 260 >>> tensor_add_graph = ms_function(fn=tensor_add) 261 >>> out = tensor_add_graph(x, y) 262 ... 263 >>> # create a callable MindSpore graph through decorator @ms_function 264 >>> @ms_function 265 ... def tensor_add_with_dec(x, y): 266 ... z = x + y 267 ... return z 268 ... 269 >>> out = tensor_add_with_dec(x, y) 270 ... 271 >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter 272 >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), 273 ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) 274 ... def tensor_add_with_sig(x, y): 275 ... z = x + y 276 ... return z 277 ... 278 >>> out = tensor_add_with_sig(x, y) 279 """ 280 281 def wrap_mindspore(func): 282 @wraps(func) 283 def staging_specialize(*args): 284 if obj is not None: 285 logger.warning("Obj is no longer in use, and the function's own object has been used to \ 286 distinguish whether it has been compiled.") 287 process_obj = None 288 if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__): 289 process_obj = args[0] 290 out = _MindsporeFunctionExecutor(func, input_signature, process_obj)(*args) 291 return out 292 293 return staging_specialize 294 295 if fn is not None: 296 return wrap_mindspore(fn) 297 return wrap_mindspore 298 299 300def _generate_pip_args(obj, *args, method="construct"): 301 """Generate arguments for pipeline.""" 302 if hasattr(obj, method): 303 fn = getattr(obj, method) 304 else: 305 raise AttributeError('The process method is not exist') 306 converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args) 307 if not converted: 308 raise RuntimeError('Process method parameter is failure') 309 args_list = tuple(arguments_dict.values()) 310 args_names = tuple(arguments_dict.keys()) 311 obj.__parse_method__ = parse_method 312 return args_names, args_list 313 314 315def _get_auto_split_param_names(parameter_layout_dict): 316 auto_split_param_names = [] 317 for key, value in parameter_layout_dict.items(): 318 for dim in value[1]: 319 if dim != -1: 320 auto_split_param_names.append(key) 321 break 322 return auto_split_param_names 323 324 325def _build_broadcast_graph(broadcast_params_dict, broadcast_phase): 326 """Build broadcast graph.""" 327 from mindspore.nn.wrap.cell_wrapper import _BroadCastCell 328 if not broadcast_params_dict: 329 broadcast_params_dict = {} 330 broadcast_params = [] 331 for param in broadcast_params_dict.values(): 332 broadcast_params.append(Tensor(param.asnumpy())) 333 _broadcast_net = _BroadCastCell(broadcast_params) 334 _broadcast_net.phase = broadcast_phase 335 broadcasted_params = _broadcast_net() 336 for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): 337 broadcast_params_dict[param_name].set_data(param) 338 339 340def _parameter_broadcast(obj, auto_parallel_mode): 341 """Parameter broadcast.""" 342 auto_split_param_names = [] 343 if auto_parallel_mode: 344 auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict) 345 346 broadcast_params_dict = obj.parameters_broadcast_dict() 347 if auto_split_param_names and broadcast_params_dict: 348 broadcast_params_dict = OrderedDict() 349 for param_name, param in obj.parameters_broadcast_dict().items(): 350 if param_name not in auto_split_param_names: 351 broadcast_params_dict[param_name] = param 352 broadcast_phase = "_broadcast_subgraph" 353 _build_broadcast_graph(broadcast_params_dict, broadcast_phase) 354 355 356class _PynativeExecutor: 357 """ 358 A pynative executor used to compile/manage/run single op. 359 360 The main functions include: construct op graph, compile op graph, auto grad and run op graph. 361 362 Args: 363 obj (Object): The python network that will be run in pynative mode. 364 args (Tuple(Tensor...)): The inputs of network in tuple form. 365 366 Returns: 367 gradients (Tuple(Tensor...)): The gradients of network parameters and inputs. 368 369 Supported Platforms: 370 ``Ascend`` ``GPU`` ``CPU`` 371 """ 372 373 def __init__(self): 374 self._executor = PynativeExecutor_.get_instance() 375 self._executor.set_py_exe_path(sys.executable) 376 self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) 377 378 def new_graph(self, obj, *args, **kwargs): 379 self._executor.new_graph(obj, *args, *(kwargs.values())) 380 381 def end_graph(self, obj, output, *args, **kwargs): 382 self._executor.end_graph(obj, output, *args, *(kwargs.values())) 383 384 def check_graph(self, obj, *args, **kwargs): 385 return self._executor.check_graph(obj, *args, *(kwargs.values())) 386 387 def check_run(self, grad, obj, *args, **kwargs): 388 return self._executor.check_run(grad, obj, *args, *(kwargs.values())) 389 390 def grad(self, grad, obj, weights, *args, **kwargs): 391 self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) 392 393 def del_cell(self, cell_id=""): 394 self._executor.clear_cell(cell_id) 395 396 def clear_res(self): 397 return self._executor.clear_res() 398 399 def clear_grad(self, obj, *args, **kwargs): 400 self._executor.clear_grad(obj, *args, *(kwargs.values())) 401 402 def sync(self): 403 self._executor.sync() 404 405 def set_lazy_build(self, enable): 406 self._executor.set_lazy_build(enable) 407 408 def execute_all_task(self): 409 self._executor.execute_all_task() 410 411 def grad_ms_function(self, output, *args): 412 return self._executor.grad_ms_function(output, *args) 413 414 def set_graph_phase(self, phase): 415 self._executor.set_graph_phase(phase) 416 417 def grad_flag(self): 418 return self._executor.grad_flag() 419 420 def set_grad_flag(self, flag): 421 self._executor.set_grad_flag(flag) 422 423 def enter_construct(self, cell): 424 self._executor.enter_construct(cell) 425 426 def leave_construct(self, cell): 427 self._executor.leave_construct(cell) 428 429 def parameter_broadcast(self, obj, phase, auto_parallel_mode): 430 if BROADCAST_PHASE not in phase and _get_parameter_broadcast(): 431 _parameter_broadcast(obj, auto_parallel_mode) 432 433 def enter_cell(self): 434 self._executor.enter_cell() 435 436 def exit_cell(self): 437 self._executor.exit_cell() 438 439 def is_top_cell(self): 440 return self._executor.is_top_cell() 441 442 def __call__(self, obj, *args, **kwargs): 443 args = args + tuple(kwargs.values()) 444 return self._executor(obj, args) 445 446 447class _CellGraphExecutor: 448 """ 449 An executor used to compile/manage/run graph for a Cell. 450 451 Including data_graph, train_graph, eval_graph and predict graph. 452 453 Args: 454 obj (Function/Cell): The function or cell instance need compile. 455 args (tuple): Function or cell input arguments. 456 457 Returns: 458 Graph, return the result of pipeline running. 459 """ 460 461 def __init__(self): 462 # create needed graph by lazy mode 463 self.is_init = False 464 self._graph_executor = GraphExecutor_.get_instance() 465 self.compile_cache = {} 466 self._graph_executor.set_py_exe_path(sys.executable) 467 self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) 468 self.queue_name = "" 469 470 def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes, 471 input_indexs, phase='dataset'): 472 """ 473 Initialization interface for calling data subgraph. 474 475 Args: 476 queue_name (str): The name of tdt queue on the device. 477 dataset_size (int): The size of dataset. 478 batch_size (int): The size of batch. 479 dataset_types (list): The output types of element in dataset. 480 dataset_shapes (list): The output shapes of element in dataset. 481 input_indexs (list): The index of data with net. 482 phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'. 483 484 Returns: 485 bool, specifies whether the data subgraph was initialized successfully. 486 """ 487 if not init_exec_dataset(queue_name=queue_name, 488 size=dataset_size, 489 batch_size=batch_size, 490 types=dataset_types, 491 shapes=dataset_shapes, 492 input_indexs=input_indexs, 493 phase=phase): 494 raise RuntimeError("Failure to init and dataset subgraph!") 495 self.queue_name = queue_name 496 return True 497 498 def _build_data_graph(self, obj, phase): 499 self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) 500 501 def _set_dataset_mode(self, args_list): 502 """set dataset mode.""" 503 # decide whether to sink based on whether the inputs is virtual or args_list is () 504 if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \ 505 (args_list is not None and args_list == ()): 506 _set_dataset_mode_config('sink') 507 else: 508 _set_dataset_mode_config('normal') 509 510 def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False): 511 """ 512 Compiles graph. 513 514 Args: 515 obj (Function/Cell): The function or cell instance need compile. 516 args (tuple): Function or cell input arguments. 517 phase (str): The name of compile phase. Default: 'predict'. 518 do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. 519 auto_parallel_mode: When set to True, use auto parallel mode to compile graph. 520 521 Return: 522 Str, the full phase of the cell. 523 Bool, if the graph has been compiled before, return False, else return True. 524 """ 525 526 args_names, args_list = _generate_pip_args(obj, *args) 527 dic = dict(zip(args_names, args_list)) 528 key = generate_arguments_key(dic) 529 obj.arguments_key = str(key) 530 phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 531 532 if phase in self.compile_cache.keys(): 533 logger.debug("%r graph has existed.", phase) 534 return phase, False 535 536 obj.check_names() 537 _check_full_batch() 538 self._set_dataset_mode(args_list) 539 540 is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag 541 if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run(): 542 args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank()) 543 _, args_list = _generate_pip_args(obj, *args_full) 544 545 enable_debug_runtime = context.get_context("enable_debug_runtime") 546 enable_ge = context.get_context("enable_ge") 547 use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) 548 result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name) 549 self.compile_cache[phase] = phase 550 if not result: 551 raise RuntimeError("Executor compile failed.") 552 graph = self._graph_executor.get_func_graph(phase) 553 554 if graph is None: 555 raise RuntimeError("Compile graph failed for phase {}.".format(phase)) 556 557 self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args) 558 559 if not do_convert: 560 return phase, True 561 562 # the following GE init process is not needed when use vm or ms backend 563 if enable_ge: 564 self._build_data_graph(obj, phase) 565 if "export" not in phase: 566 init_phase = "init_subgraph." + str(obj.create_time) + "." + str(id(obj)) 567 _exec_init_graph(obj, init_phase) 568 elif "export" in phase: 569 self._build_data_graph(obj, phase) 570 elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): 571 _parameter_broadcast(obj, auto_parallel_mode) 572 573 return phase, True 574 575 def _auto_parallel_process(self, obj, phase, is_sink_mode, auto_parallel_mode, *args): 576 """compile graph in auto parallel mode.""" 577 if not auto_parallel_mode: 578 replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) 579 self._update_param_node_default_input(phase, replace) 580 return 581 582 obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase) 583 obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase) 584 replace = obj.init_parameters_data(auto_parallel_mode=True) 585 if _get_pipeline_stages() > 1 and (not hasattr(obj, "is_first_iteration") or not obj.is_first_iteration): 586 obj.remove_redundant_parameters() 587 if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"): 588 obj.load_parameter_slice(None) 589 590 self._update_param_node_default_input(phase, replace) 591 592 # set parallel inputs in sink mode 593 if is_sink_mode: 594 obj.set_parallel_input_with_inputs(*args) 595 596 def _update_param_node_default_input(self, phase, replace): 597 new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} 598 return self._graph_executor.updata_param_node_default_input(phase, new_param) 599 600 def _get_shard_strategy(self, obj): 601 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 602 return self._graph_executor.get_strategy(real_phase) 603 604 def _get_num_parallel_ops(self, obj): 605 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 606 return self._graph_executor.get_num_parallel_ops(real_phase) 607 608 def _get_allreduce_fusion(self, obj): 609 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 610 return self._graph_executor.get_allreduce_fusion(real_phase) 611 612 def has_compiled(self, phase='predict'): 613 """ 614 Specify whether have been compiled. 615 616 Args: 617 phase (str): The phase name. Default: 'predict'. 618 619 Returns: 620 bool, specifies whether the specific graph has been compiled. 621 """ 622 return self._graph_executor.has_compiled(phase) 623 624 def __call__(self, obj, *args, phase='predict'): 625 if context.get_context("precompile_only") or _is_role_pserver(): 626 return None 627 return self.run(obj, *args, phase=phase) 628 629 @_wrap_func 630 def _exec_pip(self, obj, *args, phase=''): 631 """Execute the generated pipeline.""" 632 fn = obj.construct 633 converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args) 634 if not converted: 635 raise RuntimeError('Process method parameter is failure') 636 args_list = tuple(arguments_dict.values()) 637 obj.__parse_method__ = parse_method 638 return self._graph_executor(args_list, phase) 639 640 def run(self, obj, *args, phase='predict'): 641 """ 642 Run the specific graph. 643 644 Args: 645 phase (str): The phase name. Default: 'predict'. 646 647 Returns: 648 Tensor/Tuple, return execute result. 649 """ 650 if phase == 'save': 651 return self._graph_executor((), phase + '.' + str(obj.create_time) + '.' + str(id(obj))) 652 653 phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 654 if self.has_compiled(phase_real): 655 return self._exec_pip(obj, *args, phase=phase_real) 656 raise KeyError('{} graph is not exist.'.format(phase_real)) 657 658 def del_net_res(self, net_id): 659 self._graph_executor.del_net_res(net_id) 660 661 def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False): 662 """Get graph proto from pipeline.""" 663 if use_prefix: 664 exec_id = exec_id + '.' + obj.arguments_key 665 if self._graph_executor.has_compiled(exec_id) is False: 666 return None 667 return self._graph_executor.get_func_graph_proto(exec_id, ir_type) 668 669 def export(self, file_name, graph_id): 670 """ 671 Export graph. 672 673 Args: 674 file_name (str): File name of model to export 675 graph_id (str): id of graph to be exported 676 """ 677 from .._c_expression import export_graph 678 export_graph(file_name, 'AIR', graph_id) 679 680 def fetch_info_for_quant_export(self, exec_id): 681 """Get graph proto from pipeline.""" 682 if self._graph_executor.has_compiled(exec_id) is False: 683 return None 684 return self._graph_executor.fetch_info_for_quant_export(exec_id) 685 686 687_cell_graph_executor = _CellGraphExecutor() 688_pynative_executor = _PynativeExecutor() 689 690__all__ = ['ms_function'] 691