1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2024 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""Providing interface methods.""" 18from __future__ import absolute_import 19 20import types 21import sys 22import os 23import time 24import ast 25import inspect 26import importlib 27import hashlib 28import contextlib 29from collections import OrderedDict, namedtuple 30from functools import wraps 31import numpy as np 32import mindspore as ms 33from mindspore import context 34from mindspore import log as logger 35from mindspore._extends.remote import kernel_build_server 36from mindspore.common.jit_config import JitConfig 37from mindspore.common.tensor import Tensor as PythonTensor 38from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor 39from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor 40from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor 41from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \ 42 PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \ 43 _ms_memory_recycle, _bind_device_ctx, jit_mode_pi_enable, jit_mode_pi_compile 44from mindspore.parallel._ps_context import _is_role_sched 45from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \ 46 _is_in_auto_parallel_mode 47from mindspore import _checkparam as Validator 48from mindspore._checkparam import is_stub_tensor 49from mindspore.common._utils import is_shape_unknown 50from mindspore.common.mutable import mutable 51from mindspore.common._register_for_adapter import ms_adapter_registry 52from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \ 53 get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature 54 55# Store ms_function class compiled pipeline cache. 56ms_compile_cache = set() 57# Store cell compiled pipeline cache. 58cells_compile_cache = {} 59# Store function compiled times information. 60function_phases = dict() 61 62BROADCAST_PHASE = "_broadcast_" 63_PYNATIVE_PARALLEL_FUNC_NAME = "after_shard" 64 65ARG_SPECIFIED = "arg_specified_infos" 66TOTAL_ARG_LEN = "total_arg_length" 67 68 69def _check_recompile_args(compile_args, kwargs): 70 """Check recompile of graph""" 71 72 def _check_constant_tensor_arg(arg): 73 if hasattr(arg, "__ms_mutable__"): 74 return False 75 if isinstance(arg, (list, tuple)): 76 return any(_check_constant_tensor_arg(x) for x in arg) 77 return isinstance(arg, Tensor) 78 79 for v in kwargs.values(): 80 compile_args += (v,) 81 for arg in compile_args: 82 if not isinstance(arg, tuple) and not isinstance(arg, list): 83 continue 84 if _check_constant_tensor_arg(arg): 85 logger.warning(f"Constant value tensor are detected in tuple or list, which might cause recompiling " 86 f"when tensor value changes. You can use mutable(Tensor) or mutable(tuple(Tensor)) " 87 f"to set tensor's value as variable to to avoid recompiling. The tuple or list arg " 88 f"is: {arg} .") 89 return 90 91 92def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time, echo_function_name): 93 """Warning when the function has been compiled.""" 94 ignore_dirs = ["mindspore/ops", "mindspore/nn"] 95 if any((lambda x: x in full_function_name)(x) for x in ignore_dirs): 96 return 97 98 if full_function_name in function_phases: 99 warning_times = 1 100 if len(function_phases[full_function_name]) >= warning_times \ 101 and create_time not in function_phases[full_function_name]: 102 if isinstance(obj, ms.nn.Cell): 103 tips = f"Please try to create {echo_function_name} instance only once to avoid recompiling. " 104 logger.info(f"The {echo_function_name} has been compiled again. " 105 f"{tips} ") 106 else: 107 tips = "Try to decorate the function with @jit(hash_args=...) " \ 108 "or @jit(compile_once=True) to reduce the compile time. " \ 109 "For more details, get instructions about `jit` at " \ 110 "https://www.mindspore.cn/search?inputValue=jit." 111 logger.warning(f"The {echo_function_name} has been compiled again. " 112 f"{tips} ") 113 else: 114 _check_recompile_args(compile_args, kwargs) 115 else: 116 function_phases[full_function_name] = set() 117 function_phases[full_function_name].add(create_time) 118 119 120def _ms_adapter_tensor_as_parameter_output(data): 121 """Check whether the data is an output from a parameter which is a ms_adapter tensor. 122 Pylint: disable=unidiomatic-typecheck. 123 """ 124 return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \ 125 and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__") 126 127 128def _convert_python_data(data): 129 """ 130 Convert C++ data to python. 131 132 Args: 133 data : The data need be convert. 134 135 Returns: 136 data, a data convert C++ to python 137 """ 138 if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag: 139 return ms_adapter_registry.tensor(data) 140 if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"): 141 return data.tensor 142 if isinstance(data, Tensor) and not isinstance(data, PythonTensor): 143 return PythonTensor(data, internal=True) 144 if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor): 145 return PythonCSRTensor(csr_tensor=data) 146 if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor): 147 return PythonCOOTensor(coo_tensor=data) 148 if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor): 149 return PythonRowTensor(row_tensor=data) 150 if data.__class__ is tuple: 151 # Handle namedtuple since its type is tuple. 152 if hasattr(data, "_fields"): 153 type_name = data.__class__.__name__ 154 data_dict = data._asdict() 155 fields = data_dict.keys() 156 return namedtuple(type_name, fields)(**_convert_python_data(data_dict)) 157 return tuple(_convert_python_data(x) for x in data) 158 if data.__class__ is list: 159 # Keep list object not change for inplace operation. 160 for i in range(len(data)): 161 data[i] = _convert_python_data(data[i]) 162 return data 163 if data.__class__ is dict: 164 # Keep the dict object not change. 165 keys = tuple(data.keys()) 166 for key in keys: 167 data[_convert_python_data(key)] = _convert_python_data(data.pop(key)) 168 return data 169 return data 170 171 172def _wrap_func(fn): 173 """ 174 Wrapper function, convert return data to tensor or tuple of tensor. 175 176 Args: 177 fn (Function): The function need be wrapped. 178 179 Returns: 180 Function, a new function with return suitable format data. 181 """ 182 183 @wraps(fn) 184 def wrapper(*arg, **kwargs): 185 results = fn(*arg, **kwargs) 186 return _convert_python_data(results) 187 188 return wrapper 189 190 191def _check_all_tensor(sequence): 192 for element in sequence: 193 if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple) 194 and _check_all_tensor(element)): 195 return False 196 return True 197 198 199def _handle_func_args(func, *args, **kwargs): 200 """Handle the *args and **kwargs inputs of the function.""" 201 if not isinstance(func, (types.FunctionType, types.MethodType)): 202 raise RuntimeError('fn {} is not function or method'.format(func)) 203 if kwargs: 204 bound_arguments = inspect.signature(func).bind(*args, **kwargs) 205 bound_arguments.apply_defaults() 206 args = bound_arguments.args 207 kwargs = bound_arguments.kwargs 208 209 positional_args = 0 210 default_args = 0 211 has_var = False 212 for value in inspect.signature(func).parameters.values(): 213 if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD: 214 has_var = True 215 if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: 216 if value.default is inspect.Parameter.empty: 217 positional_args += 1 218 else: 219 default_args += 1 220 221 if has_var: 222 return args, kwargs 223 224 if len(args) < positional_args: 225 raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.") 226 if len(args) > positional_args + default_args: 227 raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} " 228 f"default argument, total {positional_args + default_args}, but got {len(args)}.") 229 return args, kwargs 230 231 232sys_path = list(sys.path) 233# Get the entry script path. 234entry_script_path = None 235if sys.argv and sys.argv[0] != '': 236 entry_script_path = os.path.realpath(sys.argv[0]) 237 entry_script_path_dir = os.path.split(entry_script_path)[0] 238 if entry_script_path_dir in sys_path: 239 sys_path.remove(entry_script_path_dir) 240 241 242def _in_sys_path(file_path): 243 for path in sys_path: 244 if file_path.startswith(path): 245 return True 246 return False 247 248 249def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg): 250 """Get the dependency files of the network""" 251 with open(file_path) as fh: 252 root = ast.parse(fh.read(), file_path) 253 for node in ast.iter_child_nodes(root): 254 module_name = "" 255 if isinstance(node, ast.ImportFrom): 256 if node.module is not None: 257 module_name = node.module 258 module_name = "." * node.level + module_name 259 elif not isinstance(node, ast.Import): 260 continue 261 # Do not care the files in mindspore package 262 if module_name.startswith("mindspore"): 263 continue 264 265 for n in node.names: 266 if n.name.startswith("mindspore"): 267 continue 268 if module_name == "": 269 whole_module = n.name 270 else: 271 whole_module = module_name 272 if n.name is not None: 273 whole_module += "." + n.name 274 try: 275 module_spec = importlib.util.find_spec(whole_module, pkg) 276 except (ModuleNotFoundError, ValueError): 277 whole_module = whole_module[0:whole_module.rfind('.')] 278 module_spec = importlib.util.find_spec(whole_module, pkg) 279 if module_spec is None: 280 continue 281 module = importlib.util.module_from_spec(module_spec) 282 if hasattr(module, '__file__'): 283 dep_file_path = module.__file__ 284 else: 285 continue 286 # Exclude the installed modules. 287 if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files: 288 logger.debug(f"dependent file path: {dep_file_path}") 289 compile_cache_dep_files.append(dep_file_path) 290 __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__) 291 292 293def _get_compile_cache_dep_files(): 294 """Get the dependency files of the network""" 295 if entry_script_path is None: 296 logger.warning("Can not get the entry script file path.") 297 return [] 298 compile_cache_dep_files = [] 299 logger.debug(f"entry script file path: {entry_script_path}") 300 compile_cache_dep_files.append(entry_script_path) 301 __get_compile_cache_dep_files(entry_script_path, compile_cache_dep_files, None) 302 return compile_cache_dep_files 303 304 305def _restore_mutable_attr(args_list, compile_args): 306 """Restore the mutable attr for every arg.""" 307 new_compile_args = () 308 for idx, arg in enumerate(args_list): 309 if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \ 310 not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")): 311 if hasattr(arg, "__ms_dynamic_len__"): 312 new_compile_args += (mutable(compile_args[idx], getattr(arg, "__ms_dynamic_len__")),) 313 else: 314 new_compile_args += (mutable(compile_args[idx], False),) 315 else: 316 new_compile_args += (compile_args[idx],) 317 return new_compile_args 318 319 320def _get_parameter_layout(): 321 graph_executor = GraphExecutor_.get_instance() 322 layout = dict() 323 for phase in ms_compile_cache: 324 layout.update(graph_executor.get_parameter_layout(phase)) 325 return layout 326 327 328def _handle_arg(obj, arg, compile_arg): 329 """Handle arg for runtime .If need handle the arg, return True""" 330 if isinstance(arg, PythonTensor): 331 if arg.has_init: 332 arg.init_data() 333 if not arg.const_arg: 334 return arg 335 elif isinstance(arg, (Tensor, CSRTensor, COOTensor)): 336 return arg 337 elif compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and getattr(compile_arg, "__ms_mutable__"): 338 # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend. 339 if isinstance(arg, list) and not arg: 340 return None 341 return arg 342 elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)): 343 return arg 344 elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \ 345 _check_all_tensor(arg): 346 return arg 347 return None 348 349 350def _handle_arg_predict(obj, arg, compile_arg): 351 """Handle arg for runtime .If need handle the arg, return True""" 352 if arg is None: 353 return None 354 355 if isinstance(arg, (int, float)): 356 return None 357 358 if isinstance(arg, (list, tuple)): 359 if compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \ 360 getattr(compile_arg, "__ms_mutable__"): 361 # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend. 362 if isinstance(arg, list) and not arg: 363 return None 364 return arg 365 if hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \ 366 _check_all_tensor(arg): 367 return arg 368 return None 369 return arg 370 371 372def _get_args_for_run(obj, args, kwargs, compile_args): 373 """Get the actual input args and kwargs for runtime.""" 374 new_args = [] 375 for arg, compile_arg in zip(args, compile_args): 376 new_arg = _handle_arg(obj, arg, compile_arg) 377 if new_arg is not None: 378 new_args.append(new_arg) 379 380 for _, value in kwargs.items(): 381 new_value = _handle_arg(obj, value, None) 382 if new_value is not None: 383 new_args.append(new_value) 384 385 return new_args 386 387 388def _get_args_for_run_predict(obj, args, kwargs, compile_args): 389 """Get the actual input args and kwargs for runtime.""" 390 new_args = [] 391 for arg, compile_arg in zip(args, compile_args): 392 new_arg = _handle_arg_predict(obj, arg, compile_arg) 393 if new_arg is not None: 394 new_args.append(new_arg) 395 396 for _, value in kwargs.items(): 397 new_value = _handle_arg_predict(obj, value, None) 398 if new_value is not None: 399 new_args.append(new_value) 400 401 return new_args 402 403 404def _is_args_fullmode(args, is_init=True): 405 """Check whether the arguments is for incremental-mode. 406 407 Args: 408 args (Union[list, tuple, dict, Tensor]): Given arguments. 409 is_init (bool): Is check in argument initialization phase. 410 411 Raises: 412 RuntimeError: loss necessary keys and values for incremental-mode. 413 414 Returns: 415 bool: Fullmode or not. 416 """ 417 if not isinstance(args, dict): 418 return True 419 if not is_init and (args.get(ARG_SPECIFIED, None) is None or args.get(TOTAL_ARG_LEN, None) is None): 420 raise RuntimeError( 421 "The incremental inputs should be processed(with \"%s\" and \"%s\"), but got %s." % 422 (ARG_SPECIFIED, TOTAL_ARG_LEN, str(args))) 423 return False 424 425 426def _process_dyn_args(fn, dyn_args): 427 """Process the dynamic arguments, return the necessary data for latter processing. 428 429 Args: 430 fn (Function): The root function to compile. 431 dyn_args (Union[dict, list, tuple, None]): Given arguments for dynamic compilation. 432 None for nothing, list or tuple for fullmode setting, dict for incremental configuration. 433 434 Returns: 435 A dict which contains args for dynamic compilation. None for nothing dynamic. 436 """ 437 if dyn_args is None: 438 # nothing should be done for None. 439 return dyn_args 440 441 if isinstance(dyn_args, dict) and ARG_SPECIFIED in dyn_args: 442 return dyn_args 443 444 args_sig = inspect.signature(fn) 445 if _is_args_fullmode(dyn_args): 446 if not isinstance(dyn_args, (list, tuple)): 447 temp_dyn_args = (dyn_args,) 448 else: 449 temp_dyn_args = dyn_args 450 451 # If dyn_args is fullmode, it should be apply directly. 452 args_sig_parameters = list(args_sig.parameters.values()) 453 if not args_sig_parameters: 454 return () 455 456 # fn may be Cell's construct while the first input is 'self'. 457 if args_sig_parameters[0].name == "self" and (len(temp_dyn_args) + 1) == len(args_sig_parameters): 458 bound_args = args_sig.bind(None, *temp_dyn_args) 459 bound_args.apply_defaults() 460 return bound_args.args[1:] 461 462 bound_args = args_sig.bind(*temp_dyn_args) 463 bound_args.apply_defaults() 464 return bound_args.args 465 466 # The dyn_args is not fullmode, a real compilation arguments should be assembled by latter procession... 467 arg_names = [] 468 args_sig_parameters = list(args_sig.parameters.values()) 469 for arg_p in args_sig_parameters: 470 if arg_p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): 471 arg_names.append(arg_p.name) 472 else: 473 raise TypeError("Dynamic arguments is not accepted for VAR_POSITIONAL or VAR_KEYWORD parameters!") 474 475 offset = -1 if fn.__name__ == 'construct' and args_sig_parameters[0].name == "self" else 0 476 meet_index = set() 477 478 def _check_index_valid(index): 479 if index >= len(arg_names): 480 raise ValueError("For dict mode, valid index is \"0\"-\"%d\", but got %s!" % (len(arg_names) - 1, index)) 481 if index in meet_index: 482 raise ValueError("For dict mode, there are more than one same specified key for real index: %d!" % index) 483 meet_index.add(index) 484 485 arg_handler_infos = [] 486 for k, v in dyn_args.items(): 487 if not isinstance(k, str): 488 raise TypeError("For dict mode, only string key is accepted, but got %s!" % k) 489 if k in arg_names: 490 cur_id = arg_names.index(k) 491 _check_index_valid(cur_id) 492 arg_handler_infos.append([cur_id + offset, v]) 493 else: 494 raise ValueError("For dict mode, valid key is %s, but got %s!" % (arg_names, k)) 495 return {ARG_SPECIFIED: arg_handler_infos, TOTAL_ARG_LEN: len(args_sig_parameters)} 496 497 498def _generate_dyn_compile_args(compile_args, dyn_args): 499 """Generate the dynamic compile arguments.""" 500 if not dyn_args: 501 return compile_args 502 if _is_args_fullmode(dyn_args, False): 503 if not isinstance(dyn_args, (list, tuple)): 504 return (dyn_args,) 505 return dyn_args 506 arg_specified_infos = dyn_args.get(ARG_SPECIFIED, None) 507 if arg_specified_infos is None: 508 raise RuntimeError("For dict mode, a key with \"%s\" should exist, but got %s!" % 509 (ARG_SPECIFIED, str(dyn_args))) 510 new_compile_args = list(compile_args) 511 for index, arg in arg_specified_infos: 512 new_compile_args[index] = arg 513 return tuple(new_compile_args) 514 515 516class _MindsporeFunctionExecutor: 517 """ 518 Represents a function compiled by graph compiler. 519 520 _MindsporeFunctionExecutor will compile the original function for every combination 521 of argument types and shapes it is given (as well as their values, optionally). 522 523 Args: 524 fn (Function): The root function to compile. 525 input_signature (Function): User defines signature to verify input. 526 ms_create_time(TimeStamp): Time the function was created 527 obj (Object): If function is a method, obj is the owner of function, 528 else, obj is none. 529 530 Returns: 531 The result of pipeline running in graph mode. 532 """ 533 534 def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None): 535 init_pipeline() 536 if not isinstance(fn, (types.FunctionType, types.MethodType)): 537 raise RuntimeError('fn {} is not function or method'.format(fn)) 538 539 self.fn = fn 540 self.input_signature = input_signature 541 self.obj = None 542 if obj and hasattr(obj, fn.__name__): 543 self.obj = obj 544 self.shard_parent_obj = obj 545 self.enable_tuple_broaden = False 546 self._graph_executor = GraphExecutor_.get_instance() 547 self._create_time = ms_create_time 548 self._compile_args = None 549 self.jit_config_dict = jit_config.jit_config_dict if jit_config else None 550 551 @_wrap_func 552 def __call__(self, *args, **kwargs): 553 args_list = args 554 if self.obj is not None: 555 args_list = args_list[1:] 556 phase = "" 557 try: 558 if context.get_context("mode") == context.PYNATIVE_MODE: 559 _pynative_executor.set_jit_compile_status(True, phase) 560 phase = self.compile(self.fn.__name__, *args_list, **kwargs) 561 _pynative_executor.set_jit_compile_status(False, phase) 562 else: 563 phase = self.compile(self.fn.__name__, *args_list, **kwargs) 564 except Exception as err: 565 _pynative_executor.clear_res() 566 raise err 567 568 if context.get_context("precompile_only"): 569 return None 570 571 new_inputs = self._generate_run_args(args_list, kwargs) 572 output = self._graph_executor(tuple(new_inputs), phase) 573 if context.get_context("mode") == context.PYNATIVE_MODE: 574 output = _pynative_executor.grad_jit(output, *new_inputs) 575 576 return output 577 578 def compile(self, method_name, *args, **kwargs): 579 """Returns pipeline for the given args.""" 580 # Check whether hook function registered on Cell object. 581 if self.obj and hasattr(self.obj, "_hook_fn_registered"): 582 if self.obj._hook_fn_registered(): 583 logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. " 584 f"If you want to use hook function, please use context.set_context to set " 585 f"pynative mode and remove 'jit' decorator.") 586 # Chose dynamic shape tensors or actual input tensors as compile args. 587 compile_args = self._generate_compile_args(args) 588 key_id = self._get_key_id() 589 compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, 590 self.input_signature) 591 592 # Restore the mutable attr for every arg. 593 compile_args = _restore_mutable_attr(args, compile_args) 594 self._compile_args = compile_args 595 generate_name, echo_function_name = self._get_generate_name() 596 # The full Function name 597 full_function_name = generate_name 598 create_time = '' 599 600 # Add key with obj 601 if self.obj is not None: 602 if self.obj.__module__ != self.fn.__module__: 603 logger.info( 604 f'The module of `self.obj`: `{self.obj.__module__}` is not same with the module of `self.fn`: ' 605 f'`{self.fn.__module__}`') 606 self.obj.__parse_method__ = method_name 607 if isinstance(self.obj, ms.nn.Cell): 608 generate_name = generate_name + '.' + str(self.obj.create_time) 609 create_time = str(self.obj.create_time) 610 else: 611 generate_name = generate_name + '.' + str(self._create_time) 612 create_time = str(self._create_time) 613 614 generate_name = generate_name + '.' + str(id(self.obj)) 615 full_function_name = generate_name 616 else: 617 # Different instance of same class may use same memory(means same obj_id) at diff times. 618 # To avoid unexpected phase matched, add create_time to generate_name. 619 generate_name = generate_name + '.' + str(self._create_time) 620 create_time = str(self._create_time) 621 622 self.enable_tuple_broaden = False 623 if hasattr(self.obj, "enable_tuple_broaden"): 624 self.enable_tuple_broaden = self.obj.enable_tuple_broaden 625 626 self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden) 627 key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden) 628 phase = generate_name + '.' + str(key) 629 630 update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature) 631 632 if phase in ms_compile_cache: 633 # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_ 634 # generated in generate_arguments_key. 635 self._graph_executor.clear_compile_arguments_resource() 636 return phase 637 638 _check_recompile(self.obj, compile_args, kwargs, full_function_name, create_time, echo_function_name) 639 640 # If enable compile cache, get the dependency files list and set to graph executor. 641 self._set_compile_cache_dep_files() 642 if self.jit_config_dict: 643 self._graph_executor.set_jit_config(self.jit_config_dict) 644 else: 645 jit_config_dict = JitConfig().jit_config_dict 646 self._graph_executor.set_jit_config(jit_config_dict) 647 648 if self.obj is None: 649 # Set an attribute to fn as an identifier. 650 if isinstance(self.fn, types.MethodType): 651 setattr(self.fn.__func__, "__jit_function__", True) 652 else: 653 setattr(self.fn, "__jit_function__", True) 654 is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True) 655 if isinstance(self.fn, types.MethodType): 656 delattr(self.fn.__func__, "__jit_function__") 657 else: 658 delattr(self.fn, "__jit_function__") 659 else: 660 if isinstance(self.obj, ms.nn.Cell): 661 self._graph_executor.set_weights_values(self.obj.parameters_dict()) 662 is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True) 663 664 if not is_compile: 665 raise RuntimeError("Executor compile failed.") 666 ms_compile_cache.add(phase) 667 668 return phase 669 670 @staticmethod 671 def _optimizer_state_init(opt_states): 672 """set data for all optimizer states in case it is executed in graph mode""" 673 prefix_list = ["moments", "accum", "moment1", "moment2", "lamb_m", "lamb_v", "mean_grad", 674 "mean_square", "prev"] 675 for opt_param in opt_states: 676 prefix = opt_param.name[:opt_param.name.find(".")] 677 if opt_param.has_init and (prefix in prefix_list or opt_param.name == "global_step"): 678 opt_param.init_data() 679 680 def _get_key_id(self): 681 """get key id.""" 682 if isinstance(self.obj, ms.nn.Cell): 683 key_id = str(id(self.obj)) + str(self.obj.create_time) 684 else: 685 key_id = str(id(self.obj)) + str(self._create_time) 686 687 if _pynative_executor.grad_flag(): 688 key_id = key_id + ".grad" 689 return key_id 690 691 def _get_generate_name(self): 692 """get generate name.""" 693 generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + str( 694 self.fn.__code__.co_firstlineno) 695 echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \ 696 + "\", line " + str(self.fn.__code__.co_firstlineno) 697 if _pynative_executor.grad_flag(): 698 generate_name = generate_name + ".grad" 699 if _is_pynative_parallel(): 700 generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj)) 701 return generate_name, echo_function_name 702 703 def _set_compile_cache_dep_files(self): 704 # If enable compile cache, get the dependency files list 705 enable_compile_cache = context.get_context("enable_compile_cache") 706 if enable_compile_cache is None: 707 enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE') 708 if enable_compile_cache is True or enable_compile_cache == "1": 709 self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files()) 710 711 def _generate_compile_args(self, args_list): 712 """Chose dynamic shape tensors or actual input tensors as compile args.""" 713 # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile. 714 compile_args = _pynative_executor.get_dynamic_input(args_list) 715 # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args. 716 if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs(): 717 compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs()) 718 if len(compile_args) != len(args_list): 719 raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of " 720 f"dynamic shape tensors: {len(compile_args)}.") 721 self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature") 722 Validator.check_symbolic_shape(compile_args, args_list) 723 724 # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args. 725 if self.input_signature is not None: 726 compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature)) 727 dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)]) 728 Validator.check_symbolic_shape(self.input_signature, args_list) 729 if dyn_shape: 730 # Checkout whether the `sens` has been added to args_list. 731 if len(compile_args) == len(args_list) - 1: 732 logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number " 733 f"of input_signature args '{len(compile_args)}'. The last actual args may " 734 f"be 'sens' and added it to compile args.") 735 compile_args.append(args_list[-1]) 736 compile_args = tuple(compile_args) 737 self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature") 738 if self.obj is not None: 739 _pynative_executor.set_dynamic_input(self.obj, *compile_args) 740 else: 741 _pynative_executor.set_dynamic_input(self.fn, *compile_args) 742 else: 743 if not verify_inputs_signature(compile_args, args_list): 744 raise ValueError("The input args is incompatible with the args in `input_signature`!") 745 return compile_args 746 747 def _generate_run_args(self, args_list, kwargs): 748 """ 749 Generate input args, which are required for running. 750 751 Args: 752 args_list (Tuple): Actual input args. 753 kwargs (Dict): Actual input kwargs. 754 755 Returns: 756 new_inputs, new input args, which are required for running. 757 """ 758 return _get_args_for_run(self, args_list, kwargs, self._compile_args) 759 760 761# The attributes used to identify a given object. 762attr_op = {"__str__": lambda x: x.__str__(), 763 "__hash__": lambda x: str(x.__hash__()), 764 "__module__": lambda x: x.__module__, 765 "__name__": lambda x: x.__name__, 766 "__qualname__": lambda x: x.__qualname__, 767 "__len__": lambda x: str(x.__len__()), 768 "__code__": lambda x: x.__code__.co_filename + str(x.__code__.co_firstlineno) 769 } 770 771 772def _get_obj_id(input_obj): 773 """Get hash id of single object.""" 774 obj_id = ".".join( 775 (map(lambda x: attr_op.get(x)(input_obj) if hasattr(input_obj, x) and getattr(input_obj, x) else "", attr_op))) 776 return obj_id + str(id(input_obj)) 777 778 779def _get_jit_hash(hash_input): 780 """Get hash value of single object or list of objects.""" 781 if isinstance(list, tuple): 782 return ".".join(map(_get_obj_id, hash_input)) 783 return _get_obj_id(hash_input) 784 785 786def _update_graph_executor_config(jit_config): 787 """Update GraphExecutor jit_config""" 788 if isinstance(jit_config, JitConfig): 789 jit_config = jit_config.jit_config_dict 790 if not isinstance(jit_config, dict): 791 return 792 valid_config = dict() 793 for k, v in jit_config.items(): 794 valid_config[str(k)] = str(v) 795 GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict) 796 797 798def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False): 799 """ 800 Create a callable MindSpore graph from a Python function. 801 802 This allows the MindSpore runtime to apply optimizations based on graph. 803 804 Args: 805 fn (Function): The Python function that will be run as a graph. Default: ``None`` . 806 mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` . 807 808 - `PSJit <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html>`_ : 809 Parse python ast to build graph. 810 - `PIJit <https://www.mindspore.cn/docs/en/master/design/dynamic_graph_and_static_graph.html>`_ : 811 Parse python bytecode to build graph at runtime. 812 813 input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The 814 shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the 815 input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the 816 same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`: 817 818 - Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs 819 for graph-compiling. 820 - Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be 821 substituted into the input at the corresponding position for graph-compiling. 822 823 Default: ``None`` . 824 825 hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`, 826 like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args` 827 will trigger recompilation. Default: ``None`` . 828 jit_config (JitConfig): Jit config for compile. Default: ``None`` . 829 compile_once(bool): ``True``: The function would be compiled once when it was created many times. 830 But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when 831 it was created again. 832 Default: ``False`` . 833 834 Note: 835 If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn` 836 will not accept `**kwargs`. 837 838 Returns: 839 Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is 840 None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is 841 equal to the case when `fn` is not None. 842 843 Supported Platforms: 844 ``Ascend`` ``GPU`` ``CPU`` 845 846 Examples: 847 >>> import numpy as np 848 >>> from mindspore import Tensor 849 >>> from mindspore import ops 850 >>> from mindspore import jit 851 ... 852 >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 853 >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 854 ... 855 >>> # create a callable MindSpore graph by calling decorator @jit 856 >>> def tensor_add(x, y): 857 ... z = x + y 858 ... return z 859 ... 860 >>> tensor_add_graph = jit(fn=tensor_add) 861 >>> out = tensor_add_graph(x, y) 862 ... 863 >>> # create a callable MindSpore graph through decorator @jit 864 >>> @jit 865 ... def tensor_add_with_dec(x, y): 866 ... z = x + y 867 ... return z 868 ... 869 >>> out = tensor_add_with_dec(x, y) 870 ... 871 >>> # create a callable MindSpore graph through decorator @jit with input_signature parameter 872 >>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), 873 ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) 874 ... def tensor_add_with_sig(x, y): 875 ... z = x + y 876 ... return z 877 ... 878 >>> out = tensor_add_with_sig(x, y) 879 ... 880 >>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))}) 881 ... def tensor_add_with_sig_1(x, y): 882 ... z = x + y 883 ... return z 884 ... 885 >>> out1 = tensor_add_with_sig_1(x, y) 886 ... 887 ... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused. 888 ... # While fn differs during calling again, recompilation will be triggered. 889 >>> def func(x): 890 ... return ops.exp(x) 891 ... 892 >>> def closure_fn(x, fn): 893 ... @jit(hash_args=fn) 894 ... def inner_fn(a): 895 ... return fn(a) 896 ... return inner_fn(x) 897 ... 898 >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32)) 899 >>> for i in range(10): 900 ... closure_fn(inputs, func) 901 ... 902 ... # Set compile_once = True, otherwise the train_step will be compiled again. 903 >>> def train(x): 904 ... @jit(compile_once = True) 905 ... def train_step(x): 906 ... return ops.exp(x) 907 ... for i in range(10): 908 ... train_step(x) 909 ... 910 >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32)) 911 >>> for i in range(10): 912 ... train(inputs) 913 """ 914 915 def wrap_mindspore(func): 916 if not isinstance(compile_once, bool): 917 logger.warning(f"The parameter `compile_once` of jit should be a bool, " 918 f"but got {type(compile_once)}.") 919 if hash_args: 920 hash_obj = _get_jit_hash(hash_args) 921 elif compile_once: 922 hash_obj = 0 923 else: 924 hash_obj = int(time.time() * 1e9) 925 926 dyn_args = _process_dyn_args(func, input_signature) 927 928 @wraps(func) 929 def staging_specialize(*args, **kwargs): 930 if os.getenv("MS_JIT") == '0': 931 return func(*args, **kwargs) 932 933 args, kwargs = _handle_func_args(func, *args, **kwargs) 934 935 process_obj = None 936 if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__): 937 process_obj = args[0] 938 # only the function or cell instance wrapped by shard will fall into this branch 939 if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME: 940 process_obj = hash_args 941 out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs) 942 return out 943 944 return staging_specialize 945 946 def pi_wrap_mindspore(decorated): 947 func = decorated 948 if isinstance(func, ms.nn.Cell): 949 func = func.construct 950 if isinstance(func, type) and issubclass(func, ms.nn.Cell): 951 func = func.construct 952 if isinstance(func, types.MethodType): 953 func = func.__func__ 954 if not isinstance(func, types.FunctionType): 955 logger.warning("only support function and mindspore.nn.Cell instance") 956 return decorated 957 958 # generator, coroutine, awaitable and a function that return them is unsupported 959 UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE | 960 inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE) 961 if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE: 962 return decorated 963 964 _update_graph_executor_config(jit_config) 965 config = dict() 966 if isinstance(jit_config, JitConfig): 967 config.update(jit_config.jit_config_dict) 968 elif jit_config is not None: 969 config.update(jit_config) 970 jit_mode_pi_enable() 971 972 if jit_mode_pi_compile(func, config, input_signature) is False: 973 logger.warning('add fn {} to compile failed '.format(func)) 974 975 return decorated 976 977 wrap_func = wrap_mindspore 978 if mode == "PIJit": 979 wrap_func = pi_wrap_mindspore 980 981 if fn is not None: 982 return wrap_func(fn) 983 return wrap_func 984 985 986def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None): 987 """ 988 Create a callable MindSpore graph from a Python function. 989 990 This allows the MindSpore runtime to apply optimizations based on graph. 991 992 Note: 993 - `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead. 994 - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn` 995 will not accept `**kwargs`. 996 997 Args: 998 fn (Function): The Python function that will be run as a graph. Default: ``None`` . 999 input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor 1000 will be supplied to this function. The shape and dtype of actual inputs of `fn` should 1001 keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` . 1002 hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`, 1003 like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args` 1004 will trigger recompilation. Default: ``None`` . 1005 jit_config (JitConfig): Jit config for compile. Default: ``None`` . 1006 1007 Returns: 1008 Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is 1009 None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is 1010 equal to the case when `fn` is not None. 1011 1012 Supported Platforms: 1013 ``Ascend`` ``GPU`` ``CPU`` 1014 1015 Examples: 1016 >>> import numpy as np 1017 >>> from mindspore import Tensor 1018 >>> from mindspore import ops 1019 >>> from mindspore import ms_function 1020 ... 1021 >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 1022 >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 1023 ... 1024 >>> # create a callable MindSpore graph by calling ms_function 1025 >>> def tensor_add(x, y): 1026 ... z = x + y 1027 ... return z 1028 ... 1029 >>> tensor_add_graph = ms_function(fn=tensor_add) 1030 >>> out = tensor_add_graph(x, y) 1031 ... 1032 >>> # create a callable MindSpore graph through decorator @ms_function 1033 >>> @ms_function 1034 ... def tensor_add_with_dec(x, y): 1035 ... z = x + y 1036 ... return z 1037 ... 1038 >>> out = tensor_add_with_dec(x, y) 1039 ... 1040 >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter 1041 >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), 1042 ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) 1043 ... def tensor_add_with_sig(x, y): 1044 ... z = x + y 1045 ... return z 1046 ... 1047 >>> out = tensor_add_with_sig(x, y) 1048 ... 1049 ... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused. 1050 ... # While fn differs during calling again, recompilation will be triggered. 1051 >>> def func(x): 1052 ... return ops.exp(x) 1053 ... 1054 >>> def closure_fn(x, fn): 1055 ... @ms_function(hash_args=fn) 1056 ... def inner_fn(a): 1057 ... return fn(a) 1058 ... return inner_fn(x) 1059 ... 1060 >>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32)) 1061 >>> for i in range(10): 1062 ... closure_fn(inputs, func) 1063 """ 1064 1065 logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " 1066 "Please use 'mindspore.jit' instead.") 1067 return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config) 1068 1069 1070def _core(fn=None, **flags): 1071 """ 1072 A decorator that adds a flag to the function. 1073 1074 By default, the function is marked as True, enabling to use this decorator to 1075 set flag to a graph. 1076 1077 Args: 1078 fn (Function): Function to add flag. Default: ``None``. 1079 flags (dict): The following flags can be set core, which indicates that this is a core function or 1080 other flag. Default: ``None``. 1081 1082 Returns: 1083 Function, the function with core flag. 1084 1085 Supported Platforms: 1086 ``Ascend`` ``GPU`` ``CPU`` 1087 """ 1088 1089 # need set the attr and access on c++ 1090 def deco(fn): 1091 fn._func_graph_flags = { 1092 'core': True, 1093 **flags, 1094 } 1095 return fn 1096 1097 if fn is not None: 1098 ret = deco(fn) 1099 else: 1100 ret = deco 1101 return ret 1102 1103 1104def _add_flags(fn=None, **flags): 1105 """ 1106 A decorator that adds a flag to the function. 1107 1108 Note: 1109 Only supports bool value. 1110 1111 Args: 1112 fn (Function): Function or cell to add flag. Default: ``None``. 1113 flags (dict): Flags use kwargs. Default: ``None``. 1114 1115 Returns: 1116 Function, the function with added flags. 1117 1118 Supported Platforms: 1119 ``Ascend`` ``GPU`` ``CPU`` 1120 """ 1121 1122 def deco(fn): 1123 # need set the attr and access on c++ 1124 if not hasattr(fn, "_func_graph_flags"): 1125 fn._func_graph_flags = {} 1126 1127 fn._func_graph_flags.update({**flags}) 1128 return fn 1129 1130 ret = deco 1131 if fn is not None: 1132 ret = deco(fn) 1133 return ret 1134 1135 1136def _no_recursive(callable_obj): 1137 """ 1138 Method or function decorator for ignoring recursive check. 1139 1140 This allows MindSpore to skip the procedure of checking function or method recursive. 1141 1142 Args: 1143 callable_obj (Union(method, function)): The function or method to call. 1144 1145 Returns: 1146 Function or method with no_recursive flag. 1147 1148 Raises: 1149 TypeError: If ms_class is used for non-class types or nn.Cell. 1150 AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called. 1151 1152 Supported Platforms: 1153 ``Ascend`` ``GPU`` ``CPU`` 1154 """ 1155 is_cell_subclass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell) 1156 if not is_cell_subclass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj): 1157 raise TypeError(f"Decorator no_recursive is used for callable object, but got {callable_obj}.") 1158 _add_flags(callable_obj, no_recursive=True) 1159 return callable_obj 1160 1161 1162def ms_class(cls): 1163 """ 1164 Class decorator for user-defined classes. 1165 1166 This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods. 1167 1168 Note: 1169 `ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead. 1170 1171 Args: 1172 cls (Class): User-defined class. 1173 1174 Returns: 1175 Class. 1176 1177 Raises: 1178 TypeError: If ms_class is used for non-class types or nn.Cell. 1179 AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called. 1180 1181 Supported Platforms: 1182 ``Ascend`` ``GPU`` ``CPU`` 1183 1184 Examples: 1185 >>> import mindspore.nn as nn 1186 >>> from mindspore import ms_class 1187 ... 1188 >>> @ms_class 1189 ... class UserDefinedNet: 1190 ... def __init__(self): 1191 ... self.value = 10 1192 ... 1193 ... def func(self, x): 1194 ... return 2 * x 1195 ... 1196 >>> class Net(nn.Cell): 1197 ... def __init__(self): 1198 ... super(Net, self).__init__() 1199 ... self.net = UserDefinedNet() 1200 ... 1201 ... def construct(self, x): 1202 ... out = self.net.value + self.net.func(x) 1203 ... return out 1204 ... 1205 >>> net = Net() 1206 >>> out = net(5) 1207 >>> print(out) 1208 20 1209 """ 1210 1211 logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " 1212 "Please use 'mindspore.jit_class' instead.") 1213 1214 # Check if cls is of type class. 1215 if not inspect.isclass(cls): 1216 raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.') 1217 # Check if cls is nn.Cell. 1218 if issubclass(cls, ms.nn.Cell): 1219 raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") 1220 logger.info(f'Found ms_class: {cls}.') 1221 setattr(cls, '__ms_class__', True) 1222 return cls 1223 1224 1225def jit_class(cls): 1226 """ 1227 Class decorator for user-defined classes. 1228 1229 This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods. 1230 1231 Args: 1232 cls (Class): User-defined class. 1233 1234 Returns: 1235 Class. 1236 1237 Raises: 1238 TypeError: If `jit_class` is used for non-class types or nn.Cell. 1239 AttributeError: If the private attributes or magic methods of the class decorated with `jit_class` is called. 1240 1241 Supported Platforms: 1242 ``Ascend`` ``GPU`` ``CPU`` 1243 1244 Examples: 1245 >>> import mindspore.nn as nn 1246 >>> from mindspore import jit_class 1247 ... 1248 >>> @jit_class 1249 ... class UserDefinedNet: 1250 ... def __init__(self): 1251 ... self.value = 10 1252 ... 1253 ... def func(self, x): 1254 ... return 2 * x 1255 ... 1256 >>> class Net(nn.Cell): 1257 ... def __init__(self): 1258 ... super(Net, self).__init__() 1259 ... self.net = UserDefinedNet() 1260 ... 1261 ... def construct(self, x): 1262 ... out = self.net.value + self.net.func(x) 1263 ... return out 1264 ... 1265 >>> net = Net() 1266 >>> out = net(5) 1267 >>> print(out) 1268 20 1269 """ 1270 from mindspore import nn 1271 # Check if cls is of type class. 1272 if not inspect.isclass(cls): 1273 raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.') 1274 # Check if cls is nn.Cell. 1275 if issubclass(cls, nn.Cell): 1276 raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") 1277 setattr(cls, '__ms_class__', True) 1278 return cls 1279 1280 1281def set_adapter_config(config): 1282 """ 1283 Register configuration information for MSAdapter. 1284 1285 Args: 1286 config (dict): Configuration information. 1287 """ 1288 if not isinstance(config, dict): 1289 raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.") 1290 for key, value in config.items(): 1291 if key == "Tensor": 1292 ms_adapter_registry.register_tensor(value) 1293 elif key == "Parameter": 1294 ms_adapter_registry.register_parameter(value) 1295 elif key == "convert_object_map": 1296 ms_adapter_registry.register_convert_map(value) 1297 elif key == "convert_adapter_tensor_map": 1298 ms_adapter_registry.register_convert_adapter_tensor_map(value) 1299 else: 1300 raise ValueError(f"Unsupported key in adapter config: {key}") 1301 1302 1303def _function_forbid_reuse(func): 1304 if not inspect.isfunction(func): 1305 raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.') 1306 setattr(func, '__function_forbid_reuse__', True) 1307 return func 1308 1309 1310def _build_broadcast_graph(broadcast_params_dict, broadcast_phase): 1311 """Build broadcast graph.""" 1312 from mindspore.nn.wrap.cell_wrapper import _BroadCastCell 1313 if not broadcast_params_dict: 1314 broadcast_params_dict = {} 1315 broadcast_params = [] 1316 for param in broadcast_params_dict.values(): 1317 broadcast_params.append(Tensor(param.asnumpy())) 1318 _broadcast_net = _BroadCastCell(broadcast_params) 1319 _broadcast_net.phase = broadcast_phase 1320 broadcasted_params = _broadcast_net() 1321 for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): 1322 broadcast_params_dict.get(param_name).set_data(param) 1323 1324 1325def _get_auto_split_param_names(parameter_layout_dict): 1326 auto_split_param_names = [] 1327 for key, value in parameter_layout_dict.items(): 1328 for dim in value[1]: 1329 if dim != -1: 1330 auto_split_param_names.append(key) 1331 break 1332 return auto_split_param_names 1333 1334 1335def _parameter_broadcast(obj): 1336 """ 1337 Parameter broadcast. 1338 When the parallel mode is 'semi_auto_parallel' or 'auto_parallel', it will broadcast the parameters that have not 1339 split. 1340 """ 1341 auto_split_param_names = [] 1342 if _is_in_auto_parallel_mode(): 1343 auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict) 1344 1345 broadcast_params_dict = obj.parameters_broadcast_dict() 1346 if auto_split_param_names and broadcast_params_dict: 1347 broadcast_params_dict = OrderedDict() 1348 for param_name, param in obj.parameters_broadcast_dict().items(): 1349 if param_name not in auto_split_param_names: 1350 broadcast_params_dict[param_name] = param 1351 broadcast_phase = "_broadcast_subgraph" 1352 _build_broadcast_graph(broadcast_params_dict, broadcast_phase) 1353 1354 1355class _no_grad(contextlib.ContextDecorator): 1356 """ 1357 Context Manager to disable gradient calculation. When enter this context, we will disable calculate 1358 gradient. When exit this context, we will resume its prev state. 1359 Currently, it can only use in Pynative mode. It also can be used as decorator. 1360 """ 1361 1362 def __init__(self): 1363 self.prev_state = False 1364 1365 def __enter__(self): 1366 if context.get_context("mode") == context.GRAPH_MODE: 1367 raise RuntimeError("For no_grad feature, currently only support Pynative mode, but got Graph mode.") 1368 self.prev_state = _pynative_executor.enable_grad() 1369 _pynative_executor.set_enable_grad(False) 1370 1371 def __exit__(self, exc_type, exc_val, exc_tb): 1372 _pynative_executor.set_enable_grad(self.prev_state) 1373 return False 1374 1375 1376class _PyNativeExecutor: 1377 """ 1378 A pynative executor used to compile/manage/run single op. 1379 1380 The main functions include: construct op graph, compile op graph, auto grad and run op graph. 1381 1382 Args: 1383 obj (Object): The python network that will be run in pynative mode. 1384 args (Tuple(Tensor...)): The inputs of network in tuple form. 1385 1386 Returns: 1387 gradients (Tuple(Tensor...)): The gradients of network parameters and inputs. 1388 1389 Supported Platforms: 1390 ``Ascend`` ``GPU`` ``CPU`` 1391 """ 1392 1393 def __init__(self): 1394 self._executor = PyNativeExecutor_.get_instance() 1395 self._executor.set_py_exe_path(sys.executable) 1396 self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) 1397 1398 @staticmethod 1399 def parameter_broadcast(obj, phase): 1400 """ 1401 Run broadcast for parameter. 1402 1403 Args: 1404 obj (Cell): The cell instance. 1405 phase (str): The phase of cell instance. 1406 1407 Return: 1408 None. 1409 """ 1410 if BROADCAST_PHASE not in phase and _get_parameter_broadcast(): 1411 _parameter_broadcast(obj) 1412 1413 def real_run_op(self, *args): 1414 """ 1415 Run single op. 1416 1417 Args: 1418 args (tuple): Op prim and input arguments. 1419 1420 Return: 1421 Tensor, result of run op. 1422 """ 1423 return self._executor.real_run_op(*args) 1424 1425 def run_op_async(self, *args): 1426 """ 1427 Run single op async. 1428 1429 Args: 1430 args (tuple): Op prim and input arguments. 1431 1432 Return: 1433 StubNode, result of run op. 1434 """ 1435 return self._executor.run_op_async(*args) 1436 1437 def new_graph(self, obj, *args, **kwargs): 1438 """ 1439 Initialize resources for building forward and backward graph. 1440 1441 Args: 1442 obj (Function/Cell): The function or cell instance. 1443 args (tuple): Function or cell input arguments. 1444 kwargs (dict): keyword arguments. 1445 1446 Return: 1447 None. 1448 """ 1449 self._executor.new_graph(obj, *args, *(kwargs.values())) 1450 1451 def end_graph(self, obj, output, *args, **kwargs): 1452 """ 1453 Clean resources after building forward and backward graph. 1454 1455 Args: 1456 obj (Function/Cell): The function or cell instance. 1457 output (Tensor/tuple/list): Function or cell output object. 1458 args (tuple): Function or cell input arguments. 1459 kwargs (dict): keyword arguments. 1460 1461 Return: 1462 None. 1463 """ 1464 self._executor.end_graph(obj, output, *args, *(kwargs.values())) 1465 1466 def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs): 1467 """ 1468 Whether the forward graph need to construct. 1469 1470 Args: 1471 grad (GradOperation): The gradoperation object. 1472 obj (Function/Cell): The function or cell instance. 1473 grad_hash_id (tuple): The id of objects which contribute to cache of compiled graph in pynative mode. 1474 args (tuple): Function or cell input arguments. 1475 kwargs (dict): keyword arguments. 1476 1477 Return: 1478 bool, specifies whether the forward graph need to construct. 1479 """ 1480 return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values())) 1481 1482 def grad(self, obj, grad, weights, grad_position, *args, **kwargs): 1483 """ 1484 Get grad graph. 1485 1486 Args: 1487 obj (Function/Cell): The function or cell instance. 1488 grad (GradOperation): The gradoperation object. 1489 weights (ParameterTuple): The weights of cell instance. 1490 grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input. 1491 If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0. 1492 args (tuple): Function or cell input arguments. 1493 kwargs (dict): keyword arguments. 1494 1495 Return: 1496 None. 1497 """ 1498 return self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) 1499 1500 def clear_res(self): 1501 """ 1502 Clean resource for _PyNativeExecutor. 1503 1504 Return: 1505 None. 1506 """ 1507 return self._executor.clear_res() 1508 1509 def sync(self): 1510 """ 1511 SyncStream. 1512 1513 Return: 1514 None. 1515 """ 1516 self._executor.sync() 1517 1518 def grad_jit(self, output, *args): 1519 """ 1520 Building grad graph decorated by jit. 1521 1522 Args: 1523 output (tuple): The function or cell decorated by jit output object. 1524 args (tuple): Function or cell decorated by jit input arguments. 1525 1526 Return: 1527 None. 1528 """ 1529 return self._executor.grad_jit(output, *args) 1530 1531 def grad_flag(self): 1532 """ 1533 The flag of building grad graph. 1534 1535 Return: 1536 bool, whether building grad graph. 1537 """ 1538 return self._executor.grad_flag() 1539 1540 def set_grad_flag(self, flag): 1541 """ 1542 Set the flag of building grad graph. 1543 1544 Args: 1545 flag (bool): Specifying whether building grad graph. 1546 1547 Return: 1548 None. 1549 """ 1550 self._executor.set_grad_flag(flag) 1551 1552 def set_async_for_graph(self, flag): 1553 """ 1554 Set the flag for graph async run. 1555 1556 Args: 1557 flag (bool): Specifying whether enable graph async run. 1558 1559 Return: 1560 None. 1561 """ 1562 self._executor.set_async_for_graph(flag) 1563 1564 def enable_grad(self): 1565 """ 1566 The global flag whether needing to calculate gradient. 1567 1568 Return: 1569 bool, whether needing to calculate gradient. 1570 """ 1571 return self._executor.enable_grad() 1572 1573 def set_enable_grad(self, flag): 1574 """ 1575 Set the flag of calculating gradient. 1576 1577 Args: 1578 flag (bool): Specifying whether calculating gradient. 1579 1580 Return: 1581 None. 1582 """ 1583 self._executor.set_enable_grad(flag) 1584 1585 def set_jit_compile_status(self, status, phase): 1586 """ 1587 Set jit is compiling 1588 1589 Args: 1590 status(bool): jit compile status 1591 phase (str): The phase of cell/function instance. 1592 Return: 1593 None. 1594 """ 1595 self._executor.set_jit_compile_status(status, phase) 1596 1597 def set_is_run_recompute(self, status): 1598 """ 1599 Set recompute grad is compiling 1600 1601 Args: 1602 status(bool): grad is in recompute status 1603 Return: 1604 None. 1605 """ 1606 self._executor.set_is_run_recompute(status) 1607 1608 def set_dynamic_input(self, obj, *args): 1609 """ 1610 Set dynamic shape tensor of input arguments. 1611 1612 Args: 1613 obj (Function/Cell): The function or cell instance. 1614 args (tuple): Function or cell dynamic input arguments. 1615 1616 Return: 1617 None. 1618 """ 1619 self._executor.set_dynamic_input(obj, *args) 1620 1621 def get_dynamic_input(self, *actual_args): 1622 """ 1623 Get dynamic shape arguments according to actual input arguments. 1624 1625 Args: 1626 actual_args(tuple): Actual input arguments of Function or Cell. 1627 1628 Return: 1629 dynamic_shape_args(tuple): Dynamic shape arguments of Function or Cell. 1630 """ 1631 return self._executor.get_dynamic_input(*actual_args) 1632 1633 def is_first_cell(self): 1634 """ 1635 The flag of first cell instance. 1636 1637 Return: 1638 bool, specifies whether is the first cell. 1639 """ 1640 1641 return self._executor.is_first_cell() 1642 1643 def set_hook_changed(self, cell): 1644 """ 1645 The flag of registering or removing a hook function on Cell instance. 1646 1647 Args: 1648 cell (Cell): The cell instance. 1649 1650 Return: 1651 None. 1652 """ 1653 self._executor.set_hook_changed(cell) 1654 1655 def constant_folding(self, *args): 1656 """ 1657 Get value by infer value. 1658 1659 Args: 1660 args (tuple): Op prim and input arguments. 1661 1662 Return: 1663 Tensor, the value get by op infer. 1664 """ 1665 return self._executor.constant_folding(*args) 1666 1667 1668class _CellGraphExecutor: 1669 """ 1670 An executor used to compile/manage/run graph for a Cell. 1671 1672 Including data_graph, train_graph, eval_graph and predict graph. 1673 1674 Args: 1675 obj (Function/Cell): The function or cell instance need compile. 1676 args (tuple): Function or cell input arguments. 1677 1678 Returns: 1679 Graph, return the result of pipeline running. 1680 """ 1681 1682 def __init__(self): 1683 # create needed graph by lazy mode 1684 self.is_init = False 1685 self.enable_tuple_broaden = False 1686 self.obfuscate_config = None # used for model's dynamic obfuscation 1687 self._graph_executor = GraphExecutor_.get_instance() 1688 self._graph_executor.set_py_exe_path(sys.executable) 1689 self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) 1690 1691 def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes, 1692 input_indexs, phase='dataset', need_run=True): 1693 """ 1694 Initialization interface for calling data subgraph. 1695 1696 Args: 1697 queue_name (str): The name of tdt queue on the device. 1698 dataset_size (int): The size of dataset. 1699 batch_size (int): The size of batch. 1700 dataset_types (list): The output types of element in dataset. 1701 dataset_shapes (list): The output shapes of element in dataset. 1702 input_indexs (list): The index of data with net. 1703 phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'. 1704 1705 Returns: 1706 bool, specifies whether the data subgraph was initialized successfully. 1707 """ 1708 if not init_exec_dataset(queue_name=queue_name, 1709 size=dataset_size, 1710 batch_size=batch_size, 1711 types=dataset_types, 1712 shapes=dataset_shapes, 1713 input_indexs=input_indexs, 1714 phase=phase, 1715 need_run=need_run): 1716 raise RuntimeError("Failure to init and dataset subgraph!") 1717 self._graph_executor.set_queue_name(queue_name) 1718 return True 1719 1720 def set_queue_name(self, queue_name): 1721 """ 1722 while a mode use shared dataset with others, need set queue_name which saved in data_set 1723 :param queue_name: 1724 :return: 1725 """ 1726 self._graph_executor.set_queue_name(queue_name) 1727 1728 def get_queue_name(self, dataset_phase): 1729 """ 1730 Get cached queue name for the graph loaded from compile cache. 1731 :return: cached queue name 1732 """ 1733 return self._graph_executor.get_queue_name(dataset_phase) 1734 1735 @staticmethod 1736 def _set_dataset_mode(obj): 1737 """set dataset mode.""" 1738 # decide whether to sink based on the sink_mode flag which is set in connect_network_with_dataset 1739 if 'sink_mode' in obj.get_flags().keys() and obj.get_flags()['sink_mode'] is True: 1740 _set_dataset_mode_config('sink') 1741 else: 1742 _set_dataset_mode_config('normal') 1743 1744 @staticmethod 1745 def _use_vm_mode(): 1746 enable_ge = context.get_context("enable_ge") 1747 enable_debug_runtime = context.get_context("enable_debug_runtime") 1748 exe_mode = context.get_context("mode") == context.PYNATIVE_MODE 1749 return not enable_ge or (enable_debug_runtime and exe_mode) 1750 1751 def _build_data_graph(self, obj, phase): 1752 self._graph_executor.build_data_graph(obj.parameters_dict(), phase) 1753 1754 def _set_compile_cache_dep_files(self, phase): 1755 # If enable compile cache, get the dependency files list 1756 enable_compile_cache = context.get_context("enable_compile_cache") 1757 if enable_compile_cache is None: 1758 enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE') 1759 if enable_compile_cache is True or enable_compile_cache == "1": 1760 self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files()) 1761 1762 def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs): 1763 """ 1764 Compiles graph. 1765 1766 Args: 1767 obj (Function/Cell): The function or cell instance need compile. 1768 phase (str): The name of compile phase. Default: 'predict'. 1769 do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. 1770 jit_config_dict (dict): Jit config for compile. Default: ``None``. 1771 args (tuple): Args of the Cell object. 1772 kwargs (dict): Kwargs of the Cell object. 1773 1774 Return: 1775 Str, the full phase of the cell. 1776 Bool, if the graph has been compiled before, return False, else return True. 1777 """ 1778 obj.__parse_method__ = 'construct' 1779 if not hasattr(obj, obj.__parse_method__): 1780 raise AttributeError( 1781 'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__)) 1782 key_id = str(id(obj)) + str(obj.create_time) 1783 args = get_auto_dynamic_shape_args(args, key_id) 1784 1785 self.enable_tuple_broaden = False 1786 if hasattr(obj, "enable_tuple_broaden"): 1787 self.enable_tuple_broaden = obj.enable_tuple_broaden 1788 logger.debug(f"Convert the network: {do_convert}.") 1789 self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden) 1790 key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden) 1791 obj.arguments_key = str(key) 1792 raw_phase = phase 1793 phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1794 obj.phase_cache[raw_phase] = phase 1795 update_auto_dynamic_shape_phase(args, key_id, phase) 1796 obj.current_phase = phase 1797 if phase in obj.compile_cache and self.has_compiled(phase): 1798 logger.debug("%r graph has existed.", phase) 1799 # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_ 1800 # generated in generate_arguments_key. 1801 self._graph_executor.clear_compile_arguments_resource() 1802 return phase, False 1803 1804 full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj))) 1805 echo_function_name = obj.__class__.__name__ 1806 _check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name) 1807 1808 obj.check_names() 1809 _check_full_batch() 1810 self._set_dataset_mode(obj) 1811 self._set_compile_cache_dep_files(phase) 1812 1813 self._graph_executor.set_weights_values(obj.parameters_dict()) 1814 if jit_config_dict: 1815 self._graph_executor.set_jit_config(jit_config_dict) 1816 else: 1817 jit_config_dict = JitConfig().jit_config_dict 1818 self._graph_executor.set_jit_config(jit_config_dict) 1819 result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode()) 1820 obj.compile_cache.add(phase) 1821 if not result: 1822 raise RuntimeError("Executor compile failed.") 1823 graph = self._graph_executor.get_func_graph(phase) 1824 1825 if graph is None: 1826 raise RuntimeError("Compile graph failed for phase {}.".format(phase)) 1827 1828 auto_parallel_mode = _is_in_auto_parallel_mode() 1829 if not auto_parallel_mode: 1830 replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) 1831 self._update_param_node_default_input(phase, replace) 1832 elif 'skip_auto_parallel_compile' not in obj.get_flags().keys(): 1833 obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase) 1834 obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase) 1835 if "export.air" in phase: 1836 self._build_data_graph(obj, phase) 1837 elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): 1838 _parameter_broadcast(obj) 1839 return phase, True 1840 1841 def _update_param_node_default_input(self, phase, replace): 1842 new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} 1843 return self._graph_executor.updata_param_node_default_input(phase, new_param) 1844 1845 def _get_shard_strategy(self, obj): 1846 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1847 return self._graph_executor.get_strategy(real_phase) 1848 1849 def _get_num_parallel_ops(self, obj): 1850 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1851 return self._graph_executor.get_num_parallel_ops(real_phase) 1852 1853 def _get_allreduce_fusion(self, obj): 1854 real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1855 return self._graph_executor.get_allreduce_fusion(real_phase) 1856 1857 def __call__(self, obj, *args, phase='predict'): 1858 if context.get_context("precompile_only") or _is_role_sched(): 1859 return None 1860 return self.run(obj, *args, phase=phase) 1861 1862 def has_compiled(self, phase='predict'): 1863 """ 1864 Specify whether have been compiled. 1865 1866 Args: 1867 phase (str): The phase name. Default: 'predict'. 1868 1869 Returns: 1870 bool, specifies whether the specific graph has been compiled. 1871 """ 1872 return self._graph_executor.has_compiled(phase) 1873 1874 def flops_collection(self, phase='train'): 1875 """ 1876 Specify whether have been compiled. 1877 1878 Args: 1879 phase (str): The phase name. Default: 'predict'. 1880 1881 Returns: 1882 bool, specifies whether the specific graph has been compiled. 1883 """ 1884 return self._graph_executor.flops_collection(phase) 1885 1886 @_wrap_func 1887 def _exec_pip(self, obj, *args, phase=''): 1888 """Execute the generated pipeline.""" 1889 fn = obj.construct 1890 obj.__parse_method__ = fn.__name__ 1891 return self._graph_executor(args, phase) 1892 1893 def run(self, obj, *args, phase='predict'): 1894 """ 1895 Run the specific graph. 1896 1897 Args: 1898 obj (Cell): The cell object. 1899 args (tuple): Args of the Cell object. 1900 phase (str): The phase name. Default: 'predict'. 1901 1902 Returns: 1903 Tensor/Tuple, return execute result. 1904 """ 1905 if phase == 'save': 1906 exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1907 return self._graph_executor((), exe_phase) 1908 1909 phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1910 if self.has_compiled(phase_real): 1911 return self._exec_pip(obj, *args, phase=phase_real) 1912 raise KeyError('{} graph is not exist.'.format(phase_real)) 1913 1914 def del_net_res(self, obj, net_id): 1915 """Clear the memory resource of a network.""" 1916 self._graph_executor.del_net_res(obj, net_id) 1917 1918 def inc_graph_cell_count(self): 1919 """Increase the count of GraphCell instance.""" 1920 self._graph_executor.inc_graph_cell_count() 1921 1922 def dec_graph_cell_count(self): 1923 """Decrease the count of GraphCell instance.""" 1924 self._graph_executor.dec_graph_cell_count() 1925 1926 def _get_branch_control_input(self): 1927 if ('obf_ratio' not in self.obfuscate_config.keys()) or ( 1928 'obf_random_seed' not in self.obfuscate_config.keys()): 1929 raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.") 1930 obf_random_seed = self.obfuscate_config.get('obf_random_seed') 1931 if obf_random_seed == 0: 1932 branch_control_input = 0 1933 else: 1934 branch_control_input = _generate_branch_control_input(obf_random_seed) 1935 return branch_control_input 1936 1937 def _get_func_graph(self, obj, exec_id, use_prefix=False): 1938 """Get func graph from pipeline.""" 1939 if use_prefix: 1940 exec_id = exec_id + '.' + obj.arguments_key 1941 if self._graph_executor.has_compiled(exec_id) is False: 1942 return None 1943 if self.obfuscate_config is not None: 1944 raise ValueError('For get func graph, obfuscate_config is currently not supported now.') 1945 return self._graph_executor.get_func_graph(exec_id) 1946 1947 def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False): 1948 """Get graph proto from pipeline.""" 1949 if use_prefix: 1950 exec_id = exec_id + '.' + obj.arguments_key 1951 if self._graph_executor.has_compiled(exec_id) is False: 1952 return None 1953 if self.obfuscate_config is not None: 1954 branch_control_input = self._get_branch_control_input() 1955 return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental, 1956 self.obfuscate_config['obf_ratio'], 1957 branch_control_input) 1958 return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental) 1959 1960 def get_optimize_graph_proto(self, obj): 1961 """Return optimize graph binary proto.""" 1962 exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 1963 if self._graph_executor.has_compiled(exec_id) is False: 1964 return None 1965 graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id) 1966 if isinstance(graph_proto, str) and graph_proto == "": 1967 logger.warning("Can not get optimize graph proto. Instead, try to find function graph.") 1968 graph_proto = obj.get_func_graph_proto() 1969 return graph_proto 1970 1971 def export(self, file_name, graph_id, enc_key=None, encrypt_func=None): 1972 """ 1973 Export graph. 1974 1975 Args: 1976 file_name (str): File name of model to export 1977 graph_id (str): id of graph to be exported 1978 """ 1979 self._graph_executor.export_graph(file_name, graph_id, encrypt_func, enc_key) 1980 1981 1982def ms_memory_recycle(): 1983 """ 1984 Recycle memory used by MindSpore. 1985 When train multi Neural network models in one process, memory used by MindSpore is very large, 1986 this is because MindSpore cached runtime memory for every model. 1987 To recycle these cached memory, users can call this function after training of one model. 1988 1989 Examples: 1990 >>> import mindspore as ms 1991 >>> ms.ms_memory_recycle() 1992 """ 1993 if ms_compile_cache: 1994 _cell_graph_executor.del_net_res(None, ms_compile_cache) 1995 ms_compile_cache.clear() 1996 for cell_cache in cells_compile_cache.values(): 1997 if cell_cache: 1998 _cell_graph_executor.del_net_res(None, cell_cache) 1999 cell_cache.clear() 2000 _ms_memory_recycle() 2001 2002 2003def _generate_branch_control_input(obf_random_seed): 2004 """Generate append network input for dynamic obfuscation in random seed mode.""" 2005 seed_max = 2 ** 32 - 1 2006 int_max = 2 ** 31 - 1 2007 np.random.seed(obf_random_seed % seed_max) 2008 # generate a string as hash function inputs 2009 word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789" 2010 repo_len = len(word_repo) 2011 sha_string = '' 2012 string_len = 1024 * 1024 2013 for _ in range(string_len): 2014 rand_index = np.random.randint(0, repo_len) 2015 sha_string += word_repo[rand_index] 2016 # get hash result 2017 sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64 2018 branch_control_input = 1 2019 hex_base = 16 2020 for item in sha_result: 2021 if int(item, hex_base) > 0: 2022 branch_control_input *= int(item, hex_base) 2023 branch_control_input %= int_max 2024 return branch_control_input 2025 2026 2027def _bind_device_context(): 2028 """Bind device context to current thread""" 2029 _bind_device_ctx() 2030 2031 2032def flops_collection(phase='train'): 2033 """ 2034 Recycle memory used by MindSpore. 2035 When train multi Neural network models in one process, memory used by MindSpore is very large, 2036 this is because MindSpore cached runtime memory for every model. 2037 To recycle these cached memory, users can call this function after training of one model. 2038 2039 Examples: 2040 >>> import mindspore as ms 2041 >>> ms.ms_memory_recycle() 2042 """ 2043 return _cell_graph_executor.flops_collection(phase) 2044 2045 2046_cell_graph_executor = _CellGraphExecutor() 2047_pynative_executor = _PyNativeExecutor() 2048 2049__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection'] 2050