1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""optimizer""" 16import inspect 17from typing import Iterable 18 19import numpy as np 20 21import mindspore 22from mindspore.ops import functional as F, composite as C, operations as P 23from mindspore.ops.operations import _inner_ops as inner 24from mindspore.nn.cell import Cell 25from mindspore.nn.layer.container import CellList 26from mindspore.common.parameter import Parameter, ParameterTuple 27from mindspore.common.initializer import initializer 28from mindspore.common.tensor import Tensor, RowTensor 29import mindspore.common.dtype as mstype 30from mindspore._checkparam import Validator as validator 31from mindspore import log as logger 32from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode 33from mindspore.context import ParallelMode 34from mindspore import context 35from mindspore.nn.learning_rate_schedule import LearningRateSchedule 36 37__all__ = ['Optimizer', 'opt_init_args_register'] 38 39 40def opt_init_args_register(fn): 41 """Register optimizer init args.""" 42 def deco(self, *args, **kwargs): 43 bound_args = inspect.signature(fn).bind(self, *args, **kwargs) 44 bound_args.apply_defaults() 45 arguments = bound_args.arguments 46 arguments.pop('self') 47 if 'params' in arguments.keys(): 48 setattr(self, 'init_params', dict({"params": arguments['params']})) 49 arguments.pop('params') 50 if 'optimizer' in arguments.keys(): 51 setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]})) 52 arguments.pop('optimizer') 53 setattr(self, 'init_args', arguments) 54 fn(self, *args, **kwargs) 55 return deco 56 57 58class Optimizer(Cell): 59 """ 60 Base class for all optimizers. 61 62 Note: 63 This class defines the API to add Ops to train a model. Never use 64 this class directly, but instead instantiate one of its subclasses. 65 66 Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`. 67 68 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 69 weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will 70 be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 71 72 When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True, 73 but the gradient centralization can only be applied to the parameters of the convolution layer. 74 If the parameters of the non convolution layer are set to True, an error will be reported. 75 76 To improve parameter groups performance, the customized order of parameters can be supported. 77 78 Args: 79 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning 80 rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then 81 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 82 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 83 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 84 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 85 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 86 parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be 87 updated, the element in `parameters` must be class `Parameter`. When the `parameters` is a list of `dict`, 88 the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed. 89 90 - params: Required. The value must be a list of `Parameter`. 91 92 - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. 93 If not, the `learning_rate` in the API will be used. 94 95 - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay 96 will be used. If not, the `weight_decay` in the API will be used. 97 98 - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and 99 the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which 100 in the value of 'order_params' must be in one of group parameters. 101 102 - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization" 103 is in the keys, the set value will be used. If not, the `grad_centralization` is False by default. 104 This parameter only works on the convolution layer. 105 106 weight_decay (Union[float, int]): An int or a floating point value for the weight decay. 107 It must be equal to or greater than 0. 108 If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. 109 loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the 110 type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only 111 when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in 112 `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in 113 `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details. 114 Default: 1.0. 115 116 Raises: 117 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 118 TypeError: If element of `parameters` is neither Parameter nor dict. 119 TypeError: If `loss_scale` is not a float. 120 TypeError: If `weight_decay` is neither float nor int. 121 ValueError: If `loss_scale` is less than or equal to 0. 122 ValueError: If `weight_decay` is less than 0. 123 ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1. 124 125 Supported Platforms: 126 ``Ascend`` ``GPU`` ``CPU`` 127 """ 128 129 def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): 130 super(Optimizer, self).__init__(auto_prefix=False) 131 parameters = self._parameters_base_check(parameters, "parameters") 132 if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters): 133 raise TypeError("All elements of the optimizer parameters must be of type `Parameter` or `dict`.") 134 135 if isinstance(loss_scale, int): 136 loss_scale = float(loss_scale) 137 validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) 138 validator.check_positive_float(loss_scale, "loss_scale", self.cls_name) 139 self.loss_scale = loss_scale 140 141 weight_decay = self._preprocess_weight_decay(weight_decay) 142 self.grad_centralization = False 143 144 self._unique = True 145 self._target = context.get_context("device_target") 146 self.dynamic_lr = False 147 self.assignadd = None 148 self.global_step = None 149 self.is_group = False 150 self.is_group_lr = False 151 self.is_group_params_ordered = False 152 learning_rate = self._preprocess_single_lr(learning_rate) 153 if isinstance(parameters[0], dict): 154 self.is_group = True 155 self.group_params = [] 156 self.group_lr = [] 157 self.group_weight_decay = [] 158 self.group_grad_centralization = [] 159 self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization) 160 161 # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params 162 if self.dynamic_lr: 163 self.assignadd = P.AssignAdd() 164 self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') 165 166 if self.is_group_lr: 167 self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \ 168 else ParameterTuple(self.group_lr) 169 else: 170 self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') 171 172 if self.is_group: 173 self.parameters = ParameterTuple(self.group_params) 174 self.weight_decay = tuple(self.group_weight_decay) 175 self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay) 176 decay_filter = lambda x: x > 0 177 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) 178 self.exec_weight_decay = any(self.decay_flags) 179 self.grad_centralization_flags = tuple(self.group_grad_centralization) 180 else: 181 self.parameters = ParameterTuple(parameters) 182 self.weight_decay = weight_decay * loss_scale 183 self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) 184 decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name 185 self.decay_flags = tuple(decay_filter(x) for x in self.parameters) 186 self.exec_weight_decay = self.weight_decay > 0 187 # when a parameter has been unique, there is no need do another unique in optimizer. 188 for param in self.parameters: 189 if param.unique: 190 self._unique = False 191 break 192 ps_filter = lambda x: x.is_param_ps 193 self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) 194 cache_filter = lambda x: x.cache_enable 195 self.cache_enable = tuple(cache_filter(x) for x in self.parameters) 196 self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) 197 self.need_scale = loss_scale != 1.0 198 self.global_step_increase_tensor = Tensor(1, mstype.int32) 199 self.param_length = len(self.parameters) 200 self.map_ = C.Map() 201 self.map_reverse = C.Map(None, True) 202 self.hyper_map = C.HyperMap() 203 self.hyper_map_reverse = C.HyperMap(None, True) 204 self._use_parallel_optimizer() 205 206 def _use_parallel_optimizer(self): 207 """Indicates whether to use automatic parallelism.""" 208 if context.get_auto_parallel_context("enable_parallel_optimizer"): 209 if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": 210 self.use_parallel = True 211 elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \ 212 and context.get_context("device_target") != "Ascend": 213 raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.") 214 elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL): 215 raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode())) 216 else: 217 self.use_parallel = False 218 else: 219 self.use_parallel = False 220 if self.use_parallel: 221 if self.cls_name not in ["Lamb", "AdamWeightDecay", "AdaFactor"]: 222 raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name)) 223 self.dev_num = _get_device_num() 224 if self.dev_num > self.param_length: 225 raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is" 226 " less than the number of devices {}".format(self.param_length, self.dev_num)) 227 self.param_rank = self._get_parameter_group_id() 228 self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) 229 self.param_names = [] 230 for param in self.parameters: 231 self.param_names.append(param.name) 232 else: 233 self.optim_filter = (True,) * self.param_length 234 235 @property 236 def unique(self): 237 """The method is to see whether to make unique. The input type is bool. The method is read-only.""" 238 return self._unique 239 240 @unique.setter 241 def unique(self, value): 242 """Set whether the input value is unique.""" 243 if not isinstance(value, bool): 244 raise TypeError("The value type must be bool, but got value type is {}".format(type(value))) 245 self._unique = value 246 247 @property 248 def target(self): 249 """ 250 The method is used to determine whether the parameter is updated on host or device. The input type is str 251 and can only be 'CPU', 'Ascend' or 'GPU'. 252 """ 253 return self._target 254 255 @target.setter 256 def target(self, value): 257 """ 258 If the input value is set to "CPU", the parameters will be updated on the host using the Fused 259 optimizer operation. 260 """ 261 raise NotImplementedError 262 263 def _set_base_target(self, value): 264 """ 265 If the input value is set to "CPU", the parameters will be updated on the host using the Fused 266 optimizer operation. 267 """ 268 if not isinstance(value, str): 269 raise TypeError("The value must be str type, but got value type is {}".format(type(value))) 270 271 if value not in ('CPU', 'Ascend', 'GPU'): 272 raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) 273 274 if self._target == "CPU" and value in ('Ascend', 'GPU'): 275 raise ValueError("In the CPU environment, target cannot be set to 'GPU' or 'Ascend'.") 276 277 if self._target == "Ascend" and value == 'GPU': 278 raise ValueError("In the Ascend environment, target cannot be set to 'GPU'.") 279 280 if self._target == "GPU" and value == 'Ascend': 281 raise ValueError("In the GPU environment, target cannot be set to 'Ascend'.") 282 283 self._is_device = (value != 'CPU') 284 self._target = value 285 286 def decay_weight(self, gradients): 287 """ 288 Weight decay. 289 290 An approach to reduce the overfitting of a deep learning neural network model. 291 292 Args: 293 gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as 294 `self.parameters`. 295 296 Returns: 297 tuple[Tensor], The gradients after weight decay. 298 """ 299 if self.exec_weight_decay: 300 params = self.parameters 301 if self.is_group: 302 gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags, 303 params, gradients) 304 else: 305 gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags, 306 params, gradients) 307 308 return gradients 309 310 def gradients_centralization(self, gradients): 311 """ 312 Gradients centralization. 313 314 A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural 315 network model. 316 317 Args: 318 gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as 319 `self.parameters`. 320 321 Returns: 322 tuple[Tensor], The gradients after gradients centralization. 323 """ 324 if self.is_group: 325 gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients) 326 327 return gradients 328 329 def scale_grad(self, gradients): 330 """ 331 Loss scale for mixed precision. 332 333 An approach of mixed precision training to improve the speed and energy efficiency of training deep neural 334 network. 335 336 Args: 337 gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as 338 `self.parameters`. 339 340 Returns: 341 tuple[Tensor], The gradients after loss scale. 342 343 """ 344 if self.need_scale: 345 gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) 346 347 return gradients 348 349 def _grad_sparse_indices_deduplicate(self, gradients): 350 """ In the case of using big operators, deduplicate the 'indexes' in gradients.""" 351 if self._target != 'CPU' and self._unique: 352 gradients = self.map_(F.partial(_indices_deduplicate), gradients) 353 return gradients 354 355 def _preprocess_weight_decay(self, weight_decay): 356 """Check weight decay, and convert int to float.""" 357 if isinstance(weight_decay, (float, int)): 358 weight_decay = float(weight_decay) 359 validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name) 360 return weight_decay 361 raise TypeError("Weight decay should be int or float.") 362 363 def _preprocess_grad_centralization(self, grad_centralization): 364 if not isinstance(grad_centralization, bool): 365 raise TypeError("The gradients centralization should be bool") 366 return grad_centralization 367 368 def _preprocess_single_lr(self, learning_rate): 369 """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" 370 if isinstance(learning_rate, (float, int)): 371 learning_rate = float(learning_rate) 372 validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name) 373 return learning_rate 374 if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: 375 return learning_rate 376 377 self.dynamic_lr = True 378 if isinstance(learning_rate, Iterable): 379 return Tensor(np.array(list(learning_rate)).astype(np.float32)) 380 if isinstance(learning_rate, Tensor): 381 if learning_rate.ndim > 1: 382 raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," 383 f"but got {learning_rate.ndim}.") 384 if learning_rate.ndim == 1 and learning_rate.size < 2: 385 logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" 386 "of elements in the tensor is greater than 1.") 387 return learning_rate 388 if isinstance(learning_rate, LearningRateSchedule): 389 return learning_rate 390 raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.") 391 392 def _build_single_lr(self, learning_rate, name): 393 """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule.""" 394 if isinstance(learning_rate, float): 395 learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name) 396 if self.is_group_lr and self.dynamic_lr: 397 learning_rate = _ConvertToCell(learning_rate) 398 return learning_rate 399 if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: 400 learning_rate = Parameter(learning_rate, name) 401 if self.is_group_lr and self.dynamic_lr: 402 learning_rate = _ConvertToCell(learning_rate) 403 return learning_rate 404 if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: 405 return _IteratorLearningRate(learning_rate, name) 406 return learning_rate 407 408 def _parameters_base_check(self, parameters, param_info): 409 if parameters is None: 410 raise ValueError(f"Optimizer {param_info} can not be None.") 411 if not isinstance(parameters, Iterable): 412 raise TypeError(f"Optimizer {param_info} must be Iterable.") 413 parameters = list(parameters) 414 415 if not parameters: 416 raise ValueError(f"Optimizer got an empty {param_info} list.") 417 return parameters 418 419 def _check_group_params(self, parameters): 420 """Check group params.""" 421 parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization'] 422 for group_param in parameters: 423 invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) 424 if invalid_key: 425 raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') 426 427 if 'order_params' in group_param.keys(): 428 if len(group_param.keys()) > 1: 429 raise ValueError("The order params dict in group parameters should " 430 "only include the 'order_params' key.") 431 if not isinstance(group_param['order_params'], Iterable): 432 raise TypeError("The value of 'order_params' should be an Iterable type.") 433 continue 434 435 parameters = self._parameters_base_check(group_param['params'], "group `params`") 436 if not all(isinstance(x, Parameter) for x in parameters): 437 raise TypeError("The group `params` should be an iterator of Parameter type.") 438 439 def _parse_group_params(self, parameters, learning_rate): 440 """Parse group params.""" 441 self._check_group_params(parameters) 442 if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: 443 tensor_lr_length = learning_rate.size 444 else: 445 tensor_lr_length = 0 446 447 for group_param in parameters: 448 if 'order_params' in group_param.keys(): 449 if len(group_param.keys()) > 1: 450 raise ValueError("The order params dict in group parameters should " 451 "only include the 'order_params' key.") 452 if not isinstance(group_param['order_params'], Iterable): 453 raise TypeError("The value of 'order_params' should be an Iterable type.") 454 self.is_group_params_ordered = True 455 continue 456 457 if 'lr' in group_param.keys(): 458 self.is_group_lr = True 459 group_lr = self._preprocess_single_lr(group_param['lr']) 460 461 if isinstance(group_lr, Tensor) and group_lr.ndim == 1: 462 group_lr_length = group_lr.size 463 if tensor_lr_length == 0: 464 tensor_lr_length = group_lr_length 465 elif group_lr_length != tensor_lr_length: 466 raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") 467 468 def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization): 469 """Initialize learning rate, weight decay or grad centralization in group params.""" 470 self._parse_group_params(parameters, learning_rate) 471 default_lr = self._build_single_lr(learning_rate, 'learning_rate') 472 473 params_store = [] 474 for group_num, group_param in enumerate(parameters): 475 if 'order_params' in group_param.keys(): 476 ordered_parameters = group_param['order_params'] 477 continue 478 479 self.group_params += group_param['params'] 480 481 if 'lr' in group_param.keys(): 482 lr_param_name = 'learning_rate_group_' + str(group_num) 483 lr = self._preprocess_single_lr(group_param['lr']) 484 lr = self._build_single_lr(lr, lr_param_name) 485 else: 486 lr = default_lr 487 488 if 'weight_decay' in group_param.keys(): 489 cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay']) 490 weight_decay_ = cur_weight_decay * self.loss_scale 491 else: 492 weight_decay_ = weight_decay * self.loss_scale 493 494 if 'grad_centralization' in group_param.keys(): 495 self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) 496 for param in group_param['params']: 497 validator.check_value_type("parameter", param, [Parameter], self.cls_name) 498 grad_centralization_ = self.grad_centralization 499 else: 500 grad_centralization_ = grad_centralization 501 502 for key in group_param.keys(): 503 if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'): 504 logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") 505 506 for param in group_param['params']: 507 validator.check_value_type("parameter", param, [Parameter], self.cls_name) 508 if param.name in params_store: 509 raise RuntimeError(f"The {param.name} parameter already exists in parameter groups, " 510 f"duplicate parameters are not supported.") 511 512 params_store.append(param.name) 513 self.group_lr.append(lr) 514 self.group_weight_decay.append(weight_decay_) 515 self.group_grad_centralization.append(grad_centralization_) 516 517 if self.is_group_params_ordered: 518 self._order_and_adjust_group_params(ordered_parameters) 519 520 def _order_and_adjust_group_params(self, ordered_parameters): 521 """ 522 Order group parameter, learning rate, weight decay and grad centralization in group params. 523 """ 524 params_length = len(self.group_params) 525 if len(ordered_parameters) != len(self.group_params): 526 raise ValueError(f"The value of 'order_params' should be same with all group parameters.") 527 528 ordered_params = [None] * params_length 529 ordered_learning_rate = [None] * params_length 530 ordered_weight_decay = [None] * params_length 531 ordered_grad_centralization = [None] * params_length 532 params_name = [param.name for param in ordered_parameters] 533 534 for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay, 535 self.group_grad_centralization): 536 index = params_name.index(param.name) 537 ordered_params[index] = param 538 ordered_learning_rate[index] = lr 539 ordered_weight_decay[index] = wd 540 ordered_grad_centralization[index] = gc 541 542 self.group_params = ordered_params 543 self.group_lr = ordered_learning_rate 544 self.group_weight_decay = ordered_weight_decay 545 self.group_grad_centralization = ordered_grad_centralization 546 547 def get_lr(self): 548 """ 549 Get the learning rate of current step. 550 551 Returns: 552 float, the learning rate of current step. 553 """ 554 lr = self.learning_rate 555 if self.dynamic_lr: 556 if self.is_group_lr: 557 lr = () 558 for learning_rate in self.learning_rate: 559 current_dynamic_lr = learning_rate(self.global_step) 560 lr += (current_dynamic_lr,) 561 else: 562 lr = self.learning_rate(self.global_step) 563 564 self.assignadd(self.global_step, self.global_step_increase_tensor) 565 return lr 566 567 def get_lr_parameter(self, param): 568 """ 569 Get the learning rate of parameter. 570 571 Args: 572 param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`. 573 574 Returns: 575 Parameter, single `Parameter` or `list[Parameter]` according to the input type. 576 """ 577 def get_lr_value(learning_rate): 578 if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)): 579 return learning_rate.learning_rate 580 581 return learning_rate 582 583 if isinstance(param, Parameter): 584 param_list = [param] 585 elif isinstance(param, list): 586 param_list = param 587 else: 588 raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") 589 590 lr = [] 591 ids = [id(p) for p in self.parameters] 592 for p in param_list: 593 validator.check_value_type("parameter", p, [Parameter], self.cls_name) 594 if id(p) not in ids: 595 raise ValueError(f"The parameter {p.name} is not in optimizer.") 596 if self.is_group_lr: 597 index = ids.index(id(p)) 598 lr.append(get_lr_value(self.learning_rate[index])) 599 else: 600 lr.append(get_lr_value(self.learning_rate)) 601 602 return lr if isinstance(param, list) else lr[0] 603 604 def _get_parameter_group_id(self): 605 """ 606 Get the parameter partition group id, which is less than the number of devices. 607 608 Returns: 609 tuple, the group id tuple of parameters. 610 """ 611 rank_list = () 612 count = 0 613 for _ in range(self.param_length): 614 rank_list = rank_list + (count,) 615 count = count + 1 616 if count == self.dev_num: 617 count = 0 618 return rank_list 619 620 def broadcast_params(self, optim_result): 621 """ 622 Apply Broadcast operations in the sequential order of parameter groups. 623 624 Returns: 625 bool, the status flag. 626 """ 627 param_group = [] 628 key_group = [] 629 for _ in range(self.dev_num): 630 param_group.append(F.make_tuple()) 631 key_group.append(F.make_tuple()) 632 for i in range(self.param_length): 633 param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) 634 key = P.MakeRefKey(self.param_names[i])() 635 key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) 636 new_param_group = [] 637 for root in range(self.dev_num): 638 ops = P.Broadcast(root) 639 if root > 0: 640 param_group[root] = F.depend(param_group[root], new_param_group[root-1]) 641 else: 642 param_group[root] = F.depend(param_group[root], optim_result) 643 next_params = ops(param_group[root]) 644 new_param_group.append(next_params) 645 for i in range(F.tuple_len(next_params)): 646 F.assign(key_group[root][i], next_params[i]) 647 return new_param_group 648 649 def construct(self, *hyper_params): 650 raise NotImplementedError 651 652 653op_add = P.AddN() 654op_gather = P.Gather() 655op_mul = P.Mul() 656op_gc = inner.Centralization() 657 658_apply_decay = C.MultitypeFuncGraph("apply_decay") 659_apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization") 660 661 662@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor") 663def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): 664 """Get grad with weight_decay.""" 665 if if_apply: 666 indices = gradient.indices 667 values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values)) 668 shape = gradient.dense_shape 669 return RowTensor(indices, values, shape) 670 return gradient 671 672 673@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor") 674def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): 675 """Get grad with weight_decay.""" 676 if if_apply: 677 return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient)) 678 return gradient 679 680 681@_apply_grad_centralization.register("Bool", "RowTensor") 682def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): 683 """Get grad with grad_centralization.""" 684 if if_apply: 685 indices = gradient.indices 686 shape = gradient.dense_shape 687 grad_shape = F.shape(gradient) 688 axis = [] 689 for i in range(1, len(grad_shape)): 690 axis.append(i) 691 if len(axis) >= 1: 692 if grad_shape[1] % 16 != 0: 693 return gradient 694 values = op_gc(gradient.values, axis) 695 return RowTensor(indices, values, shape) 696 return gradient 697 698 699@_apply_grad_centralization.register("Bool", "Tensor") 700def _tensor_apply_grad_centralization(if_apply, gradient): 701 """Get grad with grad_centralization.""" 702 if if_apply: 703 axis = [] 704 grad_shape = F.shape(gradient) 705 for i in range(1, len(grad_shape)): 706 axis.append(i) 707 if len(axis) >= 1: 708 if grad_shape[1] % 16 != 0: 709 return gradient 710 return op_gc(gradient, axis) 711 return gradient 712 713 714_grad_scale = C.MultitypeFuncGraph("grad_scale") 715_indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate") 716 717 718@_grad_scale.register("Number", "Tensor") 719def tensor_grad_scale(scale, grad): 720 """Get grad with scale.""" 721 if scale == 1.0: 722 return grad 723 return op_mul(grad, F.cast(scale, F.dtype(grad))) 724 725 726@_grad_scale.register("Tensor", "Tensor") 727def tensor_grad_scale_with_tensor(scale, grad): 728 """Get grad with scale.""" 729 return op_mul(grad, F.cast(scale, F.dtype(grad))) 730 731 732@_grad_scale.register("Tensor", "RowTensor") 733def tensor_grad_scale_with_sparse(scale, grad): 734 """Get grad with scale.""" 735 return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape) 736 737 738@_indices_deduplicate.register("RowTensor") 739def rowtensor_deduplicate_indices_slices(grad): 740 """Unique the indices and sums the 'values' corresponding to the duplicate indices.""" 741 indices = grad.indices 742 values = grad.values 743 744 unique_indices, index_position = P.Unique()(indices) 745 summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0]) 746 747 return RowTensor(unique_indices, summed_values, grad.dense_shape) 748 749 750@_indices_deduplicate.register("Tensor") 751def tensor_deduplicate_indice_slices(grad): 752 """Return the input gradient directly in the dense sences.""" 753 return grad 754 755 756class _ConvertToCell(LearningRateSchedule): 757 """Inner api, convert learning rate of scalar to LearningRateSchedule.""" 758 def __init__(self, learning_rate): 759 super(_ConvertToCell, self).__init__() 760 if not isinstance(learning_rate, Parameter): 761 raise TypeError('Learning rate must be Parameter.') 762 self.learning_rate = learning_rate 763 764 def construct(self, global_step): 765 return self.learning_rate + 1.0 - 1.0 766 767 768class _IteratorLearningRate(LearningRateSchedule): 769 """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule.""" 770 def __init__(self, learning_rate, name): 771 super(_IteratorLearningRate, self).__init__() 772 if isinstance(learning_rate, Tensor): 773 if learning_rate.ndim != 1: 774 raise ValueError("The dim of `Tensor` type dynamic learning rate should be 1, " 775 f"but got {learning_rate.ndim}.") 776 else: 777 raise TypeError("Learning rate should be Tensor.") 778 779 self.learning_rate = Parameter(learning_rate, name) 780 self.gather = P.Gather() 781 782 def construct(self, global_step): 783 return self.gather(self.learning_rate, global_step, 0) 784