1# Copyright 2021-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""thor""" 16from __future__ import absolute_import 17 18import numpy as np 19 20from mindspore.ops import functional as F, composite as C, operations as P 21from mindspore.common.initializer import initializer 22from mindspore.common.parameter import Parameter, ParameterTuple 23from mindspore.common.tensor import Tensor 24import mindspore.ops as ops 25import mindspore.nn as nn 26import mindspore.common.dtype as mstype 27import mindspore.log as logger 28from mindspore import _checkparam as Validator 29from mindspore.nn.optim.optimizer import Optimizer 30from mindspore.parallel._utils import _get_device_num, _get_gradients_mean 31from mindspore import context 32from mindspore.context import ParallelMode 33from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor 34from mindspore.nn.wrap import DistributedGradReducer 35from mindspore.train.train_thor.convert_utils import ConvertNetUtils 36from mindspore.parallel._auto_parallel_context import auto_parallel_context 37 38# Enumerates types of Layer 39Other = -1 40Conv = 1 41FC = 2 42Embedding = 3 43LayerNorm = 4 44BatchNorm = 5 45 46op_add = P.AddN() 47apply_decay = C.MultitypeFuncGraph("apply_decay") 48_momentum_opt = C.MultitypeFuncGraph("momentum_opt") 49 50 51@apply_decay.register("Number", "Bool", "Tensor", "Tensor") 52def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): 53 """Get grad with weight_decay.""" 54 if if_apply: 55 return op_add((weight * weight_decay, gradient)) 56 return gradient 57 58 59@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") 60def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): 61 """Apply momentum optimizer to the weight parameter using Tensor.""" 62 success = True 63 success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) 64 return success 65 66 67IS_ENABLE_GLOBAL_NORM = False 68GRADIENT_CLIP_TYPE = 1 69GRADIENT_CLIP_VALUE = 1.0 70clip_grad = C.MultitypeFuncGraph("clip_grad") 71hyper_map_op = C.HyperMap() 72 73 74@clip_grad.register("Number", "Number", "Tensor") 75def _clip_grad(clip_type, clip_value, grad): 76 """ 77 Clip gradients. 78 79 Inputs: 80 clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. 81 clip_value (float): Specifies how much to clip. 82 grad (tuple[Tensor]): Gradients. 83 84 Outputs: 85 tuple[Tensor], clipped gradients. 86 """ 87 if clip_type not in [0, 1]: 88 return grad 89 dt = F.dtype(grad) 90 if clip_type == 0: 91 new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), 92 F.cast(F.tuple_to_array((clip_value,)), dt)) 93 else: 94 new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) 95 return new_grad 96 97 98def clip_gradient(enable_clip_grad, gradients): 99 """clip gradients""" 100 if enable_clip_grad: 101 if IS_ENABLE_GLOBAL_NORM: 102 gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None) 103 else: 104 gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients) 105 return gradients 106 107 108C0 = 16 109 110 111def _check_param(momentum, frequency, lr, cls_name): 112 """Check param.""" 113 Validator.check_value_type("momentum", momentum, [float], cls_name) 114 if isinstance(momentum, float) and momentum < 0.0: 115 raise ValueError("For 'thor', the argument 'momentum' must be at least 0.0, " 116 "but got 'momentum' {}.".format(momentum)) 117 Validator.check_value_type("frequency", frequency, [int], cls_name) 118 if isinstance(frequency, int) and frequency < 2: 119 raise ValueError("For 'thor', the argument 'frequency' must be at least 2, " 120 "but got 'frequency' {}.".format(frequency)) 121 Validator.check_value_type("learning rate", lr, [Tensor], cls_name) 122 123 124def caculate_device_shape(matrix_dim, channel, is_a): 125 if is_a: 126 if channel // C0 == 0: 127 matrix_dim = (matrix_dim / channel) * C0 128 ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) 129 return ll 130 131 132def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape): 133 """is conv layer matmul support shape""" 134 temp = (matrix_g_shape, matrix_a_shape) 135 support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)), 136 ((4, 4, 16, 16), (4, 4, 16, 16)), 137 ((4, 4, 16, 16), (36, 36, 16, 16)), 138 ((16, 16, 16, 16), (4, 4, 16, 16)), 139 ((4, 4, 16, 16), (16, 16, 16, 16)), 140 ((8, 8, 16, 16), (16, 16, 16, 16)), 141 ((8, 8, 16, 16), (72, 72, 16, 16)), 142 ((32, 32, 16, 16), (8, 8, 16, 16)), 143 ((32, 32, 16, 16), (16, 16, 16, 16)), 144 ((8, 8, 16, 16), (32, 32, 16, 16)), 145 ((16, 16, 16, 16), (32, 32, 16, 16)), 146 ((16, 16, 16, 16), (144, 144, 16, 16)), 147 ((64, 64, 16, 16), (16, 16, 16, 16)), 148 ((64, 64, 16, 16), (32, 32, 16, 16)), 149 ((16, 16, 16, 16), (64, 64, 16, 16)), 150 ((32, 32, 16, 16), (64, 64, 16, 16)), 151 ((32, 32, 16, 16), (288, 288, 16, 16)), 152 ((128, 128, 16, 16), (32, 32, 16, 16)), 153 ((128, 128, 16, 16), (64, 64, 16, 16)), 154 ((32, 32, 16, 16), (128, 128, 16, 16))] 155 if temp in support_shape: 156 return True 157 return False 158 159 160def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim): 161 """get matmul shape""" 162 split_dima = split_dim 163 split_dimg = split_dim 164 if matrix_a_dim % split_dim == 0: 165 batch_w = matrix_a_dim // split_dim 166 else: 167 if matrix_a_dim < split_dim: 168 batch_w = 1 169 split_dima = matrix_a_dim 170 else: 171 batch_w = matrix_a_dim // split_dim + 1 172 173 if matrix_g_dim % split_dim == 0: 174 batch_h = matrix_g_dim // split_dim 175 else: 176 if matrix_g_dim < split_dim: 177 batch_h = 1 178 split_dimg = matrix_g_dim 179 else: 180 batch_h = matrix_g_dim // split_dim + 1 181 matrix_a_shape = (batch_h, batch_w, split_dima, split_dima) 182 matrix_g_shape = (batch_h, split_dimg, split_dimg) 183 return matrix_a_shape, matrix_g_shape 184 185 186def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map): 187 """get layer type for dense layer and conv layer""" 188 if subcell.weight.requires_grad: 189 if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \ 190 or "rpn_with_loss.rpn_convs_list.0." in prefix.lower(): 191 layertype_map.append(Other) 192 193 194def find_net_layertype_recur(net, layertype_map): 195 """get net layer type recursively.""" 196 cells = net.name_cells() 197 for name in cells: 198 subcell = cells[name] 199 prefix = subcell.param_prefix 200 if subcell == net: 201 continue 202 elif isinstance(subcell, Conv2dThor): 203 layertype_map.append(Conv) 204 elif isinstance(subcell, DenseThor): 205 layertype_map.append(FC) 206 elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)): 207 layertype_map.append(Embedding) 208 elif isinstance(subcell, nn.LayerNorm): 209 layertype_map.append(LayerNorm) 210 elif isinstance(subcell, nn.BatchNorm2d): 211 if subcell.gamma.requires_grad: 212 layertype_map.append(BatchNorm) 213 elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, 214 nn.BatchNorm1d, nn.GroupNorm)): 215 if isinstance(subcell, (nn.Dense, nn.Conv2d)): 216 get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map) 217 else: 218 layertype_map.append(Other) 219 else: 220 find_net_layertype_recur(subcell, layertype_map) 221 222 223def get_net_layertype_mask(net): 224 layertype_map = [] 225 find_net_layertype_recur(net, layertype_map) 226 return layertype_map 227 228 229def get_layer_counter(layer_type, layer_counter, params, idx): 230 """get layer counter""" 231 if layer_type in [Conv, FC]: 232 if "bias" in params[idx].name.lower(): 233 layer_counter = layer_counter + 1 234 else: 235 if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower(): 236 layer_counter = layer_counter + 1 237 elif layer_type in [LayerNorm, BatchNorm]: 238 if "beta" in params[idx].name.lower(): 239 layer_counter = layer_counter + 1 240 else: 241 if "bias" in params[idx].name.lower(): 242 layer_counter = layer_counter + 1 243 elif "weight" in params[idx].name.lower(): 244 if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower(): 245 layer_counter = layer_counter + 1 246 else: 247 layer_counter = layer_counter + 1 248 return layer_counter 249 250 251def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, 252 use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, 253 frequency=100): 254 r""" 255 Updates gradients by second-order algorithm--THOR. 256 257 The updating formulas are as follows, 258 259 .. math:: 260 \begin{array}{ll} 261 & \textbf{Parameter:} \: \text{the learning rate } \gamma\text{, the damping parameter }\lambda \\ 262 & \textbf{Init:} \: \lambda \leftarrow 0 \\ 263 & A_{i-1}=\mathbb{E}\left[a_{i-1} a_{i-1}^{T}\right] \\ 264 & G_{i}=\mathbb{E}\left[D_{s_i} D_{s_i}^{T}\right] \\ 265 & w_{i}^{(k+1)} \leftarrow w_{i}^{(k)}-\gamma\left(\left(A_{i-1}^{(k)}+\lambda I\right)^{-1} 266 \otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)} 267 \end{array} 268 269 :math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer. 270 :math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer. 271 :math:`I` represents the identity matrix. 272 :math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer. 273 :math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'. 274 275 Note: 276 When a parameter group is separated, 'weight_decay' of each group is applied to the corresponding parameter. 277 'weight_decay' in the optimizer is applied to arguments that do not have 'beta' or 'gamma' in their name 278 when the argument group is not separated. 279 When separating parameter groups, set grad_centralization to True if you want to concentrate gradients, 280 but concentration gradients can only be applied to parameters of the convolution layer. 281 If the parameter for the unconvolutional layer is set to True, an error will be reported. 282 To improve the performance of parameter groups, you can customize the order of parameters. 283 284 Args: 285 net (Cell): The training network. 286 287 learning_rate (Tensor): A value for the learning rate. 288 289 damping (Tensor): A value for the damping. 290 291 momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0. 292 293 weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. 294 Default: ``0.0`` . 295 296 loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the 297 default value. Default: ``1.0`` . 298 299 batch_size (int): The size of a batch. Default: ``32`` . 300 301 use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` . 302 303 decay_filter (function): A function to determine which layers the weight decay applied to. And it 304 only works when the weight_decay > 0. Default: lambda x: x.name not in [] 305 306 split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed 307 computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set 308 to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other 309 is 27~53. Default: ``None`` . 310 311 enable_clip_grad (bool): Whether to clip the gradients. Default: ``False`` . 312 313 frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N 314 (N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps, 315 and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: ``100`` . 316 317 Inputs: 318 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 319 320 Outputs: 321 tuple[bool], all elements are True. 322 323 Raises: 324 TypeError: If `learning_rate` is not Tensor. 325 TypeError: If `loss_scale`, `momentum` or `frequency` is not a float. 326 TypeError: If `weight_decay` is neither float nor int. 327 TypeError: If `use_nesterov` is not a bool. 328 TypeError: If `frequency` is not int. 329 ValueError: If `loss_scale` is less than or equal to 0. 330 ValueError: If `weight_decay` or `momentum` is less than 0. 331 ValueError: If `frequency` is less than 2. 332 333 Supported Platforms: 334 ``Ascend`` ``GPU`` 335 336 Examples: 337 >>> import mindspore as ms 338 >>> from mindspore import nn 339 >>> from mindspore import Tensor 340 >>> 341 >>> # Define the network structure of LeNet5. Refer to 342 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 343 >>> net = LeNet5() 344 >>> # Create the dataset taking MNIST as an example. Refer to 345 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 346 >>> dataset = create_dataset() 347 >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32) 348 >>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4) 349 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 350 >>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False) 351 >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'}, 352 ... amp_level="O2", keep_batchnorm_fp32=False) 353 >>> model = ms.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim, 354 ... loss_scale_manager=loss_scale, metrics={'acc'}, 355 ... amp_level="O2", keep_batchnorm_fp32=False) 356 357 """ 358 context.set_context(max_call_depth=10000) 359 ConvertNetUtils().convert_to_thor_net(net) 360 if context.get_context("device_target") == "Ascend": 361 return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter, 362 split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency) 363 return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, 364 use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad, 365 frequency=frequency) 366 367 368class ThorGpu(Optimizer): 369 """ 370 ThorGpu 371 """ 372 373 def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, 374 use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, 375 enable_clip_grad=False, frequency=100): 376 params = filter(lambda x: x.requires_grad, net.get_parameters()) 377 super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale) 378 _check_param(momentum, frequency, learning_rate, self.__class__.__name__) 379 self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") 380 self.params = self._parameters 381 self.use_nesterov = Validator.check_bool(use_nesterov) 382 self.moments = self.params.clone(prefix="moments", init='zeros') 383 self.hyper_map = C.HyperMap() 384 self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) 385 self.net = net 386 self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) 387 self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) 388 self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) 389 self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) 390 self.batch_size = Tensor(batch_size, mstype.float32) 391 self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) 392 self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) 393 self.damping = damping 394 self._define_gpu_operator() 395 logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) 396 self.thor = True 397 self.matrix_a = () 398 self.matrix_g = () 399 self.matrix_a_shape = () 400 self.thor_layer_count = 0 401 self.conv_layer_count = 0 402 self.weight_fim_idx_map = () 403 self.weight_conv_idx_map = () 404 self.weight_layertype_idx_map = () 405 self._process_matrix_init_and_weight_idx_map(self.net) 406 self.matrix_a = ParameterTuple(self.matrix_a) 407 self.matrix_g = ParameterTuple(self.matrix_g) 408 self.weight_decay = weight_decay 409 self.decay_flags = tuple(decay_filter(x) for x in self._parameters) 410 self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim) 411 self.enable_clip_grad = enable_clip_grad 412 self.frequency = frequency 413 self._define_gpu_reducer(split_indices) 414 415 def get_frequency(self): 416 """get thor frequency""" 417 return self.frequency 418 419 def _define_gpu_operator(self): 420 """define gpu operator""" 421 self.transpose = P.Transpose() 422 self.shape = P.Shape() 423 self.reshape = P.Reshape() 424 self.matmul = P.MatMul() 425 self.assign = P.Assign() 426 self.mul = P.Mul() 427 self.gather = P.Gather() 428 self.one = Tensor(1, mstype.int32) 429 self.feature_map = Tensor(1.0, mstype.float32) 430 self.axis = 0 431 self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) 432 self.cast = P.Cast() 433 self.sqrt = P.Sqrt() 434 self.eye = P.Eye() 435 self.split_dim = 128 436 self.embedding_cholesky = P.CholeskyTrsm() 437 self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim) 438 self.vector_matmul = P.BatchMatMul(transpose_a=True) 439 self.reduce_sum = P.ReduceSum(keep_dims=False) 440 self.inv = P.Reciprocal() 441 self.square = P.Square() 442 self.expand = P.ExpandDims() 443 444 def _define_gpu_reducer(self, split_indices): 445 """define gpu reducer""" 446 self.parallel_mode = context.get_auto_parallel_context("parallel_mode") 447 self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) 448 if self.is_distributed: 449 mean = _get_gradients_mean() 450 degree = _get_device_num() 451 if not split_indices: 452 self.split_indices = [len(self.matrix_a_cov) - 1] 453 else: 454 self.split_indices = split_indices 455 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") 456 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") 457 self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) 458 self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) 459 460 def _process_matrix_init_and_weight_idx_map(self, net): 461 """for GPU, process matrix init shape, and get weight idx map""" 462 layer_type_map = get_net_layertype_mask(net) 463 layer_counter = 0 464 for idx in range(len(self.params)): 465 layer_type = layer_type_map[layer_counter] 466 weight = self.params[idx] 467 weight_shape = self.shape(weight) 468 if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower(): 469 in_channels = weight_shape[1] 470 out_channels = weight_shape[0] 471 matrix_a_dim = in_channels 472 if layer_type == Conv: 473 matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] 474 matrix_g_dim = out_channels 475 matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim) 476 matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32), 477 name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) 478 matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32), 479 name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) 480 self.matrix_a = self.matrix_a + (matrix_a_inv,) 481 self.matrix_g = self.matrix_g + (matrix_g_inv,) 482 self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,) 483 elif layer_type == Embedding: 484 vocab_size = weight_shape[0] 485 embedding_size = weight_shape[1] 486 matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), 487 name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) 488 matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)), 489 name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) 490 self.matrix_a = self.matrix_a + (matrix_a_inv,) 491 self.matrix_g = self.matrix_g + (matrix_g_inv,) 492 self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),) 493 494 if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): 495 self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) 496 self.thor_layer_count = self.thor_layer_count + 1 497 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) 498 if layer_type == Conv: 499 self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) 500 self.conv_layer_count = self.conv_layer_count + 1 501 else: 502 self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) 503 else: 504 self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) 505 self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) 506 if layer_type == LayerNorm: 507 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) 508 else: 509 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) 510 # bert.cls1.output_bias: not a network layer, only a trainable param 511 if "output_bias" not in self.params[idx].name.lower(): 512 layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) 513 514 def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce): 515 """get matrixA inverse list and matrix G inverse list""" 516 for i in range(len(self.params)): 517 thor_layer_count = self.weight_fim_idx_map[i] 518 conv_layer_count = self.weight_conv_idx_map[i] 519 layer_type = self.weight_layertype_idx_map[i] 520 if layer_type in [Conv, FC, Embedding]: 521 g = gradients[i] 522 matrix_a = self.matrix_a_cov[thor_layer_count] 523 matrix_g = self.matrix_g_cov[thor_layer_count] 524 matrix_a = F.depend(matrix_a, g) 525 matrix_g = F.depend(matrix_g, g) 526 damping_a = damping_step 527 damping_g = damping_step 528 feature_map = self.feature_map 529 if layer_type == Conv: 530 a_normalizer = self.a_normalizer[conv_layer_count] 531 g_normalizer = self.g_normalizer[conv_layer_count] 532 a_normalizer = F.depend(a_normalizer, g) 533 g_normalizer = F.depend(g_normalizer, g) 534 damping_a = self.mul(damping_step, 1.0 / a_normalizer) 535 damping_g = self.mul(damping_step, 1.0 / g_normalizer) 536 feature_map = self.sqrt(1.0 / a_normalizer) 537 a_shape = self.shape(matrix_a) 538 a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) 539 damping_a = self.sqrt(damping_a) 540 damping_g = self.sqrt(damping_g) 541 g_shape = self.shape(matrix_g) 542 g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32) 543 matrix_g = self.mul(matrix_g, self.loss_scale) 544 matrix_g = self.mul(matrix_g, self.batch_size_scale) 545 matrix_g = matrix_g + damping_g * g_eye 546 if layer_type == Embedding: 547 a_eye = P.OnesLike()(matrix_a) 548 matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) 549 matrix_a = matrix_a + damping_a * a_eye 550 matrix_a = self.inv(matrix_a) 551 matrix_g = self.embedding_cholesky(matrix_g) 552 matrix_g = self.matmul(matrix_g, matrix_g) 553 else: 554 matrix_a = matrix_a + damping_a * a_eye 555 matrix_a = self.cholesky(matrix_a) 556 matrix_a = self.vector_matmul(matrix_a, matrix_a) 557 matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a) 558 matrix_g = self.cholesky(matrix_g) 559 matrix_g = self.vector_matmul(matrix_g, matrix_g) 560 matrix_a = self.mul(matrix_a, feature_map) 561 matrix_g = self.mul(matrix_g, feature_map) 562 matrix_a_allreduce = matrix_a_allreduce + (matrix_a,) 563 matrix_g_allreduce = matrix_g_allreduce + (matrix_g,) 564 return matrix_a_allreduce, matrix_g_allreduce 565 566 def _process_layernorm(self, damping_step, gradient): 567 """process layernorm""" 568 damping = self.sqrt(damping_step) 569 normalizer = self.batch_size 570 normalizer = self.cast(normalizer, mstype.float32) 571 fim_cov = self.square(gradient) 572 fim_cov = self.mul(fim_cov, 1.0 / normalizer) 573 fim_cov = fim_cov + damping 574 fim_inv = self.inv(fim_cov) 575 gradient = self.mul(fim_inv, gradient) 576 return gradient 577 578 def _reshape_gradient(self, conv_layer_count, g, g_shape): 579 """reshape gradient""" 580 if conv_layer_count != -1: 581 g = self.reshape(g, g_shape) 582 return g 583 584 def construct(self, gradients): 585 params = self.params 586 moments = self.moments 587 gradients = self.flatten_gradients(gradients) 588 gradients = self.scale_grad(gradients) 589 damping_step = self.gather(self.damping, self.cov_step, self.axis) 590 damping_step = self.cast(damping_step, mstype.float32) 591 new_grads = () 592 if self.thor: 593 matrix_ainv_list = () 594 matrix_ginv_list = () 595 matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step, 596 matrix_ainv_list, matrix_ginv_list) 597 if self.is_distributed: 598 matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) 599 matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) 600 601 for i in range(len(self.params)): 602 g = gradients[i] 603 thor_layer_count = self.weight_fim_idx_map[i] 604 conv_layer_count = self.weight_conv_idx_map[i] 605 layer_type = self.weight_layertype_idx_map[i] 606 if layer_type in [Conv, FC]: 607 g_shape = self.shape(g) 608 g = self.reshape(g, (g_shape[0], -1)) 609 matrix_a = matrix_a_allreduce[thor_layer_count] 610 matrix_g = matrix_g_allreduce[thor_layer_count] 611 g = self.update_gradient(matrix_g, g, matrix_a) 612 self.assign(self.matrix_a[thor_layer_count], matrix_a) 613 self.assign(self.matrix_g[thor_layer_count], matrix_g) 614 g = self._reshape_gradient(conv_layer_count, g, g_shape) 615 elif layer_type == Embedding: 616 matrix_a = matrix_a_allreduce[thor_layer_count] 617 matrix_g = matrix_g_allreduce[thor_layer_count] 618 self.assign(self.matrix_a[thor_layer_count], matrix_a) 619 self.assign(self.matrix_g[thor_layer_count], matrix_g) 620 temp_a = self.expand(matrix_a, 1) 621 g = self.mul(temp_a, g) 622 g = self.matmul(g, matrix_g) 623 elif layer_type == LayerNorm: 624 g = self._process_layernorm(damping_step, g) 625 new_grads = new_grads + (g,) 626 else: 627 for j in range(len(self.params)): 628 g = gradients[j] 629 thor_layer_count = self.weight_fim_idx_map[j] 630 conv_layer_count = self.weight_conv_idx_map[j] 631 layer_type = self.weight_layertype_idx_map[j] 632 if layer_type in [Conv, FC]: 633 g_shape = self.shape(g) 634 g = self.reshape(g, (g_shape[0], -1)) 635 matrix_a = self.matrix_a[thor_layer_count] 636 matrix_g = self.matrix_g[thor_layer_count] 637 g = self.update_gradient(matrix_g, g, matrix_a) 638 g = self._reshape_gradient(conv_layer_count, g, g_shape) 639 elif layer_type == Embedding: 640 matrix_a = self.matrix_a[thor_layer_count] 641 matrix_g = self.matrix_g[thor_layer_count] 642 g = gradients[j] 643 temp_a = self.expand(matrix_a, 1) 644 g = self.mul(temp_a, g) 645 g = self.matmul(g, matrix_g) 646 elif layer_type == LayerNorm: 647 g = self._process_layernorm(damping_step, g) 648 new_grads = new_grads + (g,) 649 gradients = new_grads 650 651 self.cov_step = self.cov_step + self.one 652 if self.weight_decay > 0: 653 gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) 654 gradients = clip_gradient(self.enable_clip_grad, gradients) 655 lr = self.get_lr() 656 self.assignadd(self.global_step, self.global_step_increase_tensor) 657 success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) 658 return success 659 660 661class ThorAscend(Optimizer): 662 """ThorAscend""" 663 664 def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, 665 decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): 666 params = filter(lambda x: x.requires_grad, net.get_parameters()) 667 super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale) 668 _check_param(momentum, frequency, learning_rate, self.__class__.__name__) 669 self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") 670 self.params = self._parameters 671 self.moments = self.params.clone(prefix="moments", init='zeros') 672 self.hyper_map = C.HyperMap() 673 self.opt = P.ApplyMomentum() 674 self.net = net 675 self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) 676 self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) 677 self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) 678 self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) 679 logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) 680 self._define_ascend_operator() 681 self.c0 = 16 682 self.device_shape_pad_flag = () 683 self.diag_block_dim = 128 684 self.matrix_a = () 685 self.matrix_g = () 686 self.thor_layer_count = 0 687 self.conv_layer_count = 0 688 self.weight_conv_idx_map = () 689 self.weight_fim_idx_map = () 690 self.weight_layertype_idx_map = () 691 self.a_split_pad_dim_map = () 692 self.g_split_pad_dim_map = () 693 self.conv_matmul_support_map = () 694 self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36] 695 self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32] 696 self._process_matrix_init_and_weight_idx_map(self.net) 697 self.matrix_a = ParameterTuple(self.matrix_a) 698 self.matrix_g = ParameterTuple(self.matrix_g) 699 self.matrix_max_inv = () 700 for i in range(len(self.matrix_a)): 701 self.matrix_max_inv = self.matrix_max_inv + ( 702 Parameter(initializer(1, [1], mstype.float32), name='%s%s' % ("matrix_max", str(i)), 703 requires_grad=False),) 704 self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) 705 self.thor = True 706 self.weight_decay = weight_decay 707 self.decay_flags = tuple(decay_filter(x) for x in self._parameters) 708 self.damping = damping 709 self.batch_size = Tensor(batch_size, mstype.float32) 710 self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) 711 self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) 712 self.enable_clip_grad = enable_clip_grad 713 self.frequency = frequency 714 self._define_ascend_reducer(split_indices) 715 716 def get_frequency(self): 717 """get thor frequency""" 718 return self.frequency 719 720 def _get_pad_dim(self, matrix_dim): 721 """get diag split pad dim """ 722 split_pad_dim = 0 723 if matrix_dim == 64: 724 return split_pad_dim 725 res = matrix_dim % self.diag_block_dim 726 if res != 0: 727 split_pad_dim = self.diag_block_dim - res 728 return split_pad_dim 729 730 def _define_ascend_operator(self): 731 """define ascend operator""" 732 self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() 733 self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() 734 self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() 735 self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() 736 self.transpose = P.Transpose() 737 self.shape = P.Shape() 738 self.reshape = P.Reshape() 739 self.mul = P.Mul() 740 self.log = P.Log() 741 self.exp = P.Exp() 742 self.sqrt = P.Sqrt() 743 self.gather = P.Gather() 744 self.assign = P.Assign() 745 self.cast = P.Cast() 746 self.eye = P.Eye() 747 self.concat = P.Concat(0) 748 self.cholesky = P.CusCholeskyTrsm() 749 self.vector_matmul = P.CusBatchMatMul() 750 self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True) 751 self.fused_abs_max2 = P.CusFusedAbsMax1() 752 self.matrix_combine = P.CusMatrixCombine() 753 self.slice = P.Slice() 754 self.expand = P.ExpandDims() 755 self.reduce_sum = P.ReduceSum(keep_dims=False) 756 self.square = P.Square() 757 self.inv = P.Inv() 758 self.matmul = P.MatMul() 759 self.axis = 0 760 self.one = Tensor(1, mstype.int32) 761 self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) 762 763 def _define_ascend_reducer(self, split_indices): 764 """define ascend reducer""" 765 self.parallel_mode = context.get_auto_parallel_context("parallel_mode") 766 self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) 767 if self.is_distributed: 768 mean = _get_gradients_mean() 769 degree = _get_device_num() 770 if not split_indices: 771 self.split_indices = [len(self.matrix_a_cov) - 1] 772 else: 773 self.split_indices = split_indices 774 if self.conv_layer_count > 0: 775 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") 776 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") 777 self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2) 778 self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4) 779 780 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") 781 auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") 782 self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) 783 self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) 784 785 def _get_weight_idx_map(self, layer_type, idx, weight_shape): 786 """for Ascend, get weight idx map""" 787 if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): 788 self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) 789 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) 790 if layer_type == Embedding: 791 a_pad_dim = 0 792 g_pad_dim = 0 793 self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) 794 self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) 795 else: 796 out_channels = weight_shape[0] 797 g_pad_dim = self._get_pad_dim(out_channels) 798 self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) 799 matrix_a_dim = weight_shape[1] 800 if layer_type == Conv: 801 matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] 802 a_pad_dim = self._get_pad_dim(matrix_a_dim) 803 self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) 804 805 self.thor_layer_count = self.thor_layer_count + 1 806 if layer_type == Conv: 807 self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) 808 self.conv_layer_count = self.conv_layer_count + 1 809 else: 810 self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) 811 else: 812 self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) 813 self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) 814 if layer_type == LayerNorm: 815 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) 816 else: 817 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) 818 819 def _get_fc_matrix(self, weight_shape): 820 """for Ascend, get fc matrix_a and matrix_g""" 821 out_channels = weight_shape[0] 822 in_channels = weight_shape[1] 823 if self.conv_layer_count > 0: 824 if out_channels == 1001: 825 fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), 826 name='matrix_a_inv_' + str(self.thor_layer_count), 827 requires_grad=False) 828 fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), 829 name="matrix_g_inv_" + str(self.thor_layer_count), 830 requires_grad=False) 831 else: 832 fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)), 833 name='matrix_a_inv_' + str(self.thor_layer_count), 834 requires_grad=False) 835 fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)), 836 name="matrix_g_inv_" + str(self.thor_layer_count), 837 requires_grad=False) 838 self.matrix_a = self.matrix_a + (fc_matrix_a,) 839 self.matrix_g = self.matrix_g + (fc_matrix_g,) 840 841 def _process_matrix_init_and_weight_idx_map(self, net): 842 """for Ascend, process matrix init shape, and get weight idx map""" 843 layer_counter = 0 844 layer_type_map = get_net_layertype_mask(net) 845 for idx in range(len(self.params)): 846 layer_type = layer_type_map[layer_counter] 847 weight = self.params[idx] 848 weight_shape = self.shape(weight) 849 if layer_type == Conv and "bias" not in self.params[idx].name.lower(): 850 in_channels = weight_shape[1] 851 out_channels = weight_shape[0] 852 matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] 853 matrix_g_dim = out_channels 854 matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True) 855 matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False) 856 ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape) 857 if ret: 858 matrix_a_inv = Parameter( 859 Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)), 860 name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) 861 matrix_g_inv = Parameter( 862 Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)), 863 name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) 864 self.conv_matmul_support_map = self.conv_matmul_support_map + (1,) 865 else: 866 matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)), 867 name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) 868 matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)), 869 name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) 870 self.conv_matmul_support_map = self.conv_matmul_support_map + (0,) 871 self.matrix_a = self.matrix_a + (matrix_a_inv,) 872 self.matrix_g = self.matrix_g + (matrix_g_inv,) 873 device_shape_pad_flag = False 874 if matrix_a_dim != matrix_a_device_dim: 875 device_shape_pad_flag = True 876 self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,) 877 elif layer_type == FC and "bias" not in self.params[idx].name.lower(): 878 self._get_fc_matrix(weight_shape) 879 self._get_weight_idx_map(layer_type, idx, weight_shape) 880 # bert.cls1.output_bias: not a network layer, only a trainable param 881 if "output_bias" not in self.params[idx].name.lower(): 882 layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) 883 884 def _process_batch_matmul(self, input_matrix): 885 """process batch matmul""" 886 input_matrix_shape = self.shape(input_matrix) 887 if input_matrix_shape[0] in self.batch_matmul_support_list: 888 input_matrix = self.vector_matmul(input_matrix, input_matrix) 889 else: 890 input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix) 891 return input_matrix 892 893 def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0): 894 """process cholesky pad""" 895 if pad_dim > 0: 896 matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32) 897 matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup) 898 input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix) 899 input_matrix = self.concat((input_matrix, matrix_sup)) 900 return input_matrix 901 902 def _get_abs_max(self, matrix_inv, origin_dim): 903 """get matrix abs max""" 904 cholesky_shape = self.shape(matrix_inv) 905 if cholesky_shape[0] in self.abs_max_support_list: 906 matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv) 907 matrix_max = self.fused_abs_max2(matrix_inv_max) 908 matrix_inv = self.matrix_combine(matrix_inv) 909 else: 910 matrix_inv = self.matrix_combine(matrix_inv) 911 matrix_abs = P.Abs()(matrix_inv) 912 matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs) 913 return matrix_max, matrix_inv 914 915 def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, 916 matrix_a_max_allreduce, matrix_g_max_allreduce): 917 """get fc layer ainv and ginv""" 918 thor_layer_count = self.weight_fim_idx_map[index] 919 g = gradients[index] 920 matrix_a = self.matrix_a_cov[thor_layer_count] 921 matrix_g = self.matrix_g_cov[thor_layer_count] 922 matrix_a = F.depend(matrix_a, g) 923 matrix_g = F.depend(matrix_g, g) 924 a_shape = self.shape(matrix_a) 925 a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) 926 g_shape = self.shape(matrix_g) 927 g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) 928 damping = self.sqrt(damping_step) 929 matrix_a = matrix_a + damping * a_eye 930 a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] 931 matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) 932 matrix_a_inv = self.cholesky(matrix_a) 933 matrix_a_inv = self._process_batch_matmul(matrix_a_inv) 934 935 weight_shape = self.shape(self.params[index]) 936 out_channels = weight_shape[0] 937 in_channels = weight_shape[1] 938 if out_channels == 2: 939 matrix_a_inv = self.matrix_combine(matrix_a_inv) 940 matrix_g_inv = g_eye 941 else: 942 matrix_g = self.mul(matrix_g, self.loss_scale) 943 matrix_g = self.mul(matrix_g, self.batch_size_scale) 944 matrix_g = matrix_g + damping * g_eye 945 g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] 946 matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) 947 matrix_g_inv = self.cholesky(matrix_g) 948 matrix_g_inv = self._process_batch_matmul(matrix_g_inv) 949 if self.conv_layer_count > 0: 950 a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels) 951 g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) 952 a_max = F.depend(a_max, g) 953 g_max = F.depend(g_max, g) 954 matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) 955 matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) 956 else: 957 matrix_a_inv = self.matrix_combine(matrix_a_inv) 958 matrix_g_inv = self.matrix_combine(matrix_g_inv) 959 960 if a_pad_dim > 0: 961 matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels)) 962 if g_pad_dim > 0: 963 matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) 964 matrix_a_inv_shape = self.shape(matrix_a_inv) 965 matrix_g_combine_shape = self.shape(matrix_g_inv) 966 if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001: 967 matrix_a_inv = self.reshape(matrix_a_inv, 968 (matrix_a_inv_shape[0] // 16, 16, 969 matrix_a_inv_shape[0] // 16, 16)) 970 matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) 971 matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv) 972 973 matrix_g_inv_shape = self.shape(matrix_g_inv) 974 matrix_g_inv = self.reshape(matrix_g_inv, 975 (matrix_g_inv_shape[0] // 16, 16, 976 matrix_g_inv_shape[0] // 16, 16)) 977 matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) 978 979 matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) 980 matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) 981 return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce 982 983 def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv): 984 """process conv matmul device pad""" 985 if self.device_shape_pad_flag[conv_layer_count]: 986 kernel_hw = weight_shape[2] * weight_shape[3] 987 in_channels = weight_shape[1] 988 matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels)) 989 matrix_a_inv = P.Pad(((0, 0), (0, self.c0 - in_channels), (0, 0), 990 (0, self.c0 - in_channels)))(matrix_a_inv) 991 return matrix_a_inv 992 993 def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, 994 matrix_a_max_allreduce, matrix_g_max_allreduce): 995 """get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list""" 996 for i in range(len(self.params)): 997 thor_layer_count = self.weight_fim_idx_map[i] 998 conv_layer_count = self.weight_conv_idx_map[i] 999 layer_type = self.weight_layertype_idx_map[i] 1000 weight_shape = self.shape(self.params[i]) 1001 out_channels = weight_shape[0] 1002 if layer_type == Conv: 1003 g = gradients[i] 1004 matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] 1005 matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] 1006 matrix_a = self.matrix_a_cov[thor_layer_count] 1007 matrix_g = self.matrix_g_cov[thor_layer_count] 1008 matrix_a = F.depend(matrix_a, g) 1009 matrix_g = F.depend(matrix_g, g) 1010 a_shape = self.shape(matrix_a) 1011 a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) 1012 g_shape = self.shape(matrix_g) 1013 g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) 1014 a_normalizer = self.a_normalizer[conv_layer_count] 1015 g_normalizer = self.g_normalizer[conv_layer_count] 1016 a_normalizer = F.depend(a_normalizer, g) 1017 g_normalizer = F.depend(g_normalizer, g) 1018 damping_a = self.mul(damping_step, self.batch_size / a_normalizer) 1019 damping_g = self.mul(damping_step, self.batch_size / g_normalizer) 1020 damping_a = self.sqrt(damping_a) 1021 matrix_a = matrix_a + damping_a * a_eye 1022 a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] 1023 matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) 1024 matrix_a_inv = self.cholesky(matrix_a) 1025 matrix_a_inv = self._process_batch_matmul(matrix_a_inv) 1026 a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim) 1027 1028 damping_g = self.sqrt(damping_g) 1029 matrix_g = self.mul(matrix_g, self.loss_scale) 1030 matrix_g = self.mul(matrix_g, self.batch_size_scale) 1031 matrix_g = matrix_g + damping_g * g_eye 1032 g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] 1033 matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) 1034 matrix_g_inv = self.cholesky(matrix_g) 1035 matrix_g_inv = self._process_batch_matmul(matrix_g_inv) 1036 g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) 1037 1038 if a_pad_dim > 0: 1039 matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim)) 1040 if g_pad_dim > 0: 1041 matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) 1042 1043 if matmul_support_flag == 1: 1044 matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv) 1045 matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count]) 1046 matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2], 1047 matrix_a_inv_shape[1], matrix_a_inv_shape[3]) 1048 matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape) 1049 matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) 1050 matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count]) 1051 matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2], 1052 matrix_g_inv_shape[1], matrix_g_inv_shape[3]) 1053 matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape) 1054 matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) 1055 1056 a_max = F.depend(a_max, g) 1057 g_max = F.depend(g_max, g) 1058 matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) 1059 matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) 1060 matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) 1061 matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) 1062 elif layer_type == FC: 1063 matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ 1064 self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, 1065 matrix_a_max_allreduce, matrix_g_max_allreduce) 1066 elif layer_type == Embedding: 1067 g = gradients[i] 1068 matrix_a = self.matrix_a_cov[thor_layer_count] 1069 matrix_g = self.matrix_g_cov[thor_layer_count] 1070 matrix_a = F.depend(matrix_a, g) 1071 matrix_g = F.depend(matrix_g, g) 1072 g_shape = self.shape(matrix_g) 1073 g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) 1074 damping = self.sqrt(damping_step) 1075 a_eye = P.OnesLike()(matrix_a) 1076 matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) 1077 matrix_a = matrix_a + damping * a_eye 1078 matrix_a_inv = self.inv(matrix_a) 1079 matrix_g = self.mul(matrix_g, self.loss_scale) 1080 matrix_g = self.mul(matrix_g, self.batch_size_scale) 1081 matrix_g = matrix_g + damping * g_eye 1082 matrix_g_inv = self.cholesky(matrix_g) 1083 matrix_g_inv = self._process_batch_matmul(matrix_g_inv) 1084 matrix_g_inv = self.matrix_combine(matrix_g_inv) 1085 matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) 1086 matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) 1087 return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce 1088 1089 def _process_layernorm(self, damping_step, gradient): 1090 """process layernorm layer for thor""" 1091 damping = self.sqrt(damping_step) 1092 normalizer = self.cast(self.batch_size, mstype.float32) 1093 fim_cov = self.square(gradient) 1094 fim_cov = self.mul(fim_cov, 1.0 / normalizer) 1095 fim_cov = fim_cov + damping 1096 fim_inv = self.inv(fim_cov) 1097 gradient = self.mul(fim_inv, gradient) 1098 return gradient 1099 1100 def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g): 1101 """process thor graph fc layer""" 1102 temp_a = matrix_a_allreduce[thor_layer_count] 1103 temp_g = matrix_g_allreduce[thor_layer_count] 1104 self.assign(self.matrix_a_cov[thor_layer_count], temp_a) 1105 self.assign(self.matrix_g_cov[thor_layer_count], temp_g) 1106 temp_a = self.cast(temp_a, mstype.float16) 1107 temp_g = self.cast(temp_g, mstype.float16) 1108 g = self.cast(g, mstype.float16) 1109 g = self.matmul(temp_g, g) 1110 g = self.matmul(g, temp_a) 1111 g = self.cast(g, mstype.float32) 1112 return g 1113 1114 def _get_second_gradients_one(self, params_len, gradients, new_grads): 1115 """get second gradients one""" 1116 for i in range(params_len): 1117 g = gradients[i] 1118 thor_layer_count = self.weight_fim_idx_map[i] 1119 conv_layer_count = self.weight_conv_idx_map[i] 1120 layer_type = self.weight_layertype_idx_map[i] 1121 matrix_a = self.matrix_a[thor_layer_count] 1122 matrix_g = self.matrix_g[thor_layer_count] 1123 matrix_max = self.matrix_max_inv[thor_layer_count] 1124 grad_shape = self.shape(g) 1125 if layer_type == FC: 1126 if grad_shape[0] == 1001: 1127 g = self.cube_matmul_left_fc(matrix_g, g) 1128 g = self.cube_matmul_right_fc(g, matrix_a, matrix_max) 1129 else: 1130 temp_a = self.cast(matrix_a, mstype.float16) 1131 temp_g = self.cast(matrix_g, mstype.float16) 1132 g = self.cast(g, mstype.float16) 1133 g = self.matmul(temp_g, g) 1134 g = self.matmul(g, temp_a) 1135 g = self.cast(g, mstype.float32) 1136 g = self.mul(g, matrix_max) 1137 elif layer_type == Conv: 1138 matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] 1139 if matmul_support_flag == 1: 1140 g = self.cube_matmul_left(matrix_g, g) 1141 g = self.cube_matmul_right_mul(g, matrix_a, matrix_max) 1142 else: 1143 g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) 1144 temp_a = self.cast(matrix_a, mstype.float16) 1145 temp_g = self.cast(matrix_g, mstype.float16) 1146 g = self.cast(g, mstype.float16) 1147 g = self.matmul(temp_g, g) 1148 g = self.matmul(g, temp_a) 1149 g = self.cast(g, mstype.float32) 1150 g = self.mul(g, matrix_max) 1151 g = self.reshape(g, grad_shape) 1152 new_grads = new_grads + (g,) 1153 return new_grads 1154 1155 def _get_second_gradients(self, new_grads, damping_step, gradients): 1156 """get second gradients for thor""" 1157 params_len = len(self.params) 1158 if self.conv_layer_count > 0: 1159 new_grads = self._get_second_gradients_one(params_len, gradients, new_grads) 1160 else: 1161 for i in range(params_len): 1162 g = gradients[i] 1163 thor_layer_count = self.weight_fim_idx_map[i] 1164 layer_type = self.weight_layertype_idx_map[i] 1165 if layer_type == Embedding: 1166 temp_a_ori = self.matrix_a_cov[thor_layer_count] 1167 temp_g = self.matrix_g_cov[thor_layer_count] 1168 temp_a = self.expand(temp_a_ori, 1) 1169 g = self.mul(temp_a, g) 1170 temp_g = self.cast(temp_g, mstype.float16) 1171 g = self.cast(g, mstype.float16) 1172 g = self.matmul(g, temp_g) 1173 g = self.cast(g, mstype.float32) 1174 elif layer_type == FC: 1175 temp_a = self.matrix_a_cov[thor_layer_count] 1176 temp_g = self.matrix_g_cov[thor_layer_count] 1177 temp_a = self.cast(temp_a, mstype.float16) 1178 temp_g = self.cast(temp_g, mstype.float16) 1179 g = self.cast(g, mstype.float16) 1180 g = self.matmul(temp_g, g) 1181 g = self.matmul(g, temp_a) 1182 g = self.cast(g, mstype.float32) 1183 elif layer_type == LayerNorm: 1184 g = self._process_layernorm(damping_step, g) 1185 new_grads = new_grads + (g,) 1186 return new_grads 1187 1188 def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max): 1189 """get second gradient by matmul""" 1190 conv_layer_count = self.weight_conv_idx_map[index] 1191 layer_type = self.weight_layertype_idx_map[index] 1192 grad_shape = self.shape(g) 1193 if layer_type == FC: 1194 if grad_shape[0] == 1001: 1195 g = self.cube_matmul_left_fc(temp_g, g) 1196 g = self.cube_matmul_right_fc(g, temp_a, temp_max) 1197 else: 1198 temp_a = self.cast(temp_a, mstype.float16) 1199 temp_g = self.cast(temp_g, mstype.float16) 1200 g = self.cast(g, mstype.float16) 1201 g = self.matmul(temp_g, g) 1202 g = self.matmul(g, temp_a) 1203 g = self.cast(g, mstype.float32) 1204 g = self.mul(g, temp_max) 1205 elif layer_type == Conv: 1206 a_normalizer = self.a_normalizer[conv_layer_count] 1207 a_normalizer = F.depend(a_normalizer, g) 1208 temp_max = self.mul(temp_max, self.batch_size / a_normalizer) 1209 matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] 1210 if matmul_support_flag == 1: 1211 g = self.cube_matmul_left(temp_g, g) 1212 g = self.cube_matmul_right_mul(g, temp_a, temp_max) 1213 else: 1214 g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) 1215 temp_a = self.cast(temp_a, mstype.float16) 1216 temp_g = self.cast(temp_g, mstype.float16) 1217 g = self.cast(g, mstype.float16) 1218 g = self.matmul(temp_g, g) 1219 g = self.matmul(g, temp_a) 1220 g = self.cast(g, mstype.float32) 1221 g = self.mul(g, temp_max) 1222 g = self.reshape(g, grad_shape) 1223 return g, temp_max 1224 1225 def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step): 1226 """get second gradient by layertype""" 1227 thor_layer_count = self.weight_fim_idx_map[index] 1228 layer_type = self.weight_layertype_idx_map[index] 1229 if layer_type == Embedding: 1230 temp_a_ori = matrix_a_allreduce[thor_layer_count] 1231 temp_g = matrix_g_allreduce[thor_layer_count] 1232 self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori) 1233 self.assign(self.matrix_g_cov[thor_layer_count], temp_g) 1234 temp_a = self.expand(temp_a_ori, 1) 1235 g = self.mul(temp_a, g) 1236 temp_g = self.cast(temp_g, mstype.float16) 1237 g = self.cast(g, mstype.float16) 1238 g = self.matmul(g, temp_g) 1239 g = self.cast(g, mstype.float32) 1240 elif layer_type == FC: 1241 g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g) 1242 elif layer_type == LayerNorm: 1243 g = self._process_layernorm(damping_step, g) 1244 return g 1245 1246 def construct(self, gradients): 1247 params = self.params 1248 moments = self.moments 1249 gradients = self.flatten_gradients(gradients) 1250 gradients = self.scale_grad(gradients) 1251 damping_step = self.gather(self.damping, self.cov_step, self.axis) 1252 damping_step = self.cast(damping_step, mstype.float32) 1253 if self.thor: 1254 matrix_a_allreduce = () 1255 matrix_g_allreduce = () 1256 matrix_a_max_allreduce = () 1257 matrix_g_max_allreduce = () 1258 matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ 1259 self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, 1260 matrix_a_max_allreduce, matrix_g_max_allreduce) 1261 if self.is_distributed: 1262 matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) 1263 matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) 1264 if self.conv_layer_count > 0: 1265 matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce) 1266 matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce) 1267 1268 new_grads = () 1269 if self.conv_layer_count > 0: 1270 for i in range(len(self.params)): 1271 g = gradients[i] 1272 thor_layer_count = self.weight_fim_idx_map[i] 1273 temp_a = matrix_a_allreduce[thor_layer_count] 1274 temp_g = matrix_g_allreduce[thor_layer_count] 1275 matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count]) 1276 matrix_a_inv_max = self.mul(matrix_a_inv_max, -1) 1277 matrix_a_inv_max = self.exp(matrix_a_inv_max) 1278 temp_a = self.mul(temp_a, matrix_a_inv_max) 1279 matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count]) 1280 matrix_g_inv_max = self.mul(matrix_g_inv_max, -1) 1281 matrix_g_inv_max = self.exp(matrix_g_inv_max) 1282 temp_g = self.mul(temp_g, matrix_g_inv_max) 1283 temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count], 1284 matrix_g_max_allreduce[thor_layer_count]) 1285 temp_a = self.cast(temp_a, mstype.float16) 1286 temp_g = self.cast(temp_g, mstype.float16) 1287 g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max) 1288 self.assign(self.matrix_a[thor_layer_count], temp_a) 1289 self.assign(self.matrix_g[thor_layer_count], temp_g) 1290 self.assign(self.matrix_max_inv[thor_layer_count], temp_max) 1291 new_grads = new_grads + (g,) 1292 gradients = new_grads 1293 else: 1294 for i in range(len(self.params)): 1295 g = gradients[i] 1296 g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step) 1297 new_grads = new_grads + (g,) 1298 gradients = new_grads 1299 else: 1300 new_grads = () 1301 gradients = self._get_second_gradients(new_grads, damping_step, gradients) 1302 1303 self.cov_step = self.cov_step + self.one 1304 if self.weight_decay > 0: 1305 gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) 1306 gradients = clip_gradient(self.enable_clip_grad, gradients) 1307 lr = self.get_lr() 1308 self.assignadd(self.global_step, self.global_step_increase_tensor) 1309 success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) 1310 return success 1311