1# Copyright 2020-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""cell""" 16from __future__ import absolute_import 17 18import gc 19import inspect 20import os 21import time 22from collections import OrderedDict 23import numpy 24 25from mindspore._checkparam import args_type_check, check_hook_fn 26from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic 27from mindspore import log as logger 28from mindspore.common.parameter import PARAMETER_NAME_DEFAULT 29from mindspore.common.hook_handle import HookHandle 30from mindspore.context import ParallelMode 31from mindspore import context 32from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType 33from mindspore import _checkparam as Validator 34from mindspore.common import dtype as mstype 35from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache 36from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict 37from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args 38from mindspore.common.parameter import Parameter, ParameterTuple 39from mindspore.common.tensor import Tensor 40from mindspore.ops.operations import Cast 41from mindspore.ops.primitive import Primitive 42from mindspore.ops.operations import _inner_ops as inner 43from mindspore.parallel.shard import Shard 44from mindspore._check_jit_forbidden_api import jit_forbidden_register 45from mindspore.common._decorator import deprecated 46from mindspore.common._register_for_recompute import recompute_registry 47 48 49class Cell(Cell_): 50 """ 51 The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this 52 base class. 53 54 Layers in `mindspore.nn` are also the subclass of Cell, such as :class:`mindspore.nn.Conv2d`, 55 and :class:`mindspore.nn.ReLU`, etc. Cell will be compiled into a calculation 56 graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in 57 PYNATIVE_MODE (dynamic graph mode). 58 59 .. note:: 60 Cell is the inference mode by default. For a class that inherits a Cell, 61 if the training and inference have different structures, the subclass performs the inference branch by default. 62 To set the training mode, refer to `mindspore.nn.Cell.set_train` . 63 64 .. warning:: 65 In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute 66 named 'phase' or 'cells', otherwise, an error will be raised. 67 68 Args: 69 auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also 70 affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be 71 automatically prefixed, otherwise not. In general, the backbone network should be set to 72 ``True`` , otherwise the duplicate name problem will appear. The cell to train the backbone 73 network, such as optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to 74 ``False`` , otherwise the parameter name in backbone will be changed by mistake. 75 Default: ``True`` . 76 flags (dict, optional): Network configuration information, currently it is used for the binding of network 77 and dataset. Users can also customize network attributes by this parameter. Default: ``None`` . 78 79 Supported Platforms: 80 ``Ascend`` ``GPU`` ``CPU`` 81 82 Examples: 83 >>> import mindspore.nn as nn 84 >>> from mindspore import ops 85 >>> class MyCell(nn.Cell): 86 ... def __init__(self, forward_net): 87 ... super(MyCell, self).__init__(auto_prefix=False) 88 ... self.net = forward_net 89 ... self.relu = ops.ReLU() 90 ... 91 ... def construct(self, x): 92 ... y = self.net(x) 93 ... return self.relu(y) 94 >>> 95 >>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') 96 >>> my_net = MyCell(inner_net) 97 >>> print(my_net.trainable_params()) 98 ... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class, 99 ... # the parameter's name will be 'net.weight'. 100 [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] 101 """ 102 103 IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time', 104 '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', 105 '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook', 106 '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix', 107 '_attr_synced', 'pynative', 'requires_grad', 'cell_type'] 108 total_instance_count = 0 109 110 def __init__(self, auto_prefix=True, flags=None): 111 Cell_.__init__(self, self._cell_tag) 112 Cell.total_instance_count += 1 113 self.instance_count = Cell.total_instance_count 114 self._params = OrderedDict() 115 self._cells = OrderedDict() 116 self._params_list = OrderedDict() 117 self._primitives = OrderedDict() 118 self.training = False 119 self.requires_grad = False 120 self.pynative = False 121 self._attr_synced = False 122 self._param_prefix = '' 123 self._auto_prefix = auto_prefix 124 self._scope = None 125 self._phase = 'train' 126 self._parameter_layout_dict = {} 127 self._parallel_parameter_name_list = () 128 self._parallel_parameter_merge_net_dict = {} 129 self._create_time = int(time.time() * 1e9) 130 self.arguments_key = "" 131 self.compile_cache = set() 132 self.phase_cache = dict() 133 cells_compile_cache[id(self)] = self.compile_cache 134 self.parameter_broadcast_done = False 135 self._id = 1 136 self.exist_names = set("") 137 self.exist_objs = set() 138 self.recompute_cell = None 139 self.sig = inspect.signature(self.construct) 140 init_pipeline() 141 142 # call gc to release GE session resources used by non-used cell objects 143 if os.getenv('GC_COLLECT_IN_CELL') == '1': 144 gc.collect() 145 146 if flags: 147 self.add_flags(**flags) 148 self._bprop_debug = False 149 self._forward_pre_hook = OrderedDict() 150 self._forward_hook = OrderedDict() 151 self._enable_forward_pre_hook = False 152 self._enable_forward_hook = False 153 self._enable_backward_hook = False 154 self._cell_backward_hook = None 155 self._is_recursion_hook = False 156 self.cell_type = None 157 self.cast = Cast() 158 self._has_config_recompute = False 159 self._user_parameters = [] 160 self._dynamic_shape_inputs = None 161 self._compile_args = None 162 self.saved_dynamic_shape = None 163 self._jit_config_dict = dict() 164 self.grad_ops_label = False 165 self.ge_sync_data = False 166 self._is_check_and_refresh = False 167 self._amp_level = "" 168 self._init_flag = False 169 170 def __getstate__(self): 171 base = Cell_.__getstate__(self) 172 return base, self.__dict__ 173 174 def __setstate__(self, state): 175 base, dict_ = state 176 Cell_.__setstate__(self, base) 177 self.__dict__ = dict_ 178 self._attr_synced = False 179 180 def __bool__(self): 181 return True 182 183 @property 184 def _cell_tag(self): 185 # `<class 'xxxxxxx'>` to `xxxxxxx` 186 return str(self.__class__)[8:-2] 187 188 @property 189 def create_time(self): 190 return self._create_time 191 192 @property 193 def cell_init_args(self): 194 return self._cell_init_args 195 196 @property 197 def param_prefix(self): 198 """ 199 Param prefix is the prefix of current cell's direct child parameter. 200 201 Examples: 202 >>> import mindspore as ms 203 >>> from mindspore import Tensor, nn 204 ... 205 >>> class Net(nn.Cell): 206 ... def __init__(self): 207 ... super(Net, self).__init__() 208 ... self.dense = nn.Dense(2, 2) 209 ... 210 ... def construct(self, x): 211 ... x = self.dense(x) 212 ... return x 213 >>> net = Net() 214 >>> net.update_cell_prefix() 215 >>> print(net.dense.param_prefix) 216 dense 217 """ 218 return self._param_prefix 219 220 @property 221 def bprop_debug(self): 222 """ 223 Get whether cell custom bprop debug is enabled. 224 225 Tutorial Examples: 226 - `Cell and Parameter - Custom Cell Reverse 227 <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_ 228 """ 229 return self._bprop_debug 230 231 @bprop_debug.setter 232 def bprop_debug(self, value): 233 """ 234 Set whether to enable cell custom bprop debug. 235 236 Note: 237 When bprop is defined in cell, the bprop function will be executed 238 in python interpreter when bprop debug is true, and will be parsed 239 and add to graph when bprop debug is false. 240 241 Args: 242 value (bool): Specifies whether to enable bprop debug. Default: ``False``. 243 """ 244 if not isinstance(value, bool): 245 raise TypeError(f"For 'Cell', the property 'bprop_debug' must be bool type, but got type {type(value)}.") 246 self._bprop_debug = value 247 248 def update_cell_prefix(self): 249 """ 250 Update the `param_prefix` of all child cells. 251 252 After being invoked, it can get all the cell's children's name prefix by '_param_prefix'. 253 """ 254 cells_name = self.cells_and_names() 255 256 for cell_name, cell in cells_name: 257 cell._param_prefix = cell_name 258 259 def update_cell_type(self, cell_type): 260 """ 261 The current cell type is updated when a quantization aware training network is encountered. 262 263 After being invoked, it can set the cell type to 'cell_type'. 264 265 Args: 266 cell_type(str): The type of cell to be updated, cell_type can be "quant" or "second-order". 267 """ 268 self.cell_type = cell_type 269 270 @cell_init_args.setter 271 def cell_init_args(self, value): 272 if not isinstance(value, str): 273 raise TypeError(f"For 'Cell', the property 'cell_init_args' must be string type, " 274 f"but got type {type(value)}.") 275 self._cell_init_args = value 276 277 @property 278 def phase(self): 279 return self._phase 280 281 @phase.setter 282 def phase(self, value): 283 if not isinstance(value, str): 284 raise TypeError(f"For 'Cell', the property 'phase' must be string type, but got type {type(value)}.") 285 self._phase = value 286 287 @property 288 def parameter_layout_dict(self): 289 """ 290 `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and 291 distributed operator information. 292 """ 293 return self._parameter_layout_dict 294 295 @property 296 def cls_name(self): 297 return self.__class__.__name__ 298 299 @parameter_layout_dict.setter 300 def parameter_layout_dict(self, value): 301 if not isinstance(value, dict): 302 raise TypeError(f"For 'Cell', the property 'parameter_layout_dict' must be dict type, " 303 f"but got type {type(value)}.") 304 self._parameter_layout_dict = value 305 306 @property 307 def parallel_parameter_name_list(self): 308 return self._parallel_parameter_name_list 309 310 @parallel_parameter_name_list.setter 311 def parallel_parameter_name_list(self, value): 312 if not isinstance(value, list): 313 raise TypeError(f"For 'Cell', the property 'parallel_parameter_name_list' must be list type, " 314 f"but got type {type(value)}.") 315 self._parallel_parameter_name_list = value 316 317 @property 318 def pipeline_stage(self): 319 """ 320 `pipeline_stage` represents the pipeline stage of current Cell. 321 """ 322 return self._pipeline_stage 323 324 @pipeline_stage.setter 325 def pipeline_stage(self, value): 326 """ 327 Set the `pipeline_stage` of a Cell. 328 329 Args: 330 value (int): The pipeline stage of a parameter. 331 332 Raises: 333 TypeError: If `value` is not int type or is a bool type. 334 ValueError: If `value` is not a positive integer. 335 """ 336 if not isinstance(value, int) or isinstance(value, bool): 337 raise TypeError("For 'Cell', the property 'pipeline_stage' " 338 "must be int type, but got type : {}".format(type(value))) 339 340 if value < 0: 341 raise ValueError("For 'Cell', the property 'pipeline_stage' " 342 "can not be less than 0, but got {}".format(value)) 343 self._pipeline_stage = value 344 for item in self.trainable_params(): 345 item.add_pipeline_stage(value) 346 347 @property 348 def pipeline_segment(self): 349 return self._pipeline_segment 350 351 @pipeline_segment.setter 352 def pipeline_segment(self, value): 353 if not isinstance(value, int) or isinstance(value, bool): 354 raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' " 355 "must be int type, but got type : {}".format(type(value))) 356 357 if value < 0: 358 raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' " 359 "can not be less than 0, but got {}".format(value)) 360 self._pipeline_segment = value 361 362 @property 363 def parallel_parameter_merge_net_dict(self): 364 return self._parallel_parameter_merge_net_dict 365 366 @parallel_parameter_merge_net_dict.setter 367 def parallel_parameter_merge_net_dict(self, value): 368 if not isinstance(value, dict): 369 raise TypeError(f"For 'Cell', the property 'parallel_parameter_merge_net_dict' must be dict type, " 370 f"but got type {type(value)}.") 371 self._parallel_parameter_merge_net_dict = value 372 373 @property 374 def jit_config_dict(self): 375 return self._jit_config_dict 376 377 def get_func_graph_proto(self): 378 """Return graph binary proto.""" 379 exec_id = ".".join([self.phase, str(self.create_time), str(id(self))]) 380 return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True) 381 382 def __getattr__(self, name): 383 if '_params' in self.__dict__: 384 params = self.__dict__['_params'] 385 if name in params: 386 return params[name] 387 if '_cells' in self.__dict__: 388 cells = self.__dict__['_cells'] 389 if name in cells: 390 return cells[name] 391 if '_params_list' in self.__dict__: 392 params_list = self.__dict__['_params_list'] 393 if name in params_list: 394 return params_list[name] 395 raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name)) 396 397 def __del__(self): 398 if isinstance(cells_compile_cache, dict): 399 # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache 400 # here using pop(id(self), None) to avoid KeyError exception 401 cells_compile_cache.pop(id(self), None) 402 if hasattr(self, "compile_cache") and self.compile_cache: 403 _cell_graph_executor.del_net_res(self, self.compile_cache) 404 if isinstance(self, GraphCell): 405 _cell_graph_executor.dec_graph_cell_count() 406 Cell.total_instance_count -= 1 407 408 def __delattr__(self, name): 409 if name in self._params: 410 del self._params[name] 411 elif name in self._cells: 412 del self._cells[name] 413 elif '_params_list' in self.__dict__ and name in self._params_list: 414 del self._params_list[name] 415 else: 416 object.__delattr__(self, name) 417 self._attr_synced = False 418 419 def _cast_mixed_precision_inputs(self, inputs, dst_type): 420 """Cast input for mixed precision""" 421 res = list() 422 for item in inputs: 423 if isinstance(item, tuple): 424 res.append(self._cast_mixed_precision_inputs(item, dst_type)) 425 elif isinstance(item, float): 426 res.append(self.cast(item, dst_type)) 427 elif hasattr(item, "dtype") and item.dtype in \ 428 {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type: 429 res.append(self.cast(item, dst_type)) 430 else: 431 res.append(item) 432 return tuple(res) 433 434 def cast_inputs(self, inputs, dst_type): 435 """ 436 Cast inputs to specified type. 437 438 Args: 439 inputs (tuple[Tensor]): The cell inputs. 440 dst_type (mindspore.dtype): The specified data type. 441 442 returns: 443 tuple[Tensor], the result with destination data type. 444 """ 445 res = list() 446 for item in inputs: 447 if isinstance(item, tuple): 448 res.append(self.cast_inputs(item, dst_type)) 449 else: 450 res.append(self.cast(item, dst_type)) 451 return tuple(res) 452 453 def _do_parameter_broadcast(self): 454 if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL: 455 if not self.parameter_broadcast_done: 456 _pynative_executor.parameter_broadcast(self, self.phase) 457 self.parameter_broadcast_done = True 458 459 def run_construct(self, cast_inputs, kwargs): 460 """ 461 Run the construct function. 462 463 Note: 464 This function will be removed in a future version. It is not recommended to call this function. 465 466 Args: 467 cast_inputs (tuple): The input objects of Cell. 468 kwargs (dict): Provide keyword arguments. 469 470 Returns: 471 output, the output object of Cell. 472 """ 473 logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. " 474 f"Calling this function is not recommended.") 475 output = self._run_construct(cast_inputs, kwargs) 476 return output 477 478 def _run_construct(self, cast_inputs, kwargs): 479 """Run the construct function""" 480 if self._enable_forward_pre_hook: 481 cast_inputs = self._run_forward_pre_hook(cast_inputs) 482 if self._enable_backward_hook: 483 output = self._backward_hook_construct(*cast_inputs, **kwargs) 484 elif hasattr(self, "_shard_fn"): 485 output = self._shard_fn(*cast_inputs, **kwargs) 486 else: 487 if self.recompute_cell is not None: 488 output = self.recompute_cell(*cast_inputs, **kwargs) 489 else: 490 output = self.construct(*cast_inputs, **kwargs) 491 if self._enable_forward_hook: 492 output = self._run_forward_hook(cast_inputs, output) 493 return output 494 495 def _check_construct_args(self, *args): 496 """Check the args needed by the function construct""" 497 positional_args = 0 498 default_args = 0 499 has_var = False 500 for value in inspect.signature(self.construct).parameters.values(): 501 if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD: 502 has_var = True 503 if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: 504 if value.default is inspect.Parameter.empty: 505 positional_args += 1 506 else: 507 default_args += 1 508 509 if has_var: 510 return 511 512 if len(args) < positional_args: 513 raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, " 514 f"but got {len(args)}. When using set_inputs, please make sure that all networks " 515 f"and loss functions are configured with set_inputs.") 516 517 if len(args) > positional_args + default_args: 518 construct_inputs_names = self.construct.__code__.co_varnames 519 if 'self' not in construct_inputs_names: 520 raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ") 521 522 raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and " 523 f"{default_args} default argument, total {positional_args + default_args}, " 524 f"but got {len(args)}.") 525 526 def _hook_fn_registered(self): 527 '''Hook function in graph mode''' 528 # Check super().__init__() in graph mode. 529 try: 530 if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook: 531 return True 532 except AttributeError as e: 533 raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. " 534 f"Please use 'super().__init__()'.") from e 535 if not self._is_recursion_hook: 536 self._is_recursion_hook = True 537 for cell in self.cells(): 538 if cell._hook_fn_registered(): 539 return True 540 return False 541 542 def _get_prims_recursively(self): 543 all_prims = list() 544 for _, value in self._primitives.items(): 545 if value: 546 all_prims.append(value) 547 548 for cell in self.cells(): 549 all_prims.extend(cell._get_prims_recursively()) 550 551 return all_prims 552 553 def set_data_parallel(self): 554 """ 555 For all primitive ops in this cell(including ops of cells that wrapped by this cell), 556 if parallel strategy is not specified, then instead of auto-searching, data parallel 557 strategy will be generated for those primitive ops. 558 559 Note: 560 Only effective while using auto_parallel_context = ParallelMode.AUTO_PARALLEL under graph mode. 561 562 Examples: 563 >>> import mindspore.nn as nn 564 >>> net = nn.Dense(3, 4) 565 >>> net.set_data_parallel() 566 """ 567 if context._get_mode() == context.PYNATIVE_MODE: 568 raise ValueError("set_data_parallel: does not support PyNative mode.") 569 570 all_prims = self._get_prims_recursively() 571 for prim in all_prims: 572 prim.add_prim_attr("strategy_gen_mode", "data_parallel") 573 574 def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0): 575 """ 576 Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be 577 generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed 578 execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell, 579 strategy for others will be set by sharding propagation. 580 in_strategy and out_strategy define the input and output layout respectively. 581 in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of 582 this input/output, and None represents data_parallel, 583 which can refer to the description of `mindspore.ops.Primitive.shard`. 584 The parallel strategies of remaining operators are derived from the strategy specified by the input and output. 585 586 Note: 587 If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to 588 "auto_parallel" and the search mode (search_mode) to "sharding_propagation". 589 If the input contain Parameter, its strategy should be set in `in_strategy`. 590 591 Args: 592 in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple 593 defines the layout of the corresponding input and None represents a data parallel strategy. 594 out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy. 595 It is not in use right now. Default: ``None`` . 596 parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict 597 defines the layout of the parameter like "param_name: layout". 598 The key is a parameter name of type 'str'. 599 The value is a 1-D integer tuple, indicating the corresponding layout. 600 If the parameter name is incorrect or the corresponding parameter 601 has been set, the parameter setting will be ignored. 602 Default: ``None`` . 603 device (string): Select a certain device target. It is not in use right now. 604 Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` . 605 level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation 606 over communication ratio, maximize speed performance, minimize memory usage etc. It is not in 607 use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` . 608 609 Returns: 610 Function, return the cell construct function that will be executed under auto parallel process. 611 612 Examples: 613 >>> import mindspore.nn as nn 614 >>> 615 >>> class Block(nn.Cell): 616 ... def __init__(self): 617 ... self.dense1 = nn.Dense(10, 10) 618 ... self.relu = nn.ReLU() 619 ... self.dense2 = nn.Dense2(10, 10) 620 ... def construct(self, x): 621 ... x = self.relu(self.dense2(self.relu(self.dense1(x)))) 622 ... return x 623 >>> 624 >>> class example(nn.Cell): 625 ... def __init__(self): 626 ... self.block1 = Block() 627 ... self.block2 = Block() 628 ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,), 629 ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)}) 630 ... def construct(self, x): 631 ... x = self.block1(x) 632 ... x = self.block2_shard(x) 633 ... return x 634 """ 635 if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]: 636 raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel " 637 f"Please check the parallel mode in parallel context.") 638 639 shard_fn = Shard() 640 fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level) 641 object.__setattr__(self, "_shard_fn", fn) 642 return fn 643 644 def auto_cast_inputs(self, inputs): 645 """ 646 Auto cast inputs in mixed precision scenarios. 647 648 Args: 649 inputs (tuple): the inputs of construct. 650 651 Returns: 652 Tuple, the inputs after data type cast. 653 """ 654 msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version." 655 logger.warning(msg) 656 cast_inputs = inputs 657 mixed_type = self.get_mixed_precision_type() 658 if mixed_type == MixedPrecisionType.FP16: 659 cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16) 660 if mixed_type == MixedPrecisionType.FP32: 661 cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32) 662 663 return cast_inputs 664 665 def _init_check(self): 666 for param in self.get_parameters(expand=False): 667 if param.has_init: 668 param.init_data() 669 670 def _self_check(self): 671 if not self._is_check_and_refresh: 672 self.check_names_and_refresh_name() 673 self._is_check_and_refresh = True 674 675 def _predict(self, *args, **kwargs): 676 if not hasattr(self, "phase"): 677 return False, None 678 if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache: 679 new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args) 680 res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase]) 681 res = _convert_python_data(res) 682 return True, res 683 return False, None 684 685 def __call__(self, *args, **kwargs): 686 # Run in Graph mode. 687 if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE: 688 if kwargs: 689 bound_arguments = self.sig.bind(*args, **kwargs) 690 bound_arguments.apply_defaults() 691 args = bound_arguments.args 692 kwargs = bound_arguments.kwargs 693 694 predict_compiled, res = self._predict(*args, **kwargs) 695 if predict_compiled: 696 return res 697 self._check_construct_args(*args) 698 699 if self._hook_fn_registered(): 700 logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook " 701 f"function, please use context.set_context to set pynative mode.") 702 self._self_check() 703 out = self.compile_and_run(*args, **kwargs) 704 return out 705 706 # Run in PyNative mode. 707 self._self_check() 708 if not self._init_flag: 709 self._init_check() 710 self._init_flag = True 711 712 if self.requires_grad: 713 _pynative_executor.set_grad_flag(True) 714 715 try: 716 _pynative_executor.new_graph(self, *args, **kwargs) 717 output = self._run_construct(args, kwargs) 718 _pynative_executor.end_graph(self, output, *args, **kwargs) 719 except Exception as err: 720 _pynative_executor.clear_res() 721 raise err 722 723 return output 724 725 def _add_attr(self, name, value): 726 if name and name[:2] != '__' and name not in Cell.IGNORE_LIST: 727 super(Cell, self)._add_attr(name, value) 728 729 def _sync_attr_for_compile(self): 730 """Sync the attr to c++ object.""" 731 if self._attr_synced: 732 return 733 cells = self.__dict__.get('_cells') 734 for key in cells: 735 cell = cells[key] 736 cell._sync_attr_for_compile() 737 self._add_attr(key, cell) 738 params = self.__dict__.get('_params') 739 for key in params: 740 if '.' in key: 741 continue 742 param = params[key] 743 self._add_attr(key, param) 744 params_list = self.__dict__.get('_params_list') 745 for key in params_list: 746 params_list_item = params_list[key] 747 self._add_attr(key, params_list_item) 748 for key in self.__dict__: 749 value = self.__dict__[key] 750 self._add_attr(key, value) 751 self._attr_synced = True 752 753 def _set_attr_for_parameter(self, name, value): 754 """Set attr for parameter.""" 755 cells = self.__dict__.get('_cells') 756 params = self.__dict__.get('_params') 757 if params is None: 758 raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.") 759 if name in self.__dict__: 760 if self.__dict__[name] is not None: 761 raise TypeError(f"For 'Cell', the {name} should not be Parameter.") 762 del self.__dict__[name] 763 if cells and name in cells: 764 raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.") 765 self.insert_param_to_cell(name, value) 766 767 def _set_attr_for_parameter_tuple(self, name, value): 768 """Set attr for parameter in ParameterTuple.""" 769 params = self.__dict__.get('_params') 770 params_list = self.__dict__.get('_params_list') 771 if params is None: 772 raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.") 773 exist_names = set("") 774 exist_objs = set() 775 for item in value: 776 if item in exist_objs: 777 # If there are multiple identical objects, their names only check once. 778 continue 779 exist_objs.add(item) 780 if item.name == PARAMETER_NAME_DEFAULT: 781 logger.warning("For 'Cell', the parameter definition is deprecated.\n" 782 "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value)) 783 item.name = item.name + "$" + str(self._id) 784 self._id += 1 785 self.insert_param_to_cell(item.name, item, check_name_contain_dot=False) 786 if item.name in exist_names: 787 raise ValueError("The value {} , its name '{}' already exists. " 788 "Please set a unique name for the parameter.".format(value, item.name)) 789 exist_names.add(item.name) 790 791 if context._get_mode() == context.PYNATIVE_MODE: 792 if name in self.__dict__: 793 del self.__dict__[name] 794 if name in params: 795 del params[name] 796 params_list[name] = value 797 else: 798 object.__setattr__(self, name, value) 799 800 def _set_attr_for_parameter_in_list_or_tuple(self, name, value): 801 """Set attr for parameter in list or tuple.""" 802 for item in value: 803 if item in self.exist_objs: 804 # If there are multiple identical objects, their names only check once. 805 continue 806 self.exist_objs.add(item) 807 if item.name == PARAMETER_NAME_DEFAULT: 808 item.name = item.name + "$" + str(self._id) 809 self._id += 1 810 if item.name in self.exist_names: 811 raise ValueError("The value {} , its name '{}' already exists. " 812 "Please set a unique name for the parameter.".format(value, item.name)) 813 self.exist_names.add(item.name) 814 object.__setattr__(self, name, value) 815 816 def _set_attr_for_cell(self, name, value): 817 """Set attr for cell.""" 818 cells = self.__dict__.get('_cells') 819 params = self.__dict__.get('_params') 820 if cells is None: 821 raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.") 822 if name in self.__dict__: 823 del self.__dict__[name] 824 if params and name in params: 825 raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.") 826 if self._auto_prefix: 827 value.update_parameters_name(name + '.') 828 cells[name] = value 829 if hasattr(self, '_cell_init_args'): 830 self.cell_init_args += str({name: value}) 831 832 def _set_attr_for_params(self, name, value): 833 if isinstance(value, Tensor) and self._params[name] is not None: 834 self._params[name].set_data(value) 835 elif value is not None: 836 raise TypeError(f"For 'Cell', the type of {name} must be Parameter or ParameterTuple, " 837 f"but got {type(value).__name__}.") 838 else: 839 self.insert_param_to_cell(name, None) 840 841 def __setattr__(self, name, value): 842 cells = self.__dict__.get('_cells') 843 params = self.__dict__.get('_params') 844 if isinstance(value, Parameter): 845 self._set_attr_for_parameter(name, value) 846 elif isinstance(value, ParameterTuple): 847 self._set_attr_for_parameter_tuple(name, value) 848 elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value): 849 self._set_attr_for_parameter_in_list_or_tuple(name, value) 850 elif isinstance(value, Cell): 851 self._set_attr_for_cell(name, value) 852 elif params and name in params: 853 self._set_attr_for_params(name, value) 854 elif cells and name in cells: 855 if value is not None: 856 raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.") 857 self._cells[name] = None 858 else: 859 if isinstance(value, Primitive): 860 value.set_prim_instance_name(name) 861 self._primitives[name] = value 862 object.__setattr__(self, name, value) 863 if name not in Cell.IGNORE_LIST: 864 self._attr_synced = False 865 866 def extend_repr(self): 867 """ 868 Expand the description of Cell. 869 870 To print customized extended information, re-implement this method in your own cells. 871 """ 872 return '' 873 874 def __str__(self): 875 return self.__repr__() 876 877 def __repr__(self): 878 extra_str = self.extend_repr() 879 info_str = self.__class__.__name__ + '<' 880 if self._cells: 881 sub_str = '\n' 882 if extra_str: 883 sub_str += '{}\n'.format(self.extend_repr()) 884 for key, value in self._cells.items(): 885 sub_str += '({}): {}\n'.format(key, repr(value)) 886 sub_str = sub_str.replace('\n', '\n ') + '>' 887 info_str += sub_str 888 else: 889 info_str += extra_str + '>' 890 return info_str 891 892 def load_parameter_slice(self, params): 893 """ 894 Replace parameters with sliced tensors by parallel strategies. 895 896 Note: 897 This interface is deprecated. 898 """ 899 logger.warning("'load_parameter_slice' function is deprecated.") 900 901 def set_parallel_input_with_inputs(self, *inputs): 902 """ 903 Slice inputs tensors by parallel strategies. 904 905 Note: 906 This interface is deprecated. 907 """ 908 logger.warning("'set_parallel_input_with_inputs' function is deprecated.") 909 910 def set_inputs(self, *inputs, **kwargs): 911 """ 912 Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When 913 using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are 914 configured with set_inputs. The shape of input Tensor can be either dynamic or static. 915 916 .. note:: 917 There are two mode: 918 919 - Full mode: arguments will be used as all compile inputs for graph-compiling. 920 - Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input 921 at the corresponding position for graph-compiling. 922 923 Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode. 924 925 Args: 926 inputs (tuple): Full mode arguments. 927 kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined 928 in `self.construct`. 929 930 .. warning:: 931 This is an experimental API that is subject to change or deletion. 932 933 Examples: 934 >>> import numpy as np 935 >>> import mindspore as ms 936 >>> from mindspore import nn, Tensor 937 >>> 938 >>> class ReluNet(nn.Cell): 939 ... def __init__(self): 940 ... super(ReluNet, self).__init__() 941 ... self.relu = nn.ReLU() 942 ... def construct(self, x): 943 ... return self.relu(x) 944 >>> 945 >>> net = ReluNet() 946 >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) 947 >>> net.set_inputs(input_dyn) 948 >>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32) 949 >>> output = net(input) 950 >>> 951 >>> net2 = ReluNet() 952 >>> net2.set_inputs(x=input_dyn) 953 >>> output = net2(input) 954 """ 955 if self.grad_ops_label: 956 logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is ' 957 f'generated.') 958 if kwargs and inputs: 959 raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!' 960 % (inputs, kwargs)) 961 962 if not kwargs: 963 self._dynamic_shape_inputs = inputs 964 self._check_construct_args(*inputs) 965 if context._get_mode() == context.PYNATIVE_MODE: 966 _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs) 967 else: 968 self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs) 969 970 def get_inputs(self): 971 """ 972 Returns the dynamic_inputs of a cell object in one network. 973 974 Returns: 975 inputs (tuple), Inputs of the Cell object. 976 977 .. warning:: 978 This is an experimental API that is subject to change or deletion. 979 980 Examples: 981 >>> import numpy as np 982 >>> import mindspore as ms 983 >>> from mindspore import nn, Tensor 984 >>> 985 >>> class ReluNet(nn.Cell): 986 ... def __init__(self): 987 ... super(ReluNet, self).__init__() 988 ... self.relu = nn.ReLU() 989 ... def construct(self, x): 990 ... return self.relu(x) 991 >>> 992 >>> net = ReluNet() 993 >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) 994 >>> net.set_inputs(input_dyn) 995 >>> get_inputs = net.get_inputs() 996 >>> print(get_inputs) 997 (Tensor(shape=[3, -1], dtype=Float32, value= ),) 998 999 """ 1000 1001 return self._dynamic_shape_inputs 1002 1003 def _check_parameter_consistency(self, set_inputs, net_inputs): 1004 """Check consistency for parameter.""" 1005 for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)): 1006 if isinstance(set_input, Tensor): 1007 if not isinstance(net_input, Tensor): 1008 raise TypeError( 1009 f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must " 1010 f"be Tensor, but got {type(net_input)}.") 1011 if isinstance(set_input, Parameter) != isinstance(net_input, Parameter): 1012 raise TypeError( 1013 f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same " 1014 f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.") 1015 elif isinstance(set_input, (tuple, list)): 1016 if not isinstance(net_input, (tuple, list)): 1017 raise TypeError( 1018 f"The {index + 1}th input type of 'set_inputs' or tuple(list) in " 1019 f"'set_inputs' must be tuple or list, but got {type(net_input)}.") 1020 self._check_parameter_consistency(set_input, net_input) 1021 1022 def _get_compile_args(self, args): 1023 """Get compile arguments.""" 1024 # this is used only for test 1025 set_by_auto_dynamic = False 1026 if is_auto_dynamic(): 1027 if self._dynamic_shape_inputs is None: 1028 set_by_auto_dynamic = True 1029 else: 1030 if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None: 1031 set_by_auto_dynamic = True 1032 if set_by_auto_dynamic: 1033 self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args) 1034 1035 if self._dynamic_shape_inputs is not None: 1036 logger.debug("Compiled Graph with dynamic shape") 1037 compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs) 1038 _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs") 1039 self._check_parameter_consistency(compile_args, args) 1040 Validator.check_symbolic_shape(compile_args, args) 1041 self.saved_dynamic_shape = compile_args 1042 return compile_args 1043 return args 1044 1045 def compile(self, *args, **kwargs): 1046 """ 1047 Compile Cell as a computation graph, the input must be consistent with the input defined in construct. 1048 1049 Args: 1050 args (tuple): Args of the Cell object. 1051 kwargs (dict): Kwargs of the Cell object. 1052 """ 1053 self._compile_args = self._get_compile_args(args) 1054 _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase, 1055 jit_config_dict=self._jit_config_dict, **kwargs) 1056 1057 def compile_and_run(self, *args, **kwargs): 1058 """ 1059 Compile and run Cell, the input must be consistent with the input defined in construct. 1060 1061 Note: 1062 It is not recommended to call directly. 1063 1064 Args: 1065 args (tuple): Args of the Cell object. 1066 kwargs (dict): Kwargs of the Cell object. 1067 1068 Returns: 1069 Object, the result of executing. 1070 """ 1071 self.compile(*args, **kwargs) 1072 self.add_flags(ge_sync_data=False) 1073 new_args = _get_args_for_run(self, args, kwargs, self._compile_args) 1074 return _cell_graph_executor(self, *new_args, phase=self.phase) 1075 1076 def auto_parallel_compile_and_run(self): 1077 """ 1078 Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode. 1079 1080 Note: 1081 This interface is deprecated. 1082 """ 1083 logger.warning("'auto_parallel_compile_and_run' function is deprecated.") 1084 1085 def exec_checkpoint_graph(self): 1086 """Executes GE saving checkpoint graph operation.""" 1087 logger.warning("'exec_checkpoint_graph' function is deprecated.") 1088 self.add_flags(ge_sync_data=True) 1089 _cell_graph_executor(self, phase='save') 1090 1091 def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True): 1092 """ 1093 Adds a parameter to the current cell. 1094 1095 Inserts a parameter with given name to the cell. The method is currently used in 1096 `mindspore.nn.Cell.__setattr__`. 1097 1098 Args: 1099 param_name (str): Name of the parameter. 1100 param (Parameter): Parameter to be inserted to the cell. 1101 check_name_contain_dot (bool): Determines whether the name input is compatible. Default: ``True`` . 1102 1103 Raises: 1104 KeyError: If the name of parameter is null or contains dot. 1105 TypeError: If the type of parameter is not Parameter. 1106 1107 Examples: 1108 >>> import mindspore as ms 1109 >>> from mindspore import Tensor, nn, Parameter 1110 ... 1111 >>> class Net(nn.Cell): 1112 ... def __init__(self): 1113 ... super(Net, self).__init__() 1114 ... self.relu = nn.ReLU() 1115 ... 1116 ... def construct(self, x): 1117 ... x = self.relu(x) 1118 ... return x 1119 >>> net = Net() 1120 >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3]))) 1121 >>> print(net.bias) 1122 Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True) 1123 """ 1124 if not param_name: 1125 raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.") 1126 if check_name_contain_dot and '.' in param_name: 1127 raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ") 1128 if '_params' not in self.__dict__: 1129 raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.") 1130 if hasattr(self, param_name) and param_name not in self._params: 1131 raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network." 1132 f"Cannot insert another parameter with the same name.") 1133 if not isinstance(param, Parameter) and param is not None: 1134 raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, " 1135 f"but got {type(param)}.") 1136 if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT: 1137 param.name = param_name 1138 self._params[param_name] = param 1139 1140 def cast_param(self, param): 1141 """ 1142 Cast parameter according to auto mix precision level in pynative mode. 1143 1144 This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly. 1145 1146 Args: 1147 param (Parameter): Parameters, the type of which should be cast. 1148 1149 Returns: 1150 Parameter, the input parameter with type automatically cast. 1151 """ 1152 msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version." 1153 logger.warning(msg) 1154 mixed_type = self.get_mixed_precision_type() 1155 if mixed_type != MixedPrecisionType.NOTSET: 1156 if mixed_type == MixedPrecisionType.FP32: 1157 param.set_cast_dtype(mstype.float32) 1158 elif mixed_type == MixedPrecisionType.FP16: 1159 param.set_cast_dtype(mstype.float16) 1160 elif hasattr(param, "set_cast_dtype"): 1161 # retest dtype 1162 param.set_cast_dtype() 1163 return param 1164 1165 def insert_child_to_cell(self, child_name, child_cell): 1166 """ 1167 Adds a child cell to the current cell with a given name. 1168 1169 Args: 1170 child_name (str): Name of the child cell. 1171 child_cell (Cell): The child cell to be inserted. 1172 1173 Raises: 1174 KeyError: Child Cell's name is incorrect or duplicated with the other child name. 1175 TypeError: If type of `child_name` is not str. 1176 TypeError: Child Cell's type is incorrect. 1177 1178 Examples: 1179 >>> import mindspore as ms 1180 >>> from mindspore import Tensor, nn 1181 ... 1182 >>> net1 = nn.ReLU() 1183 >>> net2 = nn.Dense(2, 2) 1184 >>> net1.insert_child_to_cell("child", net2) 1185 >>> print(net1) 1186 ReLU< 1187 (child): Dense<input_channels=2, output_channels=2, has_bias=True> 1188 > 1189 """ 1190 if not isinstance(child_name, str): 1191 raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, " 1192 f"but got {type(child_name)}.") 1193 if not child_name or '.' in child_name: 1194 raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and " 1195 "can not contain '.' ") 1196 if hasattr(self, child_name) and child_name not in self._cells: 1197 raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network." 1198 f"Cannot insert another child cell with the same name.") 1199 if not isinstance(child_cell, Cell) and child_cell is not None: 1200 raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, " 1201 f"but got type {type(child_cell)}.") 1202 self._cells[child_name] = child_cell 1203 1204 def construct(self, *args, **kwargs): 1205 """ 1206 Defines the computation to be performed. This method must be overridden by all subclasses. 1207 1208 Note: 1209 It is not supported currently that inputs contain both tuple and non-tuple types at same time. 1210 1211 Args: 1212 args (tuple): Tuple of variable parameters. 1213 kwargs (dict): Dictionary of variable keyword parameters. 1214 1215 Returns: 1216 Tensor, returns the computed result. 1217 """ 1218 raise AttributeError("For 'Cell', the method 'construct' is not defined.") 1219 1220 def remove_redundant_parameters(self): 1221 """ 1222 Remove the redundant parameters. 1223 1224 This interface usually needs not to be used explicitly. 1225 """ 1226 cells = self.cells_and_names() 1227 for _, cell in cells: 1228 params = cell._params.items() 1229 for param_name, param in list(params): 1230 if param.name not in self.parallel_parameter_name_list: 1231 cell._params.pop(param_name) 1232 logger.info("remove the redundant parameter: %s", param.name) 1233 continue 1234 cell_dict = cell.__dict__ 1235 for key in cell_dict: 1236 if isinstance(cell_dict[key], ParameterTuple): 1237 param_tuple = cell_dict[key] 1238 new_param_tuple = [] 1239 for param in param_tuple: 1240 if param.name not in self.parallel_parameter_name_list: 1241 logger.info("remove the redundant parameter: %s in ParameterTuple", param.name) 1242 continue 1243 new_param_tuple.append(param) 1244 cell.__dict__[key] = ParameterTuple(new_param_tuple) 1245 1246 def init_parameters_data(self, auto_parallel_mode=False): 1247 """ 1248 Initialize all parameters and replace the original saved parameters in cell. 1249 1250 Note: 1251 trainable_params() and other similar interfaces may return different parameter instance after 1252 `init_parameters_data`, do not save these results. 1253 1254 Args: 1255 auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` . 1256 1257 Returns: 1258 Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. 1259 1260 Examples: 1261 >>> import mindspore as ms 1262 >>> from mindspore import Tensor, nn 1263 ... 1264 >>> class Net(nn.Cell): 1265 ... def __init__(self): 1266 ... super(Net, self).__init__() 1267 ... self.dense = nn.Dense(2, 2) 1268 ... 1269 ... def construct(self, x): 1270 ... x = self.dense(x) 1271 ... return x 1272 >>> net = Net() 1273 >>> print(net.init_parameters_data()) 1274 {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True): 1275 Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True), 1276 Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True): 1277 Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)} 1278 """ 1279 replace = dict() 1280 1281 def _updata(param): 1282 if param in replace: 1283 return replace.get(param) 1284 new_p = param.init_data(None, set_sliced=False) 1285 replace[param] = new_p 1286 return new_p 1287 1288 # replace all original usage. 1289 cells = self.cells_and_names() 1290 for _, cell in cells: 1291 params = cell._params.items() 1292 for param_name, param in params: 1293 if not auto_parallel_mode: 1294 cell._params[param_name] = _updata(param) 1295 continue 1296 if param.name in self.parallel_parameter_name_list: 1297 cell._params[param_name] = _updata(param) 1298 cell_dict = cell.__dict__ 1299 for key in cell_dict: 1300 if isinstance(cell_dict[key], ParameterTuple): 1301 param_tuple = cell_dict[key] 1302 new_param_tuple = [] 1303 for param in param_tuple: 1304 if not auto_parallel_mode: 1305 new_param_tuple.append(_updata(param)) 1306 continue 1307 if param.name in self.parallel_parameter_name_list: 1308 new_param_tuple.append(_updata(param)) 1309 else: 1310 new_param_tuple.append(param) 1311 cell.__dict__[key] = ParameterTuple(new_param_tuple) 1312 return replace 1313 1314 def parameters_dict(self, recurse=True): 1315 """ 1316 Gets the parameters dictionary of this cell. 1317 1318 Args: 1319 recurse (bool): Whether contains the parameters of subcells. Default: ``True`` . 1320 1321 Returns: 1322 OrderedDict, return parameters dictionary. 1323 1324 Examples: 1325 >>> import mindspore as ms 1326 >>> from mindspore import Tensor, nn, Parameter 1327 ... 1328 >>> class Net(nn.Cell): 1329 ... def __init__(self): 1330 ... super(Net, self).__init__() 1331 ... self.dense = nn.Dense(2, 2) 1332 ... 1333 ... def construct(self, x): 1334 ... x = self.dense(x) 1335 ... return x 1336 >>> net = Net() 1337 >>> print(net.parameters_dict()) 1338 OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32, 1339 requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32, 1340 requires_grad=True))]) 1341 """ 1342 param_dict = OrderedDict() 1343 for param in self.get_parameters(expand=recurse): 1344 param_dict[param.name] = param 1345 return param_dict 1346 1347 def parameters_broadcast_dict(self, recurse=True): 1348 """ 1349 Gets the parameters broadcast dictionary of this cell. 1350 1351 Args: 1352 recurse (bool): Whether contains the parameters of subcells. Default: ``True`` . 1353 1354 Returns: 1355 OrderedDict, return parameters broadcast dictionary. 1356 """ 1357 param_dict = OrderedDict() 1358 for param in self.get_parameters(expand=recurse): 1359 if param.layerwise_parallel is False: 1360 param_dict[param.name] = param 1361 if not param_dict: 1362 return None 1363 return param_dict 1364 1365 def update_parameters_name(self, prefix='', recurse=True): 1366 """ 1367 Adds the `prefix` string to the names of parameters. 1368 1369 Args: 1370 prefix (str): The prefix string. Default: ``''`` . 1371 recurse (bool): Whether contains the parameters of subcells. Default: ``True`` . 1372 """ 1373 1374 Validator.check_str_and_none_by_regular(prefix) 1375 for name, param in self.parameters_and_names(expand=recurse): 1376 if prefix != '': 1377 param.is_init = False 1378 param.name = prefix + name 1379 1380 def _update_local_parameters_name(self, prefix='', recurse=True): 1381 """ 1382 Updates the names of local parameters with given prefix string. 1383 1384 Adds the given prefix to the names of local parameters. 1385 1386 Local parameters means the parameters without user input. 1387 1388 Args: 1389 prefix (str): The prefix string. Default: ''. 1390 recurse (bool): Whether contains the parameters of subcells. Default: ``True``. 1391 """ 1392 1393 Validator.check_str_by_regular(prefix) 1394 for name, param in self.parameters_and_names(expand=recurse): 1395 if name in self._user_parameters: 1396 continue 1397 if prefix != '': 1398 param.is_init = False 1399 param.name = prefix + name 1400 1401 @jit_forbidden_register 1402 def trainable_params(self, recurse=True): 1403 """ 1404 Returns all trainable parameters. 1405 1406 Returns a list of all trainable parameters. 1407 1408 Args: 1409 recurse (bool): Whether contains the trainable parameters of subcells. Default: ``True`` . 1410 1411 Returns: 1412 List, the list of trainable parameters. 1413 1414 Tutorial Examples: 1415 - `Model Training - Optimizer 1416 <https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_ 1417 """ 1418 return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse))) 1419 1420 @jit_forbidden_register 1421 def untrainable_params(self, recurse=True): 1422 """ 1423 Returns all untrainable parameters. 1424 1425 Returns a list of all untrainable parameters. 1426 1427 Args: 1428 recurse (bool): Whether contains the untrainable parameters of subcells. Default: ``True`` . 1429 1430 Returns: 1431 List, the list of untrainable parameters. 1432 """ 1433 return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse))) 1434 1435 @jit_forbidden_register 1436 def get_parameters(self, expand=True): 1437 """ 1438 Returns an iterator over cell parameters. 1439 1440 Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells. 1441 For more details about subcells, please see the example below. 1442 1443 Args: 1444 expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield 1445 parameters that are direct members of this cell. Default: ``True`` . 1446 1447 Returns: 1448 Iteration, all parameters at the cell. 1449 1450 Examples: 1451 >>> import mindspore as ms 1452 >>> from mindspore import nn, ops, Tensor 1453 >>> import numpy as np 1454 >>> class TestNet(nn.Cell): 1455 ... def __init__(self): 1456 ... super().__init__() 1457 ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32)) 1458 ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32)) 1459 ... def construct(self, x): 1460 ... x += self.my_w1 1461 ... x = ops.reshape(x, (16,)) - self.my_w2 1462 ... return x 1463 >>> class TestNet2(nn.Cell): 1464 ... def __init__(self): 1465 ... super().__init__() 1466 ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32)) 1467 ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will 1468 ... # also be gathered. 1469 ... self.subcell = TestNet() 1470 ... def construct(self, x): 1471 ... x += self.my_w1 1472 ... x = ops.reshape(x, (16,)) - self.my_w2 1473 ... return x 1474 >>> net = TestNet2() 1475 >>> print([p for p in net.get_parameters(expand=True)]) 1476 [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1, 1477 shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32, 1478 requires_grad=True)] 1479 """ 1480 for _, param in self.parameters_and_names(expand=expand): 1481 yield param 1482 1483 # pylint: disable=missing-docstring 1484 def check_names_and_refresh_name(self): 1485 if not hasattr(self, "_params"): 1486 return 1487 all_name = [i.name for i in dict(self.parameters_and_names()).values()] 1488 if len(set(all_name)) < len(all_name): 1489 self.update_parameters_name() 1490 self.check_names() 1491 1492 def check_names(self): 1493 """ 1494 Check the names of cell parameters. 1495 """ 1496 names = set("") 1497 for value, param in self.parameters_and_names(): 1498 if param.name in names: 1499 raise ValueError("The value of {} is {}, its name '{}' already exists. " 1500 "Please set a unique name for the parameter.".format(value, param, param.name)) 1501 names.add(param.name) 1502 1503 def parameters_and_names(self, name_prefix='', expand=True): 1504 """ 1505 Returns an iterator over cell parameters. 1506 1507 Includes the parameter's name and itself. 1508 1509 Args: 1510 name_prefix (str): Namespace. Default: ``''`` . 1511 expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters 1512 that are direct members of this cell. Default: ``True`` . 1513 1514 Returns: 1515 Iteration, all the names and corresponding parameters in the cell. 1516 1517 Examples: 1518 >>> from mindspore import nn 1519 >>> n = nn.Dense(3, 4) 1520 >>> names = [] 1521 >>> for m in n.parameters_and_names(): 1522 ... if m[0]: 1523 ... names.append(m[0]) 1524 1525 Tutorial Examples: 1526 - `Building a Network - Model Parameters 1527 <https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_ 1528 """ 1529 cells = [] 1530 if expand: 1531 cells = self.cells_and_names(name_prefix=name_prefix) 1532 else: 1533 cells.append((name_prefix, self)) 1534 1535 params_set = set() 1536 for cell_name, cell in cells: 1537 params = cell._params.items() 1538 for par_name, par in params: 1539 if par is not None and par.inited_param is not None: 1540 par = par.inited_param 1541 if par is not None and id(par) not in params_set: 1542 params_set.add(id(par)) 1543 par_new_name = par_name 1544 if cell_name: 1545 par_new_name = cell_name + '.' + par_new_name 1546 1547 yield par_new_name, par 1548 1549 def cells_and_names(self, cells=None, name_prefix=''): 1550 """ 1551 Returns an iterator over all cells in the network, including the cell's name and itself. 1552 1553 Args: 1554 cells (str): Cells to iterate over. Default: ``None`` . 1555 name_prefix (str): Namespace. Default: ``''`` . 1556 1557 Returns: 1558 Iteration, all the child cells and corresponding names in the cell. 1559 1560 Examples: 1561 >>> from mindspore import nn 1562 >>> class Net(nn.Cell): 1563 ... def __init__(self): 1564 ... super(Net, self).__init__() 1565 ... self.conv = nn.Conv2d(3, 64, 3) 1566 ... def construct(self, x): 1567 ... out = self.conv(x) 1568 ... return out 1569 >>> names = [] 1570 >>> n = Net() 1571 >>> for m in n.cells_and_names(): 1572 ... if m[0]: 1573 ... names.append(m[0]) 1574 """ 1575 t_cells = cells if cells else set() 1576 if self in t_cells: 1577 return 1578 1579 t_cells.add(self) 1580 yield name_prefix, self 1581 1582 for name, cell in self._cells.items(): 1583 if cell: 1584 cells_name_prefix = name 1585 if name_prefix: 1586 cells_name_prefix = name_prefix + '.' + cells_name_prefix 1587 for ele in cell.cells_and_names(t_cells, cells_name_prefix): 1588 yield ele 1589 1590 def cells(self): 1591 """ 1592 Returns an iterator over immediate cells. 1593 1594 Returns: 1595 Iteration, the immediate cells in the cell. 1596 1597 Examples: 1598 >>> import mindspore as ms 1599 >>> from mindspore import Tensor, nn 1600 ... 1601 >>> class Net(nn.Cell): 1602 ... def __init__(self): 1603 ... super(Net, self).__init__() 1604 ... self.dense = nn.Dense(2, 2) 1605 ... 1606 ... def construct(self, x): 1607 ... x = self.dense(x) 1608 ... return x 1609 >>> net = Net() 1610 >>> print(net.cells()) 1611 odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>]) 1612 """ 1613 return self.name_cells().values() 1614 1615 def _set_scope(self, name): 1616 """Sets the name on the first time.""" 1617 if self._scope is None: 1618 self._scope = name 1619 elif self._scope == 'recompute_': 1620 self._scope = self._scope + name 1621 1622 def _children_scope_recursive(self, parent_prefix='Default'): 1623 """Generates the scope of each layer of the network recursively.""" 1624 reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope") 1625 1626 for name, cell in self.name_cells().items(): 1627 class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else "" 1628 yield parent_prefix + "/" + name + class_name, cell 1629 1630 for name, cell in self.name_cells().items(): 1631 class_name = ("-" + cell.__class__.__name__) if reserve_class_name_in_scope else "" 1632 for key, value in cell._children_scope_recursive(parent_prefix + "/" + name + class_name): 1633 yield key, value 1634 1635 def get_scope(self): 1636 """ 1637 Returns the scope of a cell object in one network. 1638 1639 Returns: 1640 String, scope of the cell. 1641 """ 1642 return self._scope 1643 1644 def generate_scope(self): 1645 """Generate the scope for each cell object in the network.""" 1646 for name, cell in self._children_scope_recursive(): 1647 cell._set_scope(name) 1648 1649 def name_cells(self): 1650 """ 1651 Returns an iterator over all immediate cells in the network. 1652 1653 Include name of the cell and cell itself. 1654 1655 Returns: 1656 Dict, all the child cells and corresponding names in the cell. 1657 1658 Examples: 1659 >>> import mindspore as ms 1660 >>> from mindspore import Tensor, nn 1661 ... 1662 >>> class Net(nn.Cell): 1663 ... def __init__(self): 1664 ... super(Net, self).__init__() 1665 ... self.dense = nn.Dense(2, 2) 1666 ... 1667 ... def construct(self, x): 1668 ... x = self.dense(x) 1669 ... return x 1670 >>> net = Net() 1671 >>> print(net.name_cells()) 1672 OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)]) 1673 """ 1674 value_set = set() 1675 cells = OrderedDict() 1676 for name, cell in self._cells.items(): 1677 if cell is not None and cell not in value_set: 1678 value_set.add(cell) 1679 cells[name] = cell 1680 return cells 1681 1682 def _add_mixed_precision_flag(self, **flags): 1683 """Add mixed precision flag to current cell""" 1684 if "fp16" in flags and flags.get("fp16", False): 1685 Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16) 1686 if "fp32" in flags and flags.get("fp32", False): 1687 Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32) 1688 if "bf16" in flags and flags.get("bf16", False): 1689 Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16) 1690 1691 def apply(self, fn): 1692 """ 1693 Applies fn recursively to every subcell (as returned by .cells()) as well as self. 1694 Typical use includes initializing the parameters of a model. 1695 1696 Args: 1697 fn (function): function to be applied to each subcell. 1698 1699 Returns: 1700 Cell, self. 1701 1702 Examples: 1703 >>> import mindspore.nn as nn 1704 >>> from mindspore.common.initializer import initializer, One 1705 >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2)) 1706 >>> def func(cell): 1707 ... if isinstance(cell, nn.Dense): 1708 ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype)) 1709 >>> net.apply(func) 1710 SequentialCell< 1711 (0): Dense<input_channels=2, output_channels=2, has_bias=True> 1712 (1): Dense<input_channels=2, output_channels=2, has_bias=True> 1713 > 1714 >>> print(net[0].weight.asnumpy()) 1715 [[1. 1.] 1716 [1. 1.]] 1717 """ 1718 for cell in self.cells(): 1719 cell.apply(fn) 1720 fn(self) 1721 return self 1722 1723 def add_flags(self, **flags): 1724 """ 1725 Add customized attributes for cell. 1726 1727 This method is also called when the cell class is instantiated and the class parameter 'flags' is set to True. 1728 1729 Args: 1730 flags (dict): Network configuration information, currently it is used for the binding of network and 1731 dataset. Users can also customize network attributes by this parameter. 1732 1733 Examples: 1734 >>> import mindspore as ms 1735 >>> from mindspore import Tensor, nn 1736 ... 1737 >>> class Net(nn.Cell): 1738 ... def __init__(self): 1739 ... super(Net, self).__init__() 1740 ... self.relu = nn.ReLU() 1741 ... 1742 ... def construct(self, x): 1743 ... x = self.relu(x) 1744 ... return x 1745 >>> net = Net() 1746 >>> net.add_flags(sink_mode=True) 1747 >>> print(net.sink_mode) 1748 True 1749 """ 1750 if not hasattr(self, "_func_graph_flags"): 1751 self._func_graph_flags = {} 1752 self._func_graph_flags.update({**flags}) 1753 if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"): 1754 raise TypeError("Recompute is not supported in PyNative mode currently, you can use " 1755 "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.") 1756 self.__dict__.update({**flags}) 1757 self._add_mixed_precision_flag(**flags) 1758 return self 1759 1760 def add_flags_recursive(self, **flags): 1761 """ 1762 If a cell contains child cells, this method can recursively customize attributes of all cells. 1763 1764 Args: 1765 flags (dict): Network configuration information, currently it is used for the binding of network and 1766 dataset. Users can also customize network attributes by this parameter. 1767 1768 Examples: 1769 >>> import mindspore as ms 1770 >>> from mindspore import Tensor, nn 1771 ... 1772 >>> class Net(nn.Cell): 1773 ... def __init__(self): 1774 ... super(Net, self).__init__() 1775 ... self.relu = nn.ReLU() 1776 ... 1777 ... def construct(self, x): 1778 ... x = self.relu(x) 1779 ... return x 1780 >>> net = Net() 1781 >>> net.add_flags_recursive(sink_mode=True) 1782 >>> print(net.sink_mode) 1783 True 1784 """ 1785 self.add_flags(**flags) 1786 for cell in self.cells(): 1787 cell.add_flags_recursive(**flags) 1788 return self 1789 1790 def _add_init_args(self, **args): 1791 if hasattr(self, '_cell_init_args'): 1792 self._cell_init_args += str({**args}) 1793 1794 def get_flags(self): 1795 """ 1796 Get the self_defined attributes of the cell, which can be added by `add_flags` method. 1797 1798 Examples: 1799 >>> import mindspore as ms 1800 >>> from mindspore import Tensor, nn 1801 ... 1802 >>> class Net(nn.Cell): 1803 ... def __init__(self): 1804 ... super(Net, self).__init__() 1805 ... self.relu = nn.ReLU() 1806 ... 1807 ... def construct(self, x): 1808 ... x = self.relu(x) 1809 ... return x 1810 >>> net = Net() 1811 >>> net.add_flags(sink_mode=True) 1812 >>> print(net.get_flags()) 1813 {'sink_mode':True} 1814 """ 1815 if not hasattr(self, "_func_graph_flags"): 1816 self._func_graph_flags = {} 1817 return self._func_graph_flags 1818 1819 def to_float(self, dst_type): 1820 """ 1821 Add cast on all inputs of cell and child cells to run with certain float type. 1822 1823 If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will 1824 be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`. 1825 1826 Note: 1827 Multiple calls will overwrite. 1828 1829 Args: 1830 dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type. 1831 dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`. 1832 1833 Returns: 1834 Cell, the cell itself. 1835 1836 Raises: 1837 ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`. 1838 1839 Supported Platforms: 1840 ``Ascend`` ``GPU`` ``CPU`` 1841 1842 Examples: 1843 >>> import mindspore.nn as nn 1844 >>> from mindspore import dtype as mstype 1845 >>> 1846 >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') 1847 >>> net.to_float(mstype.float16) 1848 Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, 1849 padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW> 1850 """ 1851 if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16): 1852 raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or " 1853 "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type)) 1854 flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32, 1855 'bf16': dst_type == mstype.bfloat16} 1856 self._add_init_args(**flags) 1857 self.add_flags_recursive(**flags) 1858 return self 1859 1860 def set_boost(self, boost_type): 1861 """ 1862 In order to improve the network performance, configure the network auto enable to 1863 accelerate the algorithm in the algorithm library. 1864 1865 If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through 1866 `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_. 1867 1868 Note: 1869 Some acceleration algorithms may affect the accuracy of the network, please choose carefully. 1870 1871 Args: 1872 boost_type (str): accelerate algorithm. 1873 1874 Returns: 1875 Cell, the cell itself. 1876 1877 Raises: 1878 ValueError: If boost_type is not in the algorithm library. 1879 """ 1880 if boost_type not in ("less_bn",): 1881 raise ValueError("For 'set_boost', the argument 'boost_type' must be 'less_bn', " 1882 "but got {}.".format(boost_type)) 1883 flags = {"less_bn": boost_type == "less_bn"} 1884 self.add_flags_recursive(**flags) 1885 return self 1886 1887 def set_grad(self, requires_grad=True): 1888 """ 1889 Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires 1890 gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward 1891 network is executed. 1892 1893 Args: 1894 requires_grad (bool): Specifies if the net need to grad, if it is 1895 ``true`` , the cell will construct backward network in pynative mode. Default: ``True`` . 1896 1897 Returns: 1898 Cell, the cell itself. 1899 """ 1900 self.requires_grad = requires_grad 1901 return self 1902 1903 def set_train(self, mode=True): 1904 """ 1905 Sets the cell to training mode. 1906 1907 The cell itself and all children cells will be set to training mode. Layers that have different constructions 1908 for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If 1909 set to true, the training branch will be executed, otherwise another branch. 1910 1911 Note: 1912 When execute function Model.train(), framework will call Cell.set_train(True). 1913 When execute function Model.eval(), framework will call Cell.set_train(False). 1914 1915 Args: 1916 mode (bool): Specifies whether the model is training. Default: ``True`` . 1917 1918 Returns: 1919 Cell, the cell itself. 1920 1921 Tutorial Examples: 1922 - `Model Training - Implementing Training and Evaluation 1923 <https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_ 1924 """ 1925 if mode: 1926 self._phase = 'train' 1927 else: 1928 self._phase = 'predict' 1929 self.add_flags_recursive(training=mode) 1930 return self 1931 1932 def set_broadcast_flag(self, mode=True): 1933 """ 1934 Set parameter broadcast mode for this cell. 1935 1936 Args: 1937 mode (bool): Specifies whether the mode is parameter broadcast. Default: ``True`` . 1938 """ 1939 self.add_flags_recursive(broadcast_flag=mode) 1940 return self 1941 1942 def set_auto_parallel(self): 1943 """ 1944 Set the cell to auto parallel mode. 1945 1946 Note: 1947 This interface is deprecated. 1948 """ 1949 logger.warning("'set_auto_parallel' function is deprecated.") 1950 1951 def set_jit_config(self, jit_config): 1952 """ 1953 Set jit config for cell. 1954 1955 Args: 1956 jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`. 1957 1958 Examples: 1959 >>> import mindspore as ms 1960 >>> from mindspore import Tensor, nn 1961 ... 1962 >>> class Net(nn.Cell): 1963 ... def __init__(self): 1964 ... super(Net, self).__init__() 1965 ... self.relu = nn.ReLU() 1966 ... 1967 ... def construct(self, x): 1968 ... x = self.relu(x) 1969 ... return x 1970 >>> net = Net() 1971 >>> jitconfig = ms.JitConfig() 1972 >>> net.set_jit_config(jitconfig) 1973 """ 1974 if self._jit_config_dict: 1975 logger.warning("For Cell, jit config can only be set once, ignore this setting.") 1976 else: 1977 self._jit_config_dict = jit_config.jit_config_dict 1978 1979 def flatten_weights(self, fusion_size=0): 1980 """ 1981 Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type. 1982 1983 Note: 1984 By default, parameters with same data type will using a single contiguous memory chunk. but for 1985 some models with huge number of parameters, splitting a large memory chunk into several smaller 1986 memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size' 1987 to limit the maximum memory chunk size. 1988 1989 Args: 1990 fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` . 1991 """ 1992 if fusion_size < 0: 1993 raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.") 1994 Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212 1995 1996 def register_forward_pre_hook(self, hook_fn): 1997 """ 1998 Register forward pre hook function for Cell object. 1999 2000 Note: 2001 - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'. 2002 - 'hook_fn' must be defined as the following code. 2003 `cell` is the object of registered Cell. `inputs` is the forward 2004 input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new 2005 forward input objects. 2006 - It should have the following signature: 2007 hook_fn(cell, inputs) -> new input objects or none. 2008 - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the 2009 `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is 2010 called in the `construct` function of the Cell object, a hook function will be added at each run time of 2011 Cell object. 2012 2013 Args: 2014 hook_fn (function): Python function. Forward pre hook function. 2015 2016 Returns: 2017 A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling 2018 `handle.remove()` . 2019 2020 Raises: 2021 TypeError: If the `hook_fn` is not a function of python. 2022 2023 Supported Platforms: 2024 ``Ascend`` ``GPU`` ``CPU`` 2025 2026 Examples: 2027 >>> import numpy as np 2028 >>> import mindspore as ms 2029 >>> from mindspore import Tensor, nn, ops 2030 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 2031 >>> def forward_pre_hook_fn(cell, inputs): 2032 ... print("forward inputs: ", inputs) 2033 ... 2034 >>> class Net(nn.Cell): 2035 ... def __init__(self): 2036 ... super(Net, self).__init__() 2037 ... self.mul = nn.MatMul() 2038 ... self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn) 2039 ... 2040 ... def construct(self, x, y): 2041 ... x = x + x 2042 ... x = self.mul(x, y) 2043 ... return x 2044 >>> grad = ops.GradOperation(get_all=True) 2045 >>> net = Net() 2046 >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) 2047 forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], 2048 dtype=Float32, value= [ 1.00000000e+00])) 2049 >>> print(output) 2050 (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, 2051 value= [ 2.00000000e+00])) 2052 """ 2053 if not check_hook_fn("register_forward_pre_hook", hook_fn): 2054 return HookHandle() 2055 self._enable_forward_pre_hook = True 2056 _pynative_executor.set_hook_changed(self) 2057 if not hasattr(self, '_forward_pre_hook_key'): 2058 self._forward_pre_hook_key = -1 2059 self._forward_pre_hook_key += 1 2060 self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn 2061 handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook") 2062 return handle 2063 2064 def _run_forward_pre_hook(self, inputs): 2065 """ 2066 Running forward pre hook function registered on Cell object. 2067 2068 Args: 2069 inputs: The input objects of cell object. 2070 2071 Returns: 2072 - **outputs** - New input objects or none. 2073 2074 Supported Platforms: 2075 ``Ascend`` ``GPU`` ``CPU`` 2076 """ 2077 for fn in self._forward_pre_hook.values(): 2078 ret = fn(self, inputs) 2079 if ret is not None: 2080 if not isinstance(ret, tuple): 2081 inputs = (ret,) 2082 else: 2083 inputs = ret 2084 return inputs 2085 2086 def register_forward_hook(self, hook_fn): 2087 """ 2088 Set the Cell forward hook function. 2089 2090 Note: 2091 - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'. 2092 - 'hook_fn' must be defined as the following code. 2093 `cell` is the object of registered Cell. `inputs` is the forward 2094 input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can 2095 modify the forward output object by returning new forward output object. 2096 - It should have the following signature: 2097 hook_fn(cell, inputs, output) -> new output object or none. 2098 - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the 2099 `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is 2100 called in the `construct` function of the Cell object, a hook function will be added at each run time of 2101 Cell object. 2102 2103 Args: 2104 hook_fn (function): Python function. Forward hook function. 2105 2106 Returns: 2107 A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling 2108 `handle.remove()` . 2109 2110 Raises: 2111 TypeError: If the `hook_fn` is not a function of python. 2112 2113 Supported Platforms: 2114 ``Ascend`` ``GPU`` ``CPU`` 2115 2116 Examples: 2117 >>> import numpy as np 2118 >>> import mindspore as ms 2119 >>> from mindspore import Tensor, nn, ops 2120 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 2121 >>> def forward_hook_fn(cell, inputs, output): 2122 ... print("forward inputs: ", inputs) 2123 ... print("forward output: ", output) 2124 ... 2125 >>> class Net(nn.Cell): 2126 ... def __init__(self): 2127 ... super(Net, self).__init__() 2128 ... self.mul = nn.MatMul() 2129 ... self.handle = self.mul.register_forward_hook(forward_hook_fn) 2130 ... 2131 ... def construct(self, x, y): 2132 ... x = x + x 2133 ... x = self.mul(x, y) 2134 ... return x 2135 >>> grad = ops.GradOperation(get_all=True) 2136 >>> net = Net() 2137 >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) 2138 forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], 2139 dtype=Float32, value= [ 1.00000000e+00])) 2140 forward output: 2.0 2141 >>> print(output) 2142 (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, 2143 value= [ 2.00000000e+00])) 2144 """ 2145 if not check_hook_fn("register_forward_hook", hook_fn): 2146 return HookHandle() 2147 self._enable_forward_hook = True 2148 _pynative_executor.set_hook_changed(self) 2149 if not hasattr(self, '_forward_hook_key'): 2150 self._forward_hook_key = -1 2151 self._forward_hook_key += 1 2152 self._forward_hook[self._forward_hook_key] = hook_fn 2153 handle = HookHandle(self, self._forward_hook_key, "_forward_hook") 2154 return handle 2155 2156 def _run_forward_hook(self, inputs, output): 2157 """ 2158 Running forward hook function registered on Cell object. 2159 2160 Args: 2161 inputs: The input objects of Cell object. 2162 output: The output object of Cell object. 2163 2164 Returns: 2165 - **output** - New output object or none. 2166 2167 Supported Platforms: 2168 ``Ascend`` ``GPU`` ``CPU`` 2169 """ 2170 for fn in self._forward_hook.values(): 2171 ret = fn(self, inputs, output) 2172 if ret is not None: 2173 output = ret 2174 return output 2175 2176 def register_backward_hook(self, hook_fn): 2177 """ 2178 Register the backward hook function. 2179 2180 Note: 2181 - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'. 2182 - The 'hook_fn' must be defined as the following code. 2183 `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the 2184 gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or 2185 primitive, which may be modified by returning a new output gradient. 2186 - The 'hook_fn' should have the following signature: 2187 hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none. 2188 - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to 2189 graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative 2190 mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object, 2191 a hook function will be added at each run time of Cell object. 2192 2193 Args: 2194 hook_fn (function): Python function. Backward hook function. 2195 2196 Returns: 2197 A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling 2198 `handle.remove()` . 2199 2200 Raises: 2201 TypeError: If the `hook_fn` is not a function of python. 2202 2203 Supported Platforms: 2204 ``Ascend`` ``GPU`` ``CPU`` 2205 2206 Examples: 2207 >>> import numpy as np 2208 >>> import mindspore as ms 2209 >>> from mindspore import Tensor, nn, ops 2210 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 2211 >>> def backward_hook_fn(cell_id, grad_input, grad_output): 2212 ... print("backward input: ", grad_input) 2213 ... print("backward output: ", grad_output) 2214 ... 2215 >>> class Net(nn.Cell): 2216 ... def __init__(self): 2217 ... super(Net, self).__init__() 2218 ... self.relu = nn.ReLU() 2219 ... self.handle = self.relu.register_backward_hook(backward_hook_fn) 2220 ... 2221 ... def construct(self, x): 2222 ... x = x + x 2223 ... x = self.relu(x) 2224 ... return x 2225 >>> grad = ops.GradOperation(get_all=True) 2226 >>> net = Net() 2227 >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32))) 2228 backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),) 2229 backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),) 2230 >>> print(output) 2231 (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),) 2232 """ 2233 if not check_hook_fn("register_backward_hook", hook_fn): 2234 return HookHandle() 2235 if self._cell_backward_hook is None: 2236 self._enable_backward_hook = True 2237 self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")") 2238 backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn) 2239 handle = HookHandle(self, backward_hook_key, "_cell_backward_hook") 2240 else: 2241 backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn) 2242 handle = HookHandle(self, backward_hook_key, "_cell_backward_hook") 2243 return handle 2244 2245 def _backward_hook_construct(self, *inputs, **kwargs): 2246 """ 2247 Backward hook construct method to replace original construct method. 2248 2249 Args: 2250 inputs: The input objects of Cell object. 2251 kwargs (dict): Dictionary of variable keyword parameters. 2252 2253 Returns: 2254 - **outputs** - The output objects of Cell object. 2255 2256 Supported Platforms: 2257 ``Ascend`` ``GPU`` ``CPU`` 2258 """ 2259 if len(inputs) > 1: 2260 inputs = self._cell_backward_hook(inputs) 2261 else: 2262 inputs = self._cell_backward_hook(*inputs) 2263 inputs = (inputs,) 2264 if self.recompute_cell is not None: 2265 if isinstance(inputs, tuple): 2266 outputs = self.recompute_cell(*inputs, **kwargs) 2267 else: 2268 outputs = self.recompute_cell(inputs, **kwargs) 2269 else: 2270 if isinstance(inputs, tuple): 2271 outputs = self.construct(*inputs, **kwargs) 2272 else: 2273 outputs = self.construct(inputs, **kwargs) 2274 outputs = self._cell_backward_hook(outputs) 2275 return outputs 2276 2277 def set_param_ps(self, recurse=True, init_in_server=False): 2278 """ 2279 Set whether the trainable parameters are updated by parameter server and whether the 2280 trainable parameters are initialized on server. 2281 2282 Note: 2283 It only works when a running task is in the parameter server mode. 2284 It is only supported in graph mode. 2285 2286 Args: 2287 recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` . 2288 init_in_server (bool): Whether trainable parameters updated by parameter server are 2289 initialized on server. Default: ``False`` . 2290 """ 2291 params = self.trainable_params(recurse) 2292 for param in params: 2293 param.set_param_ps(init_in_server) 2294 2295 @deprecated("1.8", "set_param_fl") 2296 def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True): 2297 params = self.parameters_and_names() 2298 for param in params: 2299 param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr) 2300 2301 def set_comm_fusion(self, fusion_type, recurse=True): 2302 """ 2303 Set `comm_fusion` for all the parameters in this cell. Please refer to the description of 2304 :class:`mindspore.Parameter.comm_fusion`. 2305 2306 Note: 2307 The value of attribute will be overwritten when the function is called multiply. 2308 2309 Args: 2310 fusion_type (int): The value of `comm_fusion`. 2311 recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` . 2312 """ 2313 Validator.check_non_negative_int(fusion_type) 2314 for param in self.trainable_params(recurse): 2315 param.comm_fusion = fusion_type 2316 return self 2317 2318 def _set_recompute_scope(self, mode): 2319 prefix = 'recompute_' 2320 if mode: 2321 if self._scope is None: 2322 self._scope = prefix 2323 elif not self._scope.startswith(prefix): 2324 self._scope = prefix + self._scope 2325 elif self._scope is not None and self._scope.startswith(prefix): 2326 self._scope = self._scope[len(prefix):] 2327 2328 def _mp_comm_recompute(self, mp_comm_recompute=True): 2329 """ 2330 Set the model parallel communication in cell recomputed. 2331 """ 2332 for _, value in self._primitives.items(): 2333 if value: 2334 value.add_prim_attr("recompute_comm_op", mp_comm_recompute) 2335 for cell in self.cells(): 2336 cell._mp_comm_recompute(mp_comm_recompute) 2337 2338 def _parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute=False): 2339 """ 2340 Set the parallel optimizer communication in cell recomputed. 2341 """ 2342 for param in self.trainable_params(): 2343 param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute 2344 2345 def _recompute_slice_activation(self, slice_activation=False): 2346 """ 2347 Slice the cell output which would remains in memory. 2348 """ 2349 for _, value in self._primitives.items(): 2350 if value: 2351 value.add_prim_attr("slice_activation", slice_activation) 2352 for cell in self.cells(): 2353 cell._recompute_slice_activation(slice_activation) 2354 2355 def _recompute(self, mode=True, output_recompute=False): 2356 """ 2357 Set the cell recomputed. 2358 """ 2359 Validator.check_bool(mode) 2360 Validator.check_bool(output_recompute) 2361 if not self._has_config_recompute: 2362 self._has_config_recompute = True 2363 else: 2364 raise RuntimeError("The recompute interface can be configured only once." 2365 " When the parent cell is configured, the child cell should not be configured") 2366 self._set_recompute_scope(mode) 2367 if mode and not output_recompute: 2368 self.add_flags(output_no_recompute=True) 2369 for cell in self.cells(): 2370 cell._recompute(mode, True) 2371 2372 @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool) 2373 def recompute(self, **kwargs): 2374 """ 2375 Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed. 2376 If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than 2377 storing the intermediate activation computed in forward pass, we will recompute it in backward pass. 2378 2379 Note: 2380 2381 - If the computation involves something like randomization or global variable, the equivalence 2382 is not guaranteed currently. 2383 - If the recompute api of a primitive in this cell is also called, the recompute mode of this 2384 primitive is subject to the recompute api of the primitive. 2385 - The interface can be configured only once. 2386 Therefore, when the parent cell is configured, the child cell should not be configured. 2387 - The outputs of cell are excluded from recomputation by default, which is based on our configuration 2388 experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted 2389 to be set recomputed, use the recompute api of the primtive. 2390 - When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False' 2391 to improve performance if necessary. 2392 - When the memory still not enough after applying the recompute, configuring 2393 'parallel_optimizer_comm_recompute=True' to save more memory if necessary. 2394 Cells in the same fusion group should have the same parallel_optimizer_comm_recompute configures. 2395 2396 Args: 2397 mp_comm_recompute (bool): Specifies whether the model parallel communication operators 2398 in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True`` . 2399 parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers 2400 introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. 2401 Default: ``False`` . 2402 """ 2403 if context.get_context("mode") == context.PYNATIVE_MODE: 2404 self.recompute_cell = recompute_registry.get()(self.construct) 2405 return 2406 self._recompute() 2407 if 'mp_comm_recompute' in kwargs.keys(): 2408 self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False)) 2409 if 'parallel_optimizer_comm_recompute' in kwargs.keys(): 2410 if (kwargs.get('parallel_optimizer_comm_recompute', False) and 2411 context.get_auto_parallel_context("pipeline_stages") > 1): 2412 logger.warning("Currently, the communication operator allgathers introduced by optimizer shard " 2413 "are not support recomputation in pipeline parallel.") 2414 elif context.get_auto_parallel_context("pipeline_stages") == 1: 2415 self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False)) 2416 if 'recompute_slice_activation' in kwargs: 2417 self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False)) 2418 2419 for key, _ in kwargs.items(): 2420 if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'): 2421 raise ValueError("For 'recompute', keyword '%s' is not recognized! " 2422 "the key kwargs must be 'mp_comm_recompute', " 2423 "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key) 2424 2425 @deprecated("2.3", "infer_param_pipeline_stage") 2426 def infer_param_pipeline_stage(self): 2427 """ 2428 Infer pipeline stages of all parameters in the cell. 2429 2430 Note: 2431 - The interface is deprecated from version 2.3 and will be removed in a future version. 2432 2433 Returns: 2434 The params belong to current stage in pipeline parallel. 2435 2436 Raises: 2437 RuntimeError: If there is a parameter does not belong to any stage. 2438 """ 2439 from mindspore.parallel._utils import _get_global_rank, _get_device_num 2440 logger.warning(f"This interface may be deleted in the future.") 2441 stage_num = context.get_auto_parallel_context("pipeline_stages") 2442 device_num = _get_device_num() 2443 rank_id = _get_global_rank() 2444 per_stage_devices = device_num // stage_num 2445 current_stage = rank_id // per_stage_devices 2446 params = [] 2447 for param in self.trainable_params(): 2448 if not param._pipeline_stage_list: # pylint: disable=W0212 2449 raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, " 2450 "please check whether the cell where the param locates has been set " 2451 "'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' " 2452 "to add its stage information".format(param.name)) 2453 if current_stage in param._pipeline_stage_list: 2454 params.append(param) 2455 return params 2456 2457 def place(self, role, rank_id): 2458 """ 2459 Set the label for all operators in this cell. 2460 This label tells MindSpore compiler on which process this cell should be launched. 2461 And each process's identical label consists of input `role` and `rank_id`. 2462 So by setting different cells with different labels, which will be launched on different processes, 2463 users can launch a distributed training or predicting job. 2464 2465 Note: 2466 - This method is effective only after 2467 `mindspore.communication.init()` is called for dynamic cluster building. 2468 2469 Args: 2470 role (str): The role of the process on which this cell will be launched. 2471 Only 'MS_WORKER' is supported for now. 2472 rank_id (int): The rank id of the process on which this cell will be launched. 2473 The rank is unique in processes with the same role. 2474 2475 Examples: 2476 >>> from mindspore import context 2477 >>> import mindspore.nn as nn 2478 >>> context.set_context(mode=context.GRAPH_MODE) 2479 >>> fc = nn.Dense(2, 3) 2480 >>> fc.place('MS_WORKER', 0) 2481 """ 2482 all_ops = self._get_prims_recursively() 2483 for op in all_ops: 2484 op.place(role, rank_id) 2485 2486 def _mixed_precision_cast(self, inputs): 2487 mixed_type = self.get_mixed_precision_type() 2488 if mixed_type == MixedPrecisionType.NOTSET: 2489 return inputs 2490 if mixed_type == MixedPrecisionType.FP16: 2491 cast_type = mstype.float16 2492 elif mixed_type == MixedPrecisionType.BF16: 2493 cast_type = mstype.bfloat16 2494 else: 2495 cast_type = mstype.float32 2496 cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type) 2497 return cast_inputs 2498 2499 def _get_attr_from_cell(self, network): 2500 if not isinstance(network, Cell): 2501 return 2502 if hasattr(network, "jit_config_dict"): 2503 self._jit_config_dict = network.jit_config_dict 2504 if hasattr(network, "_amp_level"): 2505 self._amp_level = getattr(network, "_amp_level") 2506 2507 2508class GraphCell(Cell): 2509 """ 2510 Base class for running the graph loaded from a MindIR. 2511 2512 This feature is still under development. Currently `GraphCell` do not support modifying the structure of the 2513 diagram, and can only use data that shape and type are the same as the input when exporting the MindIR. 2514 2515 Args: 2516 graph (FuncGraph): A compiled graph loaded from MindIR. 2517 params_init (dict): Parameters need to be inited in the graph. 2518 The key is the parameter name whose type is str, and the value is a Tensor or Parameter. 2519 If the parameter exists in the graph according to the name, update it's value. 2520 If the parameter does not exist, ignore it. Default: ``None`` . 2521 obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is 2522 used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is 2523 a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be 2524 provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` . 2525 2526 Raises: 2527 TypeError: If the `graph` is not a FuncGraph. 2528 TypeError: If the `params_init` is not a dict. 2529 TypeError: If the key of the `params_init` is not a str. 2530 TypeError: If the value of the `params_init` is neither a Tensor nor a Parameter. 2531 2532 Supported Platforms: 2533 ``Ascend`` ``GPU`` ``CPU`` 2534 2535 Examples: 2536 >>> import numpy as np 2537 >>> import mindspore as ms 2538 >>> import mindspore.nn as nn 2539 >>> from mindspore import Tensor 2540 >>> from mindspore import context 2541 >>> context.set_context(mode=context.GRAPH_MODE) 2542 >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") 2543 >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 2544 >>> ms.export(net, input, file_name="net", file_format="MINDIR") 2545 >>> graph = ms.load("net.mindir") 2546 >>> net = nn.GraphCell(graph) 2547 >>> output = net(input) 2548 >>> print(output) 2549 [[[[4. 6. 4.] 2550 [6. 9. 6.] 2551 [4. 6. 4.]]]] 2552 """ 2553 2554 def __init__(self, graph, params_init=None, obf_random_seed=None): 2555 super(GraphCell, self).__init__(auto_prefix=True) 2556 if not isinstance(graph, FuncGraph): 2557 raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, " 2558 f"but got type {type(graph)}.") 2559 self.graph = graph 2560 self.obf_random_seed = obf_random_seed 2561 if obf_random_seed is not None: 2562 if not isinstance(obf_random_seed, int): 2563 raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed))) 2564 int_64_max = 9223372036854775807 2565 if obf_random_seed <= 0 or obf_random_seed > int_64_max: 2566 raise ValueError( 2567 "'obf_random_seed' must be larger than 0, and less or equal than int64 ({})," 2568 "but got {}.".format(int_64_max, obf_random_seed)) 2569 self._branch_control_input = _generate_branch_control_input(self.obf_random_seed) 2570 params_init = {} if params_init is None else params_init 2571 if not isinstance(params_init, dict): 2572 raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.") 2573 for name, value in params_init.items(): 2574 if not isinstance(name, str) or not isinstance(value, Tensor): 2575 raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, " 2576 "and the value must be Tensor or Parameter, " 2577 f"but got the key type: {type(name)}, and the value type: {type(value)}") 2578 2579 params_dict = update_func_graph_hyper_params(self.graph, params_init) 2580 for name, param in params_dict.items(): 2581 self._params[name] = param 2582 _cell_graph_executor.inc_graph_cell_count() 2583 2584 def construct(self, *inputs): 2585 return self.graph(*inputs) 2586 2587 def __call__(self, *args, **kwargs): 2588 self.phase = "graph_load_from_mindir" 2589 self._add_attr("graph_load_from_mindir", self.graph) 2590 if not self.obf_random_seed: 2591 return self.compile_and_run(*args, **kwargs) 2592 append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32)) 2593 return self.compile_and_run(*args, append_input, **kwargs) 2594 2595 2596def _check_param_list_tuple(value): 2597 """ 2598 Check the type of input in list or tuple is Parameter. 2599 :param value: list or tuple. 2600 :return: The types of all inputs are parameter. 2601 """ 2602 for item in value: 2603 if not isinstance(item, Parameter): 2604 return False 2605 return True 2606