1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""cell""" 16import gc 17import inspect 18import os 19import time 20from collections import OrderedDict 21 22import numpy 23 24from mindspore._checkparam import args_type_check 25from mindspore import log as logger 26from mindspore.common.parameter import PARAMETER_NAME_DEFAULT 27from mindspore.common._decorator import deprecated 28from mindspore.context import ParallelMode 29from .. import context 30from .._c_expression import init_pipeline, Cell_, FuncGraph 31from .._checkparam import Validator 32from ..common import dtype as mstype 33from ..common.api import _cell_graph_executor, _pynative_executor 34from ..common.parameter import Parameter, ParameterTuple 35from ..common.tensor import Tensor 36from ..ops.operations import HookBackward, Cast 37from ..ops.primitive import Primitive 38from ..parallel._tensor import _load_tensor_by_layout 39 40 41class Cell(Cell_): 42 """ 43 Base class for all neural networks. 44 45 A 'Cell' could be a single neural network cell, such as conv2d, relu, batch_norm, etc. or a composition of 46 cells to constructing a network. 47 48 Note: 49 In general, the autograd algorithm will automatically generate the implementation of the gradient function, 50 but if back-propagation(bprop) method is implemented, the gradient function will be replaced by the bprop. 51 The bprop implementation will receive a tensor `dout` containing the gradient of the loss w.r.t. 52 the output, and a tensor `out` containing the forward result. The bprop needs to compute the 53 gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported 54 currently. The bprop method must contain the self parameter. 55 56 Args: 57 auto_prefix (bool): Recursively generate namespaces. Default: True. 58 flags (dict): Network configuration information, currently it is used for the binding of network and dataset. 59 Users can also customize network attributes by this parameter. Default: None. 60 61 Supported Platforms: 62 ``Ascend`` ``GPU`` ``CPU`` 63 64 Examples: 65 >>> import mindspore.nn as nn 66 >>> import mindspore.ops as ops 67 >>> class MyCell(nn.Cell): 68 ... def __init__(self): 69 ... super(MyCell, self).__init__() 70 ... self.relu = ops.ReLU() 71 ... 72 ... def construct(self, x): 73 ... return self.relu(x) 74 """ 75 IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', 76 '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', 77 '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', 78 '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', 79 '_attr_synced', 'enable_hook', 'pynative', 'requires_grad', 80 '_auto_parallel_compile_and_run', 'cell_type'] 81 82 def __init__(self, auto_prefix=True, flags=None): 83 Cell_.__init__(self, self._cell_tag) 84 self._params = OrderedDict() 85 self._cells = OrderedDict() 86 self._params_list = OrderedDict() 87 self._tensor_list = OrderedDict() 88 self._primitives = OrderedDict() 89 self.training = False 90 self.requires_grad = False 91 self.pynative = False 92 self._attr_synced = False 93 self._param_prefix = '' 94 self._auto_prefix = auto_prefix 95 self._scope = None 96 self._phase = 'train' 97 self._parameter_layout_dict = {} 98 self._parallel_parameter_name_list = () 99 self._parallel_parameter_merge_net_dict = {} 100 self._create_time = int(time.time() * 1e9) 101 self.arguments_key = "" 102 self.parameter_broadcast_done = False 103 init_pipeline() 104 105 # call gc to release GE session resources used by non-used cell objects 106 if os.getenv('GC_COLLECT_IN_CELL') == '1': 107 gc.collect() 108 109 self._construct_inputs_num = 0 110 self._construct_inputs_names = [] 111 self._auto_parallel_mode = False 112 self._parallel_inputs_run = None 113 if flags: 114 self.add_flags(**flags) 115 self._backward_hook = None 116 self.enable_hook = False 117 self._bprop_debug = False 118 self.cell_type = None 119 self._auto_parallel_compile_and_run = False 120 self.cast = Cast() 121 self._has_config_recompute = False 122 123 def __getstate__(self): 124 base = Cell_.__getstate__(self) 125 return base, self.__dict__ 126 127 def __setstate__(self, state): 128 base, dict_ = state 129 Cell_.__setstate__(self, base) 130 self.__dict__ = dict_ 131 self._attr_synced = False 132 133 @property 134 def _cell_tag(self): 135 # `<class 'xxxxxxx'>` to `xxxxxxx` 136 return str(self.__class__)[8:-2] 137 138 @property 139 def create_time(self): 140 return self._create_time 141 142 @property 143 def cell_init_args(self): 144 return self._cell_init_args 145 146 @property 147 def param_prefix(self): 148 """ 149 Param prefix is the prefix of current cell's direct child parameter. 150 """ 151 return self._param_prefix 152 153 @property 154 def bprop_debug(self): 155 """ 156 Get whether cell custom bprop debug is enabled. 157 """ 158 return self._bprop_debug 159 160 @bprop_debug.setter 161 def bprop_debug(self, value): 162 """ 163 Set whether to enable cell custom bprop debug. 164 165 Note: 166 When bprop is defined in cell, the bprop function will be executed 167 in python interpreter when bprop debug is true, and will be parsed 168 and add to graph when bprop debug is false. 169 170 Args: 171 value (bool): Specifies whether to enable bprop debug. Default: False. 172 """ 173 if not isinstance(value, bool): 174 raise TypeError("The 'bprop debug' value must be a bool type.") 175 self._bprop_debug = value 176 177 def update_cell_prefix(self): 178 """ 179 Update the all child cells' self.param_prefix. 180 181 After being invoked, it can get all the cell's children's name prefix by '_param_prefix'. 182 """ 183 cells_name = self.cells_and_names() 184 185 for cell_name, cell in cells_name: 186 cell._param_prefix = cell_name 187 188 def update_cell_type(self, cell_type): 189 """ 190 The current cell type is updated when a quantization aware training network is encountered. 191 192 After being invoked, it can set the cell type to 'cell_type'. 193 """ 194 self.cell_type = cell_type 195 196 @cell_init_args.setter 197 def cell_init_args(self, value): 198 if not isinstance(value, str): 199 raise TypeError("The 'cell_init_args' must be a string type.") 200 self._cell_init_args = value 201 202 @property 203 def phase(self): 204 return self._phase 205 206 @phase.setter 207 def phase(self, value): 208 if not isinstance(value, str): 209 raise TypeError("The 'phase' must be a string type.") 210 self._phase = value 211 212 @property 213 def parameter_layout_dict(self): 214 """ 215 `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and 216 distributed operator information. 217 """ 218 return self._parameter_layout_dict 219 220 @property 221 def cls_name(self): 222 return self.__class__.__name__ 223 224 @parameter_layout_dict.setter 225 def parameter_layout_dict(self, value): 226 if not isinstance(value, dict): 227 raise TypeError("The 'parameter_layout_dict' must be a dict type.") 228 self._parameter_layout_dict = value 229 230 @property 231 def parallel_parameter_name_list(self): 232 return self._parallel_parameter_name_list 233 234 @parallel_parameter_name_list.setter 235 def parallel_parameter_name_list(self, value): 236 if not isinstance(value, list): 237 raise TypeError("The 'parallel_parameter_name_list' must be a list type.") 238 self._parallel_parameter_name_list = value 239 240 @property 241 def pipeline_stage(self): 242 return self._pipeline_stage 243 244 @pipeline_stage.setter 245 def pipeline_stage(self, value): 246 if isinstance(value, bool): 247 raise TypeError("'pipeline_stage' must be int type, but got bool.") 248 if not isinstance(value, int): 249 raise TypeError("'pipeline_stage' must be int type.") 250 if value < 0: 251 raise TypeError("'pipeline_stage' can not less than 0.") 252 self._pipeline_stage = value 253 for item in self.trainable_params(): 254 item.add_pipeline_stage(value) 255 256 @property 257 def parallel_parameter_merge_net_dict(self): 258 return self._parallel_parameter_merge_net_dict 259 260 @parallel_parameter_merge_net_dict.setter 261 def parallel_parameter_merge_net_dict(self, value): 262 if not isinstance(value, dict): 263 raise TypeError("The 'parallel_parameter_merge_net_dict' must be a dict type.") 264 self._parallel_parameter_merge_net_dict = value 265 266 def get_func_graph_proto(self): 267 """Return graph binary proto.""" 268 exec_id = self.phase + "." + str(self.create_time) + '.' + str(id(self)) 269 return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True) 270 271 def __getattr__(self, name): 272 if '_params' in self.__dict__: 273 params = self.__dict__['_params'] 274 if name in params: 275 if context.get_context("mode") == context.PYNATIVE_MODE: 276 return self.cast_param(params[name]) 277 return params[name] 278 if '_cells' in self.__dict__: 279 cells = self.__dict__['_cells'] 280 if name in cells: 281 return cells[name] 282 if '_tensor_list' in self.__dict__: 283 tensor_list = self.__dict__['_tensor_list'] 284 if name in tensor_list: 285 return self.cast_param(tensor_list[name]) 286 if '_params_list' in self.__dict__: 287 params_list = self.__dict__['_params_list'] 288 if name in params_list: 289 para_list = params_list[name] 290 cast_list = list() 291 for para in para_list: 292 cast_list.append(self.cast_param(para)) 293 para_list = ParameterTuple(cast_list) 294 return para_list 295 raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name)) 296 297 def __del__(self): 298 if context.get_context is not None and context.get_context("mode") == context.PYNATIVE_MODE: 299 _pynative_executor.del_cell(str(id(self))) 300 if hasattr(self, "_create_time"): 301 _cell_graph_executor.del_net_res(str(self._create_time)) 302 303 def __delattr__(self, name): 304 if name in self._params: 305 del self._params[name] 306 elif name in self._cells: 307 del self._cells[name] 308 else: 309 if '_params_list' in self.__dict__ and name in self._params_list: 310 del self._params_list[name] 311 elif '_tensor_list' in self.__dict__ and name in self._tensor_list: 312 del self._tensor_list[name] 313 object.__delattr__(self, name) 314 self._attr_synced = False 315 316 def _cast_mixed_precision_inputs(self, inputs, dst_type): 317 """Cast input for mixed precision""" 318 res = list() 319 for item in inputs: 320 if isinstance(item, tuple): 321 res.append(self._cast_mixed_precision_inputs(item, dst_type)) 322 elif isinstance(item, float): 323 res.append(self.cast(item, dst_type)) 324 elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}: 325 res.append(self.cast(item, dst_type)) 326 else: 327 res.append(item) 328 return tuple(res) 329 330 def cast_inputs(self, inputs, dst_type): 331 """ 332 Cast inputs to specified type. 333 """ 334 res = list() 335 for item in inputs: 336 if isinstance(item, tuple): 337 res.append(self.cast_inputs(item, dst_type)) 338 else: 339 res.append(self.cast(item, dst_type)) 340 return tuple(res) 341 342 def _do_parameter_broadcast(self): 343 if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL: 344 if not self.parameter_broadcast_done: 345 _pynative_executor.parameter_broadcast(self, self.phase, self._auto_parallel_mode) 346 self.parameter_broadcast_done = True 347 348 def run_construct(self, cast_inputs, kwargs): 349 if self.enable_hook: 350 output = self._hook_construct(*cast_inputs) 351 else: 352 output = self.construct(*cast_inputs, **kwargs) 353 return output 354 355 def _check_construct_args(self, *inputs, **kwargs): 356 """Check the args needed by the function construct""" 357 if kwargs: 358 raise ValueError("For 'graph' mode, the outermost network does not support passing " 359 "variable key-value pair parameters.") 360 positional_args = 0 361 default_args = 0 362 for value in inspect.signature(self.construct).parameters.values(): 363 if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD: 364 return 365 if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: 366 if value.default is inspect.Parameter.empty: 367 positional_args += 1 368 else: 369 default_args += 1 370 371 if len(inputs) < positional_args: 372 raise TypeError( 373 f"The function construct needs {positional_args} positional argument, but only provided {len(inputs)}.") 374 375 if len(inputs) > positional_args + default_args: 376 raise TypeError( 377 f"The function construct needs {positional_args} positional argument and {default_args} default " 378 f"argument, but provided {len(inputs)}") 379 380 class CellGuard: 381 def __enter__(self): 382 _pynative_executor.set_lazy_build(True) 383 _pynative_executor.enter_cell() 384 385 def __exit__(self, exc_type, exc_val, exc_tb): 386 _pynative_executor.exit_cell() 387 if _pynative_executor.is_top_cell(): 388 _pynative_executor.set_lazy_build(False) 389 390 def __call__(self, *inputs, **kwargs): 391 if self.__class__.construct is Cell.construct: 392 logger.warning(f"The '{self.__class__}' does not override the method 'construct', " 393 f"will call the super class(Cell) 'construct'.") 394 if kwargs: 395 bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) 396 inputs = bound_args.args 397 kwargs = bound_args.kwargs 398 399 # Run in Graph mode. 400 if context.get_context("mode") == context.GRAPH_MODE: 401 self._check_construct_args(*inputs, **kwargs) 402 if self.enable_hook: 403 raise ValueError("The graph mode does not support hook function.") 404 out = self.compile_and_run(*inputs) 405 return out 406 407 # Run in PyNative mode. 408 if _pynative_executor.is_top_cell(): 409 _pynative_executor.set_lazy_build(True) 410 # There many Casts in parameter_broadcast. Enable lazy_build and build faster. 411 self._do_parameter_broadcast() 412 413 for item in inputs: 414 if isinstance(item, numpy.ndarray): 415 raise TypeError("The cell inputs should not be numpy arrays.") 416 if self.requires_grad is True: 417 _pynative_executor.set_grad_flag(True) 418 _pynative_executor.new_graph(self, *inputs, **kwargs) 419 cast_inputs = list() 420 if hasattr(self, "_mindspore_flags"): 421 if self._mindspore_flags.get('fp16'): 422 cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16) 423 if self._mindspore_flags.get('fp32'): 424 cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32) 425 if not cast_inputs: 426 cast_inputs = inputs 427 428 with self.CellGuard(): 429 try: 430 output = self.run_construct(cast_inputs, kwargs) 431 except Exception as err: 432 _pynative_executor.clear_res() 433 raise err 434 435 if _pynative_executor.is_top_cell(): 436 _pynative_executor.execute_all_task() 437 438 if isinstance(output, Parameter): 439 output = output.data 440 _pynative_executor.end_graph(self, output, *inputs, **kwargs) 441 return output 442 443 def _add_attr(self, name, value): 444 if name and name[:2] != '__' and name not in Cell.IGNORE_LIST: 445 super(Cell, self)._add_attr(name, value) 446 447 def _sync_attr_for_compile(self): 448 """Sync the attr to c++ object.""" 449 if self._attr_synced: 450 return 451 cells = self.__dict__.get('_cells') 452 for key in cells: 453 cell = cells[key] 454 cell._sync_attr_for_compile() 455 self._add_attr(key, cell) 456 params = self.__dict__.get('_params') 457 for key in params: 458 if '.' in key: 459 continue 460 param = params[key] 461 self._add_attr(key, param) 462 params_list = self.__dict__.get('_params_list') 463 for key in params_list: 464 params_list_item = params_list[key] 465 self._add_attr(key, params_list_item) 466 for key in self.__dict__: 467 value = self.__dict__[key] 468 self._add_attr(key, value) 469 self._attr_synced = True 470 471 def _set_attr_for_parameter(self, name, value): 472 """Set attr for parameter.""" 473 cells = self.__dict__.get('_cells') 474 params = self.__dict__.get('_params') 475 if params is None: 476 raise AttributeError("Can not assign params before Cell.__init__() call.") 477 if name in self.__dict__: 478 if self.__dict__[name] is not None: 479 raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.") 480 del self.__dict__[name] 481 if cells and name in cells: 482 raise TypeError("The type of value should be Cell, but got Parameter.") 483 self.insert_param_to_cell(name, value) 484 485 def _set_attr_for_parameter_tuple(self, name, value): 486 """Set attr for parameter tuple.""" 487 params = self.__dict__.get('_params') 488 params_list = self.__dict__.get('_params_list') 489 if params is None: 490 raise AttributeError("Can not assign params before Cell.__init__() call.") 491 for item in value: 492 self.insert_param_to_cell(item.name, item, check_name=False) 493 if context.get_context("mode") == context.PYNATIVE_MODE: 494 if name in self.__dict__: 495 del self.__dict__[name] 496 if name in params: 497 del params[name] 498 params_list[name] = value 499 else: 500 object.__setattr__(self, name, value) 501 502 def _set_attr_for_cell(self, name, value): 503 """Set attr for cell.""" 504 cells = self.__dict__.get('_cells') 505 params = self.__dict__.get('_params') 506 if cells is None: 507 raise AttributeError("Can not assign cells before Cell.__init__() call.") 508 if name in self.__dict__: 509 del self.__dict__[name] 510 if params and name in params: 511 raise TypeError("The type of value should be Parameter, but got Cell.") 512 if self._auto_prefix: 513 value.update_parameters_name(name + '.') 514 cells[name] = value 515 if hasattr(self, '_cell_init_args'): 516 self.cell_init_args += str({name: value}) 517 518 def __setattr__(self, name, value): 519 cells = self.__dict__.get('_cells') 520 params = self.__dict__.get('_params') 521 tensor_list = self.__dict__.get('_tensor_list') 522 if isinstance(value, Parameter): 523 self._set_attr_for_parameter(name, value) 524 elif isinstance(value, ParameterTuple): 525 self._set_attr_for_parameter_tuple(name, value) 526 elif isinstance(value, Cell): 527 self._set_attr_for_cell(name, value) 528 elif params and name in params: 529 if isinstance(value, Tensor) and self._params[name] is not None: 530 self._params[name].set_data(value) 531 elif value is not None: 532 raise TypeError(f"The type of value should be Parameter or ParameterTuple, " 533 f"but got {type(value).__name__}.") 534 else: 535 self.insert_param_to_cell(name, None) 536 elif cells and name in cells: 537 if value is not None: 538 raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.") 539 self._cells[name] = None 540 elif isinstance(value, Tensor): 541 if context.get_context("mode") == context.PYNATIVE_MODE: 542 if name in self.__dict__: 543 del self.__dict__[name] 544 tensor_list[name] = value 545 else: 546 object.__setattr__(self, name, value) 547 else: 548 if isinstance(value, Primitive): 549 value.set_prim_instance_name(name) 550 self._primitives[name] = value 551 object.__setattr__(self, name, value) 552 if name not in Cell.IGNORE_LIST: 553 self._attr_synced = False 554 555 def extend_repr(self): 556 """ 557 Sets the extended representation of the Cell. 558 559 To print customized extended information, re-implement this method in your own cells. 560 """ 561 return '' 562 563 def __str__(self): 564 return self.__repr__() 565 566 def __repr__(self): 567 extra_str = self.extend_repr() 568 info_str = self.__class__.__name__ + '<' 569 if self._cells: 570 sub_str = '\n' 571 if extra_str: 572 sub_str += '{}\n'.format(self.extend_repr()) 573 for key, value in self._cells.items(): 574 sub_str += '({}): {}\n'.format(key, repr(value)) 575 sub_str = sub_str.replace('\n', '\n ') + '>' 576 info_str += sub_str 577 else: 578 info_str += extra_str + '>' 579 return info_str 580 581 def load_parameter_slice(self, params): 582 """ 583 Replace parameters with sliced tensors by parallel strategies. 584 585 Please refer to the usage in source code of `mindspore.common._CellGraphExecutor.compile`. 586 587 Args: 588 params (dict): The parameters dictionary used for initializing the data graph. 589 """ 590 if params is None: 591 params = self.parameters_dict() 592 if isinstance(params, OrderedDict): 593 for key in params: 594 tensor = params[key].data 595 if key not in self.parameter_layout_dict: 596 logger.info("The layout dict does not contain the key %s.", key) 597 continue 598 if params[key].sliced: 599 logger.debug("The param %s is already sliced.", key) 600 continue 601 layout = self.parameter_layout_dict[key] 602 new_tensor = _load_tensor_by_layout(tensor, layout) 603 params[key].set_data(new_tensor, True) 604 else: 605 raise TypeError("Parameters need OrderedDict type, but got {}.".format(type(params))) 606 607 def _load_inputs(self, *inputs): 608 """ 609 Slice inputs tensors by parallel strategies. 610 611 Args: 612 inputs (Function or Cell): inputs of construct method. 613 """ 614 parallel_inputs_run = [] 615 # judge if *args exists in input 616 if self.argspec[1] is not None: 617 prefix = self.argspec[1] 618 for i in range(len(inputs)): 619 key = prefix + str(i) 620 self._construct_inputs_names = self._construct_inputs_names + (key,) 621 self._construct_inputs_num = self._construct_inputs_num + 1 622 for i, tensor in enumerate(inputs): 623 key = self._construct_inputs_names[i] 624 # if input is not used, self.parameter_layout_dict may not contain the key 625 if key not in self.parameter_layout_dict: 626 logger.warning("Layout dict does not contain the key %s.", key) 627 parallel_inputs_run.append(tensor) 628 else: 629 layout = self.parameter_layout_dict[key] 630 new_tensor = _load_tensor_by_layout(tensor, layout) 631 parallel_inputs_run.append(new_tensor) 632 return tuple(parallel_inputs_run) 633 634 def set_parallel_input_with_inputs(self, *inputs): 635 """ 636 Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run` 637 638 Args: 639 inputs (tuple): inputs of construct method. 640 """ 641 self._parallel_inputs_run = self._load_inputs(*inputs) 642 643 def _get_construct_inputs_number_and_name(self): 644 """Compute self._construct_inputs_names and self._construct_inputs_num""" 645 from mindspore._extends.parse.parser import get_parse_method_of_class 646 647 fn = get_parse_method_of_class(self) 648 self.argspec = inspect.getfullargspec(fn) 649 self._construct_inputs_num = fn.__code__.co_argcount 650 self._construct_inputs_names = fn.__code__.co_varnames 651 652 if self._construct_inputs_num <= 0: 653 raise ValueError(f"Num of inputs must be greater than 0, but got {self._construct_inputs_num}") 654 if self._construct_inputs_names[0] != 'self': 655 raise ValueError(f"First member of fn function must be self, but got {self._construct_inputs_names[0]}") 656 if self._construct_inputs_num - 1 > len(self._construct_inputs_names): 657 raise ValueError(f"Num of inputs must be greater than num of fn function members, num of inputs is \ 658 {self._construct_inputs_names - 1}, num of fn function members is {len(self._construct_inputs_names)}") 659 self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num] 660 self._construct_inputs_num = self._construct_inputs_num - 1 661 662 def compile(self, *inputs): 663 """ 664 Compiles cell. 665 666 Args: 667 inputs (tuple): Inputs of the Cell object. 668 """ 669 _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) 670 671 def compile_and_run(self, *inputs): 672 """ 673 Compiles and runs cell. 674 675 Args: 676 inputs (tuple): Inputs of the Cell object. 677 678 Returns: 679 Object, the result of executing. 680 """ 681 self._auto_parallel_compile_and_run = True 682 self.compile(*inputs) 683 684 new_inputs = [] 685 for i in inputs: 686 if isinstance(i, Tensor): 687 new_inputs.append(i) 688 elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): 689 new_inputs.append(i) 690 691 if self._auto_parallel_mode: 692 if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag: 693 # get parallel inputs in sink mode, parallel inputs set in _cell_graph_executor.compile 694 parallel_inputs_run = self._parallel_inputs_run 695 else: 696 parallel_inputs_run = new_inputs 697 return _cell_graph_executor(self, *parallel_inputs_run, phase=self.phase) 698 return _cell_graph_executor(self, *new_inputs, phase=self.phase) 699 700 def auto_parallel_compile_and_run(self): 701 """ 702 Whether or not to execute compile and run. 703 704 Returns: 705 bool, `_auto_parallel_compile_and_run` value. 706 """ 707 return self._auto_parallel_compile_and_run 708 709 def exec_checkpoint_graph(self): 710 """Executes saving checkpoint graph operation.""" 711 _cell_graph_executor(self, phase='save') 712 713 def insert_param_to_cell(self, param_name, param, check_name=True): 714 """ 715 Adds a parameter to the current cell. 716 717 Inserts a parameter with given name to the cell. Please refer to the usage in 718 source code of `mindspore.nn.Cell.__setattr__`. 719 720 Args: 721 param_name (str): Name of the parameter. 722 param (Parameter): Parameter to be inserted to the cell. 723 check_name (bool): Determines whether the name input is compatible. Default: True. 724 725 Raises: 726 KeyError: If the name of parameter is null or contains dot. 727 AttributeError: If user did not call init() first. 728 TypeError: If the type of parameter is not Parameter. 729 """ 730 if not param_name: 731 raise KeyError("The name of parameter should not be null.") 732 if check_name and '.' in param_name: 733 raise KeyError("The name of parameter should not contain \".\"") 734 if '_params' not in self.__dict__: 735 raise AttributeError("You need call init() first.") 736 if hasattr(self, param_name) and param_name not in self._params: 737 raise KeyError("Duplicated parameter name '{}'.".format(param_name)) 738 if not isinstance(param, Parameter) and param is not None: 739 raise TypeError("The type of parameter should be 'Parameter' if not None.") 740 if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT: 741 param.name = param_name 742 self._params[param_name] = param 743 744 def cast_param(self, param): 745 """ 746 Cast parameter according to auto mix precision level in pynative mode. 747 748 This interface is currently used in the case of auto mix precision and usually need not to be used explicitly. 749 750 Args: 751 param (Parameter): Parameters, the type of which should be cast. 752 753 Returns: 754 Parameter, the input parameter with type automatically cast. 755 """ 756 if hasattr(self, "_mindspore_flags"): 757 if self._mindspore_flags.get('fp32'): 758 param.set_cast_dtype(mstype.float32) 759 elif self._mindspore_flags.get('fp16'): 760 param.set_cast_dtype(mstype.float16) 761 elif hasattr(param, "set_cast_dtype"): 762 # retest dtype 763 param.set_cast_dtype() 764 return param 765 766 def insert_child_to_cell(self, child_name, child_cell): 767 """ 768 Adds a child cell to the current cell with a given name. 769 770 Args: 771 child_name (str): Name of the child cell. 772 child_cell (Cell): The child cell to be inserted. 773 774 Raises: 775 KeyError: Child Cell's name is incorrect or duplicated with the other child name. 776 TypeError: Child Cell's type is incorrect. 777 """ 778 if not child_name or '.' in child_name: 779 raise KeyError("Child cell name is incorrect.") 780 if hasattr(self, child_name) and child_name not in self._cells: 781 raise KeyError("Duplicate child name '{}'.".format(child_name)) 782 if not isinstance(child_cell, Cell) and child_cell is not None: 783 raise TypeError("Child cell type is incorrect.") 784 self._cells[child_name] = child_cell 785 786 def construct(self, *inputs, **kwargs): 787 """ 788 Defines the computation to be performed. This method must be overridden by all subclasses. 789 790 Returns: 791 Tensor, returns the computed result. 792 """ 793 return None 794 795 def remove_redundant_parameters(self): 796 """ 797 Remove the redundant parameters. 798 799 This interface usually need not to be used explicitly. 800 """ 801 cells = self.cells_and_names() 802 for _, cell in cells: 803 params = cell._params.items() 804 for param_name, param in list(params): 805 if param.name not in self.parallel_parameter_name_list: 806 cell._params.pop(param_name) 807 logger.info("remove the redundant parameter: %s", param.name) 808 continue 809 cell_dict = cell.__dict__ 810 for key in cell_dict: 811 if isinstance(cell_dict[key], ParameterTuple): 812 param_tuple = cell_dict[key] 813 new_param_tuple = [] 814 for param in param_tuple: 815 if param.name not in self.parallel_parameter_name_list: 816 logger.info("remove the redundant parameter: %s in ParameterTuple", param.name) 817 continue 818 new_param_tuple.append(param) 819 cell.__dict__[key] = ParameterTuple(new_param_tuple) 820 821 def init_parameters_data(self, auto_parallel_mode=False): 822 """ 823 Initialize all parameters and replace the original saved parameters in cell. 824 825 Note: 826 trainable_params() and other similar interfaces may return different parameter instance after 827 `init_parameters_data`, do not save these result. 828 829 Args: 830 auto_parallel_mode (bool): If running in auto_parallel_mode. 831 832 Returns: 833 Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. 834 """ 835 replace = dict() 836 837 def _updata(param): 838 if param in replace: 839 return replace[param] 840 layout = None 841 set_sliced = False 842 if auto_parallel_mode: 843 set_sliced = True 844 if param.name not in self.parameter_layout_dict: 845 logger.debug("Layout dict does not contain the key %s.", param.name) 846 else: 847 layout = self.parameter_layout_dict[param.name] 848 new_p = param.init_data(layout, set_sliced=set_sliced) 849 replace[param] = new_p 850 return new_p 851 852 # replace all original usage. 853 cells = self.cells_and_names() 854 for _, cell in cells: 855 params = cell._params.items() 856 for param_name, param in params: 857 if not auto_parallel_mode: 858 cell._params[param_name] = _updata(param) 859 continue 860 if param.name in self.parallel_parameter_name_list: 861 cell._params[param_name] = _updata(param) 862 cell_dict = cell.__dict__ 863 for key in cell_dict: 864 if isinstance(cell_dict[key], ParameterTuple): 865 param_tuple = cell_dict[key] 866 new_param_tuple = [] 867 for param in param_tuple: 868 if not auto_parallel_mode: 869 new_param_tuple.append(_updata(param)) 870 continue 871 if param.name in self.parallel_parameter_name_list: 872 new_param_tuple.append(_updata(param)) 873 else: 874 new_param_tuple.append(param) 875 cell.__dict__[key] = ParameterTuple(new_param_tuple) 876 return replace 877 878 def parameters_dict(self, recurse=True): 879 """ 880 Gets parameters dictionary. 881 882 Gets the parameters dictionary of this cell. 883 884 Args: 885 recurse (bool): Whether contains the parameters of subcells. Default: True. 886 887 Returns: 888 OrderedDict, return parameters dictionary. 889 """ 890 param_dict = OrderedDict() 891 for param in self.get_parameters(expand=recurse): 892 param_dict[param.name] = param 893 return param_dict 894 895 def parameters_broadcast_dict(self, recurse=True): 896 """ 897 Gets the parameters broadcast dictionary of this cell. 898 899 Args: 900 recurse (bool): Whether contains the parameters of subcells. Default: True. 901 902 Returns: 903 OrderedDict, return parameters broadcast dictionary. 904 """ 905 param_dict = OrderedDict() 906 for param in self.get_parameters(expand=recurse): 907 if param.layerwise_parallel is False: 908 param_dict[param.name] = param 909 if not param_dict: 910 return None 911 return param_dict 912 913 def update_parameters_name(self, prefix='', recurse=True): 914 """ 915 Updates the names of parameters with given prefix string. 916 917 Adds the given prefix to the names of parameters. 918 919 Args: 920 prefix (str): The prefix string. Default: ''. 921 recurse (bool): Whether contains the parameters of subcells. Default: True. 922 """ 923 924 Validator.check_str_by_regular(prefix) 925 for name, param in self.parameters_and_names(expand=recurse): 926 if prefix != '': 927 param.is_init = False 928 param.name = prefix + name 929 930 def trainable_params(self, recurse=True): 931 """ 932 Returns all trainable parameters. 933 934 Returns a list of all trainable parameters. 935 936 Args: 937 recurse (bool): Whether contains the trainable parameters of subcells. Default: True. 938 939 Returns: 940 List, the list of trainable parameters. 941 """ 942 return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse))) 943 944 def untrainable_params(self, recurse=True): 945 """ 946 Returns all untrainable parameters. 947 948 Returns a list of all untrainable parameters. 949 950 Args: 951 recurse (bool): Whether contains the untrainable parameters of subcells. Default: True. 952 953 Returns: 954 List, the list of untrainable parameters. 955 """ 956 return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse))) 957 958 def get_parameters(self, expand=True): 959 """ 960 Returns an iterator over cell parameters. 961 962 Yields parameters of this cell. If `expand` is true, yield parameters of this cell and all subcells. 963 964 Args: 965 expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters 966 that are direct members of this cell. Default: True. 967 968 Returns: 969 Iteration, all parameters at the cell. 970 971 Examples: 972 >>> net = Net() 973 >>> parameters = [] 974 >>> for item in net.get_parameters(): 975 ... parameters.append(item) 976 """ 977 for _, param in self.parameters_and_names(expand=expand): 978 yield param 979 980 def check_names(self): 981 """ 982 Check the names of cell parameters. 983 """ 984 names = set("") 985 for value, param in self.parameters_and_names(): 986 if param.name in names: 987 raise ValueError("The value of {} is {}, its name '{}' already exists.". 988 format(value, param, param.name)) 989 names.add(param.name) 990 991 def parameters_and_names(self, name_prefix='', expand=True): 992 """ 993 Returns an iterator over cell parameters. 994 995 Includes the parameter's name and itself. 996 997 Args: 998 name_prefix (str): Namespace. Default: ''. 999 expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters 1000 that are direct members of this cell. Default: True. 1001 1002 Returns: 1003 Iteration, all the names and corresponding parameters in the cell. 1004 1005 Examples: 1006 >>> n = Net() 1007 >>> names = [] 1008 >>> for m in n.parameters_and_names(): 1009 ... if m[0]: 1010 ... names.append(m[0]) 1011 """ 1012 cells = [] 1013 if expand: 1014 cells = self.cells_and_names(name_prefix=name_prefix) 1015 else: 1016 cells.append((name_prefix, self)) 1017 1018 params_set = set() 1019 for cell_name, cell in cells: 1020 params = cell._params.items() 1021 for par_name, par in params: 1022 if par.inited_param is not None: 1023 par = par.inited_param 1024 if par is not None and id(par) not in params_set: 1025 params_set.add(id(par)) 1026 par_new_name = par_name 1027 if cell_name: 1028 par_new_name = cell_name + '.' + par_new_name 1029 1030 yield par_new_name, par 1031 1032 def cells_and_names(self, cells=None, name_prefix=''): 1033 """ 1034 Returns an iterator over all cells in the network. 1035 1036 Includes the cell's name and itself. 1037 1038 Args: 1039 cells (str): Cells to iterate over. Default: None. 1040 name_prefix (str): Namespace. Default: ''. 1041 1042 Returns: 1043 Iteration, all the child cells and corresponding names in the cell. 1044 1045 Examples: 1046 >>> n = Net() 1047 >>> names = [] 1048 >>> for m in n.cells_and_names(): 1049 ... if m[0]: 1050 ... names.append(m[0]) 1051 """ 1052 t_cells = cells if cells else set() 1053 if self in t_cells: 1054 return 1055 1056 t_cells.add(self) 1057 yield name_prefix, self 1058 1059 for name, cell in self._cells.items(): 1060 if cell: 1061 cells_name_prefix = name 1062 if name_prefix: 1063 cells_name_prefix = name_prefix + '.' + cells_name_prefix 1064 for ele in cell.cells_and_names(t_cells, cells_name_prefix): 1065 yield ele 1066 1067 def cells(self): 1068 """ 1069 Returns an iterator over immediate cells. 1070 1071 Returns: 1072 Iteration, all the child cells in the cell. 1073 """ 1074 return self.name_cells().values() 1075 1076 def _set_scope(self, name): 1077 """Sets the name on the first time.""" 1078 if self._scope is None: 1079 self._scope = name 1080 elif self._scope == 'recompute_': 1081 self._scope = self._scope + name 1082 1083 def _children_scope_recursive(self, parent_prefix='Default'): 1084 """Generates the scope of each layer of the network recursively.""" 1085 reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope") 1086 1087 for name, cell in self.name_cells().items(): 1088 yield parent_prefix + "/" + name + (("-" + cell.__class__.__name__) 1089 if reserve_class_name_in_scope else ""), cell 1090 1091 for name, cell in self.name_cells().items(): 1092 for key, value in cell._children_scope_recursive(parent_prefix + "/" + name + 1093 (("-" + cell.__class__.__name__) 1094 if reserve_class_name_in_scope else "")): 1095 yield key, value 1096 1097 def get_scope(self): 1098 """ 1099 Returns the scope of a cell object in one network. 1100 1101 Returns: 1102 String, scope of the cell. 1103 """ 1104 return self._scope 1105 1106 def generate_scope(self): 1107 """Generate the scope for each cell object in the network.""" 1108 for name, cell in self._children_scope_recursive(): 1109 cell._set_scope(name) 1110 1111 def name_cells(self): 1112 """ 1113 Returns an iterator over all cells in the network. 1114 1115 Include name of the cell and cell itself. 1116 1117 Returns: 1118 Dict[String, Cell], all the child cells and corresponding names in the cell. 1119 """ 1120 value_set = set() 1121 cells = OrderedDict() 1122 for name, cell in self._cells.items(): 1123 if cell is not None and cell not in value_set: 1124 value_set.add(cell) 1125 cells[name] = cell 1126 return cells 1127 1128 def add_flags(self, **flags): 1129 """ 1130 Add customized attributes for cell. 1131 1132 This method is also called when the cell class is instantiated and the class parameter 'flag' is set to True. 1133 """ 1134 if not hasattr(self, "_mindspore_flags"): 1135 self._mindspore_flags = {} 1136 self._mindspore_flags.update({**flags}) 1137 self.__dict__.update({**flags}) 1138 return self 1139 1140 def add_flags_recursive(self, **flags): 1141 """ 1142 If a cell contains child cells, this method can recursively customize attributes of all cells. 1143 """ 1144 self.add_flags(**flags) 1145 for cell in self.cells(): 1146 cell.add_flags_recursive(**flags) 1147 return self 1148 1149 def _add_init_args(self, **args): 1150 if hasattr(self, '_cell_init_args'): 1151 self._cell_init_args += str({**args}) 1152 1153 def get_flags(self): 1154 """ 1155 Get the attributes of cell's flags. 1156 """ 1157 if not hasattr(self, "_mindspore_flags"): 1158 self._mindspore_flags = {} 1159 return self._mindspore_flags 1160 1161 def to_float(self, dst_type): 1162 """ 1163 Add cast on all inputs of cell and child cells to run with certain float type. 1164 1165 If `dst_type is mindspore.dtype.float16`, all the inputs of Cell including input, Parameter, Tensor 1166 as const will be cast to float16. Please refer to the usage in source code of 1167 `mindspore.train.amp.build_train_network`. 1168 1169 Note: 1170 Multiple calls will overwrite. 1171 1172 Args: 1173 dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type. 1174 dst_type can be `mindspore.dtype.float16` or `mindspore.dtype.float32`. 1175 1176 Returns: 1177 Cell, the cell itself. 1178 1179 Raises: 1180 ValueError: If dst_type is not float32 or float16. 1181 """ 1182 if dst_type not in (mstype.float16, mstype.float32): 1183 raise ValueError("The dst_type should inside float32 or float16.") 1184 flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32} 1185 self.add_flags_recursive(**flags) 1186 self._add_init_args(**flags) 1187 return self 1188 1189 def set_boost(self, boost_type): 1190 """ 1191 In order to improve the network performance, configure the network auto enable to 1192 accelerate the algorithm in the algorithm library. 1193 1194 If `boost_type is not in the algorithm library`, Please view the algorithm in the algorithm library 1195 through `algorithm library`. 1196 1197 Note: 1198 Some acceleration algorithms may affect the accuracy of the network, please choose carefully. 1199 1200 Args: 1201 boost_type (str): accelerate algorithm. 1202 1203 Returns: 1204 Cell, the cell itself. 1205 1206 Raises: 1207 ValueError: If boost_type is not in the algorithm library. 1208 """ 1209 if boost_type not in ("less_bn",): 1210 raise ValueError("The boost_type is not in the algorithm library.") 1211 flags = {"less_bn": boost_type == "less_bn"} 1212 self.add_flags_recursive(**flags) 1213 return self 1214 1215 def set_grad(self, requires_grad=True): 1216 """ 1217 Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network require 1218 gradients. If true, the backward network needed to compute the gradients will be generated when the forward 1219 network is executed. 1220 1221 Args: 1222 requires_grad (bool): Specifies if the net need to grad, if it is 1223 true, the cell will construct backward network in pynative mode. Default: True. 1224 1225 Returns: 1226 Cell, the cell itself. 1227 """ 1228 self.requires_grad = requires_grad 1229 return self 1230 1231 def set_train(self, mode=True): 1232 """ 1233 Sets the cell to training mode. 1234 1235 The cell itself and all children cells will be set to training mode. Layers that have different constructions 1236 for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If 1237 set to true, the training branch will be executed, otherwise another branch. 1238 1239 Args: 1240 mode (bool): Specifies whether the model is training. Default: True. 1241 1242 Returns: 1243 Cell, the cell itself. 1244 """ 1245 if mode is False: 1246 self._phase = 'predict' 1247 else: 1248 self._phase = 'train' 1249 self.add_flags_recursive(training=mode) 1250 return self 1251 1252 def set_broadcast_flag(self, mode=True): 1253 """ 1254 Set the cell to data_parallel mode. 1255 1256 The cell can be accessed as an attribute using the given name. 1257 1258 Args: 1259 mode (bool): Specifies whether the model is data_parallel. Default: True. 1260 """ 1261 self.add_flags_recursive(broadcast_flag=mode) 1262 return self 1263 1264 def set_auto_parallel(self): 1265 """ 1266 Set the cell to auto parallel mode. 1267 1268 Note: 1269 If a cell needs to use the auto parallel or semi auto parallel mode for training, evaluation or prediction, 1270 this interface needs to be called by the cell. 1271 """ 1272 self._auto_parallel_mode = True 1273 self.add_flags(auto_parallel=True) 1274 self._get_construct_inputs_number_and_name() 1275 1276 def _hook_construct(self, *inputs): 1277 """Hook construct method to replace original construct method when hook function enabled.""" 1278 inputs = self._backward_hook(*inputs) 1279 inputs = self.construct(inputs) 1280 outputs = self._backward_hook(inputs) 1281 return outputs 1282 1283 def register_backward_hook(self, fn): 1284 """ 1285 Set the cell backward hook function. Note that this function is only supported in pynative mode. 1286 1287 Note: 1288 fn must be defined as the following code. `cell_name` is the name of registered cell. 1289 `grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the 1290 next cell or primitive, which may be modified and returned. 1291 hook_fn(cell_name, grad_input, grad_output) -> Tensor or None. 1292 1293 Args: 1294 fn (function): Specifies the hook function with grad as input. 1295 1296 """ 1297 self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") 1298 self.enable_hook = True 1299 1300 def set_param_ps(self, recurse=True, init_in_server=False): 1301 """ 1302 Set whether the trainable parameters are updated by parameter server and whether the 1303 trainable parameters are initialized on server. 1304 1305 Note: 1306 It only works when a running task is in the parameter server mode. 1307 1308 Args: 1309 recurse (bool): Whether sets the trainable parameters of subcells. Default: True. 1310 init_in_server (bool): Whether trainable parameters updated by parameter server are 1311 initialized on server. Default: False. 1312 """ 1313 params = self.trainable_params(recurse) 1314 for param in params: 1315 param.set_param_ps(init_in_server) 1316 1317 def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True): 1318 """ 1319 Set the way of parameter and server interaction. 1320 1321 Args: 1322 push_to_server (bool): Whether the parameter should be pushed to server. Default: False. 1323 pull_from_server (bool): Whether the parameter should be pulled from server. Default: False. 1324 requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True. 1325 """ 1326 params = self.parameters_and_names() 1327 for param in params: 1328 param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr) 1329 1330 def set_comm_fusion(self, fusion_type, recurse=True): 1331 """ 1332 Set `comm_fusion` for all the parameters in the Net. Please refer to the description of 1333 `mindspore.common.parameter.comm_fusion`. 1334 1335 Note: 1336 The value of attribute will be overwritten when the function is called multiply. 1337 1338 Args: 1339 fusion_type (int): The value of `comm_fusion`. 1340 recurse (bool): Whether sets the trainable parameters of subcells. Default: True. 1341 """ 1342 Validator.check_non_negative_int(fusion_type) 1343 for param in self.trainable_params(recurse): 1344 param.comm_fusion = fusion_type 1345 return self 1346 1347 def _set_recompute_scope(self, mode): 1348 prefix = 'recompute_' 1349 if mode is True: 1350 if self._scope is None: 1351 self._scope = prefix 1352 elif not self._scope.startswith(prefix): 1353 self._scope = prefix + self._scope 1354 elif self._scope is not None and self._scope.startswith(prefix): 1355 self._scope = self._scope[len(prefix):] 1356 1357 def _mp_comm_recompute(self, mp_comm_recompute=True): 1358 """ 1359 Set the model parallel communication in cell recomputed. 1360 """ 1361 for _, value in self._primitives.items(): 1362 if value: 1363 value.add_prim_attr("recompute_comm_op", mp_comm_recompute) 1364 for cell in self.cells(): 1365 cell._mp_comm_recompute(mp_comm_recompute) 1366 1367 def _parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute=False): 1368 """ 1369 Set the parallel optimizer communication in cell recomputed. 1370 """ 1371 for param in self.trainable_params(): 1372 param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute 1373 1374 def _recompute(self, mode=True, output_recompute=False): 1375 """ 1376 Set the cell recomputed. 1377 """ 1378 if context.get_context("mode") == context.PYNATIVE_MODE: 1379 raise TypeError("Recompute is not supported in pynative mode currently.") 1380 Validator.check_bool(mode) 1381 Validator.check_bool(output_recompute) 1382 if not self._has_config_recompute: 1383 self._has_config_recompute = True 1384 else: 1385 raise RuntimeError("The recompute interface can be configured only once." 1386 " When the parent cell is configured, the child cell should not be configured") 1387 self._set_recompute_scope(mode) 1388 if mode and not output_recompute: 1389 self.add_flags(output_no_recompute=True) 1390 for cell in self.cells(): 1391 cell._recompute(mode, True) 1392 1393 @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool) 1394 def recompute(self, **kwargs): 1395 """ 1396 Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive 1397 set recomputed feeds into some backward nodes for computing gradient, rather than storing the 1398 intermediate activation computed in forward pass, we will recompute it in backward pass. 1399 1400 Note: 1401 1402 - If the computation involves something like randomization or global variable, the equivalence 1403 is not guaranteed currently. 1404 - If the recompute api of a primitive in this cell is also called, the recompute mode of this 1405 primitive is subject to the recompute api of the primitive. 1406 - The interface can be configured only once. 1407 Therefore, when the parent cell is configured, the child cell should not be configured. 1408 - When the memory remains after applying the recompute, configuring 'mp_comm_recompute=False' 1409 to improve performance if necessary. 1410 - When the memory still not enough after applying the recompute, configuring 1411 'parallel_optimizer_comm_recompute=True' to save more memory if necessary. 1412 Cells in the same fusion group should has the same parallel_optimizer_comm_recompute configures. 1413 1414 Args: 1415 mp_comm_recompute (bool): Specifies whether the model parallel communication operators 1416 in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True. 1417 parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers 1418 introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. 1419 Default: False. 1420 """ 1421 self._recompute() 1422 if 'mp_comm_recompute' in kwargs.keys(): 1423 self._mp_comm_recompute(kwargs['mp_comm_recompute']) 1424 if 'parallel_optimizer_comm_recompute' in kwargs.keys(): 1425 if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1: 1426 raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard " 1427 "are not support recomputation in pipeline parallel.") 1428 self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute']) 1429 1430 for key, _ in kwargs.items(): 1431 if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'): 1432 raise ValueError("Recompute keyword %s is not recognized!" % key) 1433 1434 def infer_param_pipeline_stage(self): 1435 """ 1436 Infer pipeline stages of all parameters in the cell. 1437 1438 Note: 1439 - If a parameter does not belong to any cell which has been set pipeline_stage, 1440 the parameter should use add_pipeline_stage to add it's pipeline_stage information. 1441 - If a parameter P has been used by two operator in different stages "stageA" and "stageB", 1442 the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) 1443 to add it's stage information before use infer_param_pipeline_stage. 1444 1445 Returns: 1446 The params belong to current stage in pipeline parallel. 1447 1448 Raises: 1449 RuntimeError: If there is a parameter does not belong to any stage. 1450 """ 1451 from mindspore.parallel._utils import _get_global_rank, _get_device_num 1452 stage_num = context.get_auto_parallel_context("pipeline_stages") 1453 device_num = _get_device_num() 1454 rank_id = _get_global_rank() 1455 per_stage_devices = device_num // stage_num 1456 current_stage = rank_id // per_stage_devices 1457 params = [] 1458 for param in self.trainable_params(): 1459 if not param._pipeline_stage_list: 1460 raise RuntimeError("The parameter {} does not belong to any stage, " 1461 "please check whether the cell where the param locates" 1462 " has been set pipeline_stage. " 1463 "Otherwise, the parameter should use add_pipeline_stage " 1464 "to add its stage information".format(param.name)) 1465 if current_stage in param._pipeline_stage_list: 1466 params.append(param) 1467 return params 1468 1469 1470class GraphKernel(Cell): 1471 """ 1472 Base class for GraphKernel. 1473 1474 A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when 1475 enable_graph_kernel in context is set to True. 1476 1477 This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead. 1478 1479 GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as 1480 normal `Cell` objects. 1481 1482 Args: 1483 auto_prefix (bool): Recursively generate namespaces. Default: True. 1484 flags (dict) : Set graph flags. Default: None. 1485 1486 Supported Platforms: 1487 ``Ascend`` ``GPU`` ``CPU`` 1488 1489 Examples: 1490 >>> class Relu(nn.GraphKernel): 1491 ... def __init__(self): 1492 ... super(Relu, self).__init__() 1493 ... self.max = P.Maximum() 1494 ... 1495 ... def construct(self, x): 1496 ... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) 1497 """ 1498 1499 @deprecated("1.3", "Cell", True) 1500 def __init__(self, auto_prefix=True, flags=None): 1501 super(GraphKernel, self).__init__(auto_prefix, flags) 1502 1503 def construct(self): 1504 raise NotImplementedError 1505 1506 1507class GraphCell(Cell): 1508 """ 1509 Base class for running the graph loaded from a MindIR. 1510 1511 This feature is still under development. Currently `GraphCell` do not support modifying the structure of the 1512 diagram, and can only use data that shape and type are the same as the input when exporting the MindIR. 1513 1514 Args: 1515 graph (object): A compiled graph loaded from MindIR. 1516 1517 Supported Platforms: 1518 ``Ascend`` ``GPU`` ``CPU`` 1519 1520 Examples: 1521 >>> import numpy as np 1522 >>> import mindspore.nn as nn 1523 >>> from mindspore import Tensor, export, load 1524 >>> 1525 >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") 1526 >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 1527 >>> export(net, input, file_name="net", file_format="MINDIR") 1528 >>> graph = load("net.mindir") 1529 >>> net = nn.GraphCell(graph) 1530 >>> output = net(input) 1531 >>> print(output) 1532 [[[[4. 6. 4.] 1533 [6. 9. 6.] 1534 [4. 6. 4.]]]] 1535 """ 1536 def __init__(self, graph): 1537 super(GraphCell, self).__init__(auto_prefix=True) 1538 if not isinstance(graph, FuncGraph): 1539 raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.") 1540 self.graph = graph 1541 1542 def construct(self, *inputs): 1543 return self.graph(*inputs) 1544 1545 def __call__(self, *inputs): 1546 self.phase = "graph_load_from_mindir" 1547 self._add_attr("graph_load_from_mindir", self.graph) 1548 return self.compile_and_run(*inputs) 1549