• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""thor"""
16from __future__ import absolute_import
17
18import numpy as np
19
20from mindspore.ops import functional as F, composite as C, operations as P
21from mindspore.common.initializer import initializer
22from mindspore.common.parameter import Parameter, ParameterTuple
23from mindspore.common.tensor import Tensor
24import mindspore.ops as ops
25import mindspore.nn as nn
26import mindspore.common.dtype as mstype
27import mindspore.log as logger
28from mindspore import _checkparam as Validator
29from mindspore.nn.optim.optimizer import Optimizer
30from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
31from mindspore import context
32from mindspore.context import ParallelMode
33from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
34from mindspore.nn.wrap import DistributedGradReducer
35from mindspore.train.train_thor.convert_utils import ConvertNetUtils
36from mindspore.parallel._auto_parallel_context import auto_parallel_context
37
38# Enumerates types of Layer
39Other = -1
40Conv = 1
41FC = 2
42Embedding = 3
43LayerNorm = 4
44BatchNorm = 5
45
46op_add = P.AddN()
47apply_decay = C.MultitypeFuncGraph("apply_decay")
48_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
49
50
51@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
52def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
53    """Get grad with weight_decay."""
54    if if_apply:
55        return op_add((weight * weight_decay, gradient))
56    return gradient
57
58
59@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
60def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
61    """Apply momentum optimizer to the weight parameter using Tensor."""
62    success = True
63    success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
64    return success
65
66
67IS_ENABLE_GLOBAL_NORM = False
68GRADIENT_CLIP_TYPE = 1
69GRADIENT_CLIP_VALUE = 1.0
70clip_grad = C.MultitypeFuncGraph("clip_grad")
71hyper_map_op = C.HyperMap()
72
73
74@clip_grad.register("Number", "Number", "Tensor")
75def _clip_grad(clip_type, clip_value, grad):
76    """
77    Clip gradients.
78
79    Inputs:
80        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
81        clip_value (float): Specifies how much to clip.
82        grad (tuple[Tensor]): Gradients.
83
84    Outputs:
85        tuple[Tensor], clipped gradients.
86    """
87    if clip_type not in [0, 1]:
88        return grad
89    dt = F.dtype(grad)
90    if clip_type == 0:
91        new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
92                                     F.cast(F.tuple_to_array((clip_value,)), dt))
93    else:
94        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
95    return new_grad
96
97
98def clip_gradient(enable_clip_grad, gradients):
99    """clip gradients"""
100    if enable_clip_grad:
101        if IS_ENABLE_GLOBAL_NORM:
102            gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
103        else:
104            gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
105    return gradients
106
107
108C0 = 16
109
110
111def _check_param(momentum, frequency, lr, cls_name):
112    """Check param."""
113    Validator.check_value_type("momentum", momentum, [float], cls_name)
114    if isinstance(momentum, float) and momentum < 0.0:
115        raise ValueError("For 'thor', the argument 'momentum' must be at least 0.0, "
116                         "but got 'momentum' {}.".format(momentum))
117    Validator.check_value_type("frequency", frequency, [int], cls_name)
118    if isinstance(frequency, int) and frequency < 2:
119        raise ValueError("For 'thor', the argument 'frequency' must be at least 2, "
120                         "but got 'frequency' {}.".format(frequency))
121    Validator.check_value_type("learning rate", lr, [Tensor], cls_name)
122
123
124def caculate_device_shape(matrix_dim, channel, is_a):
125    if is_a:
126        if channel // C0 == 0:
127            matrix_dim = (matrix_dim / channel) * C0
128    ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
129    return ll
130
131
132def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
133    """is conv layer matmul support shape"""
134    temp = (matrix_g_shape, matrix_a_shape)
135    support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
136                     ((4, 4, 16, 16), (4, 4, 16, 16)),
137                     ((4, 4, 16, 16), (36, 36, 16, 16)),
138                     ((16, 16, 16, 16), (4, 4, 16, 16)),
139                     ((4, 4, 16, 16), (16, 16, 16, 16)),
140                     ((8, 8, 16, 16), (16, 16, 16, 16)),
141                     ((8, 8, 16, 16), (72, 72, 16, 16)),
142                     ((32, 32, 16, 16), (8, 8, 16, 16)),
143                     ((32, 32, 16, 16), (16, 16, 16, 16)),
144                     ((8, 8, 16, 16), (32, 32, 16, 16)),
145                     ((16, 16, 16, 16), (32, 32, 16, 16)),
146                     ((16, 16, 16, 16), (144, 144, 16, 16)),
147                     ((64, 64, 16, 16), (16, 16, 16, 16)),
148                     ((64, 64, 16, 16), (32, 32, 16, 16)),
149                     ((16, 16, 16, 16), (64, 64, 16, 16)),
150                     ((32, 32, 16, 16), (64, 64, 16, 16)),
151                     ((32, 32, 16, 16), (288, 288, 16, 16)),
152                     ((128, 128, 16, 16), (32, 32, 16, 16)),
153                     ((128, 128, 16, 16), (64, 64, 16, 16)),
154                     ((32, 32, 16, 16), (128, 128, 16, 16))]
155    if temp in support_shape:
156        return True
157    return False
158
159
160def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
161    """get matmul shape"""
162    split_dima = split_dim
163    split_dimg = split_dim
164    if matrix_a_dim % split_dim == 0:
165        batch_w = matrix_a_dim // split_dim
166    else:
167        if matrix_a_dim < split_dim:
168            batch_w = 1
169            split_dima = matrix_a_dim
170        else:
171            batch_w = matrix_a_dim // split_dim + 1
172
173    if matrix_g_dim % split_dim == 0:
174        batch_h = matrix_g_dim // split_dim
175    else:
176        if matrix_g_dim < split_dim:
177            batch_h = 1
178            split_dimg = matrix_g_dim
179        else:
180            batch_h = matrix_g_dim // split_dim + 1
181    matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
182    matrix_g_shape = (batch_h, split_dimg, split_dimg)
183    return matrix_a_shape, matrix_g_shape
184
185
186def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
187    """get layer type for dense layer and conv layer"""
188    if subcell.weight.requires_grad:
189        if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
190                or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
191            layertype_map.append(Other)
192
193
194def find_net_layertype_recur(net, layertype_map):
195    """get net layer type recursively."""
196    cells = net.name_cells()
197    for name in cells:
198        subcell = cells[name]
199        prefix = subcell.param_prefix
200        if subcell == net:
201            continue
202        elif isinstance(subcell, Conv2dThor):
203            layertype_map.append(Conv)
204        elif isinstance(subcell, DenseThor):
205            layertype_map.append(FC)
206        elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
207            layertype_map.append(Embedding)
208        elif isinstance(subcell, nn.LayerNorm):
209            layertype_map.append(LayerNorm)
210        elif isinstance(subcell, nn.BatchNorm2d):
211            if subcell.gamma.requires_grad:
212                layertype_map.append(BatchNorm)
213        elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
214                                  nn.BatchNorm1d, nn.GroupNorm)):
215            if isinstance(subcell, (nn.Dense, nn.Conv2d)):
216                get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
217            else:
218                layertype_map.append(Other)
219        else:
220            find_net_layertype_recur(subcell, layertype_map)
221
222
223def get_net_layertype_mask(net):
224    layertype_map = []
225    find_net_layertype_recur(net, layertype_map)
226    return layertype_map
227
228
229def get_layer_counter(layer_type, layer_counter, params, idx):
230    """get layer counter"""
231    if layer_type in [Conv, FC]:
232        if "bias" in params[idx].name.lower():
233            layer_counter = layer_counter + 1
234        else:
235            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
236                layer_counter = layer_counter + 1
237    elif layer_type in [LayerNorm, BatchNorm]:
238        if "beta" in params[idx].name.lower():
239            layer_counter = layer_counter + 1
240    else:
241        if "bias" in params[idx].name.lower():
242            layer_counter = layer_counter + 1
243        elif "weight" in params[idx].name.lower():
244            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
245                layer_counter = layer_counter + 1
246        else:
247            layer_counter = layer_counter + 1
248    return layer_counter
249
250
251def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
252         use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
253         frequency=100):
254    r"""
255    Updates gradients by second-order algorithm--THOR.
256
257    The updating formulas are as follows,
258
259    .. math::
260        \begin{array}{ll}
261          & \textbf{Parameter:} \: \text{the learning rate } \gamma\text{, the damping parameter }\lambda \\
262          & \textbf{Init:} \: \lambda \leftarrow 0 \\
263          & A_{i-1}=\mathbb{E}\left[a_{i-1} a_{i-1}^{T}\right] \\
264          & G_{i}=\mathbb{E}\left[D_{s_i} D_{s_i}^{T}\right] \\
265          & w_{i}^{(k+1)} \leftarrow w_{i}^{(k)}-\gamma\left(\left(A_{i-1}^{(k)}+\lambda I\right)^{-1}
266            \otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
267        \end{array}
268
269    :math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer.
270    :math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer.
271    :math:`I` represents the identity matrix.
272    :math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer.
273    :math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'.
274
275    Note:
276        When a parameter group is separated, 'weight_decay' of each group is applied to the corresponding parameter.
277        'weight_decay' in the optimizer is applied to arguments that do not have 'beta' or 'gamma' in their name
278        when the argument group is not separated.
279        When separating parameter groups, set grad_centralization to True if you want to concentrate gradients,
280        but concentration gradients can only be applied to parameters of the convolution layer.
281        If the parameter for the unconvolutional layer is set to True, an error will be reported.
282        To improve the performance of parameter groups, you can customize the order of parameters.
283
284    Args:
285        net (Cell): The training network.
286
287        learning_rate (Tensor): A value for the learning rate.
288
289        damping (Tensor): A value for the damping.
290
291        momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
292
293        weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0.
294            Default: ``0.0`` .
295
296        loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
297            default value. Default: ``1.0`` .
298
299        batch_size (int): The size of a batch. Default: ``32`` .
300
301        use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .
302
303        decay_filter (function): A function to determine which layers the weight decay applied to. And it
304            only works when the weight_decay > 0. Default: lambda x: x.name not in []
305
306        split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
307            computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
308            to [26, 53], it means A/G is divided into two groups to allreduce,  one is 0~26 layer, and the other
309            is 27~53. Default: ``None`` .
310
311        enable_clip_grad (bool): Whether to clip the gradients. Default: ``False`` .
312
313        frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N
314            (N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps,
315            and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: ``100`` .
316
317    Inputs:
318        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
319
320    Outputs:
321        tuple[bool], all elements are True.
322
323    Raises:
324        TypeError: If `learning_rate` is not Tensor.
325        TypeError: If `loss_scale`, `momentum` or `frequency` is not a float.
326        TypeError: If `weight_decay` is neither float nor int.
327        TypeError: If `use_nesterov` is not a bool.
328        TypeError: If `frequency` is not int.
329        ValueError: If `loss_scale` is less than or equal to 0.
330        ValueError: If `weight_decay` or `momentum` is less than 0.
331        ValueError: If `frequency` is less than 2.
332
333    Supported Platforms:
334        ``Ascend`` ``GPU``
335
336    Examples:
337        >>> import mindspore as ms
338        >>> from mindspore import nn
339        >>> from mindspore import Tensor
340        >>>
341        >>> # Define the network structure of LeNet5. Refer to
342        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
343        >>> net = LeNet5()
344        >>> # Create the dataset taking MNIST as an example. Refer to
345        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
346        >>> dataset = create_dataset()
347        >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
348        >>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
349        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
350        >>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False)
351        >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
352        ...               amp_level="O2", keep_batchnorm_fp32=False)
353        >>> model = ms.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
354        ...                                                 loss_scale_manager=loss_scale, metrics={'acc'},
355        ...                                                 amp_level="O2", keep_batchnorm_fp32=False)
356
357    """
358    context.set_context(max_call_depth=10000)
359    ConvertNetUtils().convert_to_thor_net(net)
360    if context.get_context("device_target") == "Ascend":
361        return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
362                          split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
363    return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size,
364                   use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad,
365                   frequency=frequency)
366
367
368class ThorGpu(Optimizer):
369    """
370    ThorGpu
371    """
372
373    def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
374                 use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None,
375                 enable_clip_grad=False, frequency=100):
376        params = filter(lambda x: x.requires_grad, net.get_parameters())
377        super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
378        _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
379        self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
380        self.params = self._parameters
381        self.use_nesterov = Validator.check_bool(use_nesterov)
382        self.moments = self.params.clone(prefix="moments", init='zeros')
383        self.hyper_map = C.HyperMap()
384        self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
385        self.net = net
386        self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
387        self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
388        self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
389        self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
390        self.batch_size = Tensor(batch_size, mstype.float32)
391        self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
392        self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
393        self.damping = damping
394        self._define_gpu_operator()
395        logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
396        self.thor = True
397        self.matrix_a = ()
398        self.matrix_g = ()
399        self.matrix_a_shape = ()
400        self.thor_layer_count = 0
401        self.conv_layer_count = 0
402        self.weight_fim_idx_map = ()
403        self.weight_conv_idx_map = ()
404        self.weight_layertype_idx_map = ()
405        self._process_matrix_init_and_weight_idx_map(self.net)
406        self.matrix_a = ParameterTuple(self.matrix_a)
407        self.matrix_g = ParameterTuple(self.matrix_g)
408        self.weight_decay = weight_decay
409        self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
410        self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
411        self.enable_clip_grad = enable_clip_grad
412        self.frequency = frequency
413        self._define_gpu_reducer(split_indices)
414
415    def get_frequency(self):
416        """get thor frequency"""
417        return self.frequency
418
419    def _define_gpu_operator(self):
420        """define gpu operator"""
421        self.transpose = P.Transpose()
422        self.shape = P.Shape()
423        self.reshape = P.Reshape()
424        self.matmul = P.MatMul()
425        self.assign = P.Assign()
426        self.mul = P.Mul()
427        self.gather = P.Gather()
428        self.one = Tensor(1, mstype.int32)
429        self.feature_map = Tensor(1.0, mstype.float32)
430        self.axis = 0
431        self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
432        self.cast = P.Cast()
433        self.sqrt = P.Sqrt()
434        self.eye = P.Eye()
435        self.split_dim = 128
436        self.embedding_cholesky = P.CholeskyTrsm()
437        self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim)
438        self.vector_matmul = P.BatchMatMul(transpose_a=True)
439        self.reduce_sum = P.ReduceSum(keep_dims=False)
440        self.inv = P.Reciprocal()
441        self.square = P.Square()
442        self.expand = P.ExpandDims()
443
444    def _define_gpu_reducer(self, split_indices):
445        """define gpu reducer"""
446        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
447        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
448        if self.is_distributed:
449            mean = _get_gradients_mean()
450            degree = _get_device_num()
451            if not split_indices:
452                self.split_indices = [len(self.matrix_a_cov) - 1]
453            else:
454                self.split_indices = split_indices
455            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
456            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
457            self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
458            self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
459
460    def _process_matrix_init_and_weight_idx_map(self, net):
461        """for GPU, process matrix init shape, and get weight idx map"""
462        layer_type_map = get_net_layertype_mask(net)
463        layer_counter = 0
464        for idx in range(len(self.params)):
465            layer_type = layer_type_map[layer_counter]
466            weight = self.params[idx]
467            weight_shape = self.shape(weight)
468            if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower():
469                in_channels = weight_shape[1]
470                out_channels = weight_shape[0]
471                matrix_a_dim = in_channels
472                if layer_type == Conv:
473                    matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
474                matrix_g_dim = out_channels
475                matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim)
476                matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32),
477                                         name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
478                matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32),
479                                         name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
480                self.matrix_a = self.matrix_a + (matrix_a_inv,)
481                self.matrix_g = self.matrix_g + (matrix_g_inv,)
482                self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,)
483            elif layer_type == Embedding:
484                vocab_size = weight_shape[0]
485                embedding_size = weight_shape[1]
486                matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
487                                         name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
488                matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
489                                         name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
490                self.matrix_a = self.matrix_a + (matrix_a_inv,)
491                self.matrix_g = self.matrix_g + (matrix_g_inv,)
492                self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),)
493
494            if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
495                self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
496                self.thor_layer_count = self.thor_layer_count + 1
497                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
498                if layer_type == Conv:
499                    self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
500                    self.conv_layer_count = self.conv_layer_count + 1
501                else:
502                    self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
503            else:
504                self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
505                self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
506                if layer_type == LayerNorm:
507                    self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
508                else:
509                    self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
510            # bert.cls1.output_bias: not a network layer, only a trainable param
511            if "output_bias" not in self.params[idx].name.lower():
512                layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
513
514    def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce):
515        """get matrixA inverse list and matrix G inverse list"""
516        for i in range(len(self.params)):
517            thor_layer_count = self.weight_fim_idx_map[i]
518            conv_layer_count = self.weight_conv_idx_map[i]
519            layer_type = self.weight_layertype_idx_map[i]
520            if layer_type in [Conv, FC, Embedding]:
521                g = gradients[i]
522                matrix_a = self.matrix_a_cov[thor_layer_count]
523                matrix_g = self.matrix_g_cov[thor_layer_count]
524                matrix_a = F.depend(matrix_a, g)
525                matrix_g = F.depend(matrix_g, g)
526                damping_a = damping_step
527                damping_g = damping_step
528                feature_map = self.feature_map
529                if layer_type == Conv:
530                    a_normalizer = self.a_normalizer[conv_layer_count]
531                    g_normalizer = self.g_normalizer[conv_layer_count]
532                    a_normalizer = F.depend(a_normalizer, g)
533                    g_normalizer = F.depend(g_normalizer, g)
534                    damping_a = self.mul(damping_step, 1.0 / a_normalizer)
535                    damping_g = self.mul(damping_step, 1.0 / g_normalizer)
536                    feature_map = self.sqrt(1.0 / a_normalizer)
537                a_shape = self.shape(matrix_a)
538                a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
539                damping_a = self.sqrt(damping_a)
540                damping_g = self.sqrt(damping_g)
541                g_shape = self.shape(matrix_g)
542                g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32)
543                matrix_g = self.mul(matrix_g, self.loss_scale)
544                matrix_g = self.mul(matrix_g, self.batch_size_scale)
545                matrix_g = matrix_g + damping_g * g_eye
546                if layer_type == Embedding:
547                    a_eye = P.OnesLike()(matrix_a)
548                    matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
549                    matrix_a = matrix_a + damping_a * a_eye
550                    matrix_a = self.inv(matrix_a)
551                    matrix_g = self.embedding_cholesky(matrix_g)
552                    matrix_g = self.matmul(matrix_g, matrix_g)
553                else:
554                    matrix_a = matrix_a + damping_a * a_eye
555                    matrix_a = self.cholesky(matrix_a)
556                    matrix_a = self.vector_matmul(matrix_a, matrix_a)
557                    matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a)
558                    matrix_g = self.cholesky(matrix_g)
559                    matrix_g = self.vector_matmul(matrix_g, matrix_g)
560                matrix_a = self.mul(matrix_a, feature_map)
561                matrix_g = self.mul(matrix_g, feature_map)
562                matrix_a_allreduce = matrix_a_allreduce + (matrix_a,)
563                matrix_g_allreduce = matrix_g_allreduce + (matrix_g,)
564        return matrix_a_allreduce, matrix_g_allreduce
565
566    def _process_layernorm(self, damping_step, gradient):
567        """process layernorm"""
568        damping = self.sqrt(damping_step)
569        normalizer = self.batch_size
570        normalizer = self.cast(normalizer, mstype.float32)
571        fim_cov = self.square(gradient)
572        fim_cov = self.mul(fim_cov, 1.0 / normalizer)
573        fim_cov = fim_cov + damping
574        fim_inv = self.inv(fim_cov)
575        gradient = self.mul(fim_inv, gradient)
576        return gradient
577
578    def _reshape_gradient(self, conv_layer_count, g, g_shape):
579        """reshape gradient"""
580        if conv_layer_count != -1:
581            g = self.reshape(g, g_shape)
582        return g
583
584    def construct(self, gradients):
585        params = self.params
586        moments = self.moments
587        gradients = self.flatten_gradients(gradients)
588        gradients = self.scale_grad(gradients)
589        damping_step = self.gather(self.damping, self.cov_step, self.axis)
590        damping_step = self.cast(damping_step, mstype.float32)
591        new_grads = ()
592        if self.thor:
593            matrix_ainv_list = ()
594            matrix_ginv_list = ()
595            matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step,
596                                                                              matrix_ainv_list, matrix_ginv_list)
597            if self.is_distributed:
598                matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
599                matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
600
601            for i in range(len(self.params)):
602                g = gradients[i]
603                thor_layer_count = self.weight_fim_idx_map[i]
604                conv_layer_count = self.weight_conv_idx_map[i]
605                layer_type = self.weight_layertype_idx_map[i]
606                if layer_type in [Conv, FC]:
607                    g_shape = self.shape(g)
608                    g = self.reshape(g, (g_shape[0], -1))
609                    matrix_a = matrix_a_allreduce[thor_layer_count]
610                    matrix_g = matrix_g_allreduce[thor_layer_count]
611                    g = self.update_gradient(matrix_g, g, matrix_a)
612                    self.assign(self.matrix_a[thor_layer_count], matrix_a)
613                    self.assign(self.matrix_g[thor_layer_count], matrix_g)
614                    g = self._reshape_gradient(conv_layer_count, g, g_shape)
615                elif layer_type == Embedding:
616                    matrix_a = matrix_a_allreduce[thor_layer_count]
617                    matrix_g = matrix_g_allreduce[thor_layer_count]
618                    self.assign(self.matrix_a[thor_layer_count], matrix_a)
619                    self.assign(self.matrix_g[thor_layer_count], matrix_g)
620                    temp_a = self.expand(matrix_a, 1)
621                    g = self.mul(temp_a, g)
622                    g = self.matmul(g, matrix_g)
623                elif layer_type == LayerNorm:
624                    g = self._process_layernorm(damping_step, g)
625                new_grads = new_grads + (g,)
626        else:
627            for j in range(len(self.params)):
628                g = gradients[j]
629                thor_layer_count = self.weight_fim_idx_map[j]
630                conv_layer_count = self.weight_conv_idx_map[j]
631                layer_type = self.weight_layertype_idx_map[j]
632                if layer_type in [Conv, FC]:
633                    g_shape = self.shape(g)
634                    g = self.reshape(g, (g_shape[0], -1))
635                    matrix_a = self.matrix_a[thor_layer_count]
636                    matrix_g = self.matrix_g[thor_layer_count]
637                    g = self.update_gradient(matrix_g, g, matrix_a)
638                    g = self._reshape_gradient(conv_layer_count, g, g_shape)
639                elif layer_type == Embedding:
640                    matrix_a = self.matrix_a[thor_layer_count]
641                    matrix_g = self.matrix_g[thor_layer_count]
642                    g = gradients[j]
643                    temp_a = self.expand(matrix_a, 1)
644                    g = self.mul(temp_a, g)
645                    g = self.matmul(g, matrix_g)
646                elif layer_type == LayerNorm:
647                    g = self._process_layernorm(damping_step, g)
648                new_grads = new_grads + (g,)
649        gradients = new_grads
650
651        self.cov_step = self.cov_step + self.one
652        if self.weight_decay > 0:
653            gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
654        gradients = clip_gradient(self.enable_clip_grad, gradients)
655        lr = self.get_lr()
656        self.assignadd(self.global_step, self.global_step_increase_tensor)
657        success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
658        return success
659
660
661class ThorAscend(Optimizer):
662    """ThorAscend"""
663
664    def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
665                 decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
666        params = filter(lambda x: x.requires_grad, net.get_parameters())
667        super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
668        _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
669        self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
670        self.params = self._parameters
671        self.moments = self.params.clone(prefix="moments", init='zeros')
672        self.hyper_map = C.HyperMap()
673        self.opt = P.ApplyMomentum()
674        self.net = net
675        self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
676        self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
677        self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
678        self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
679        logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
680        self._define_ascend_operator()
681        self.c0 = 16
682        self.device_shape_pad_flag = ()
683        self.diag_block_dim = 128
684        self.matrix_a = ()
685        self.matrix_g = ()
686        self.thor_layer_count = 0
687        self.conv_layer_count = 0
688        self.weight_conv_idx_map = ()
689        self.weight_fim_idx_map = ()
690        self.weight_layertype_idx_map = ()
691        self.a_split_pad_dim_map = ()
692        self.g_split_pad_dim_map = ()
693        self.conv_matmul_support_map = ()
694        self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
695        self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
696        self._process_matrix_init_and_weight_idx_map(self.net)
697        self.matrix_a = ParameterTuple(self.matrix_a)
698        self.matrix_g = ParameterTuple(self.matrix_g)
699        self.matrix_max_inv = ()
700        for i in range(len(self.matrix_a)):
701            self.matrix_max_inv = self.matrix_max_inv + (
702                Parameter(initializer(1, [1], mstype.float32), name='%s%s' % ("matrix_max", str(i)),
703                          requires_grad=False),)
704        self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
705        self.thor = True
706        self.weight_decay = weight_decay
707        self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
708        self.damping = damping
709        self.batch_size = Tensor(batch_size, mstype.float32)
710        self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
711        self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
712        self.enable_clip_grad = enable_clip_grad
713        self.frequency = frequency
714        self._define_ascend_reducer(split_indices)
715
716    def get_frequency(self):
717        """get thor frequency"""
718        return self.frequency
719
720    def _get_pad_dim(self, matrix_dim):
721        """get diag split pad dim """
722        split_pad_dim = 0
723        if matrix_dim == 64:
724            return split_pad_dim
725        res = matrix_dim % self.diag_block_dim
726        if res != 0:
727            split_pad_dim = self.diag_block_dim - res
728        return split_pad_dim
729
730    def _define_ascend_operator(self):
731        """define ascend operator"""
732        self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
733        self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
734        self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
735        self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
736        self.transpose = P.Transpose()
737        self.shape = P.Shape()
738        self.reshape = P.Reshape()
739        self.mul = P.Mul()
740        self.log = P.Log()
741        self.exp = P.Exp()
742        self.sqrt = P.Sqrt()
743        self.gather = P.Gather()
744        self.assign = P.Assign()
745        self.cast = P.Cast()
746        self.eye = P.Eye()
747        self.concat = P.Concat(0)
748        self.cholesky = P.CusCholeskyTrsm()
749        self.vector_matmul = P.CusBatchMatMul()
750        self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
751        self.fused_abs_max2 = P.CusFusedAbsMax1()
752        self.matrix_combine = P.CusMatrixCombine()
753        self.slice = P.Slice()
754        self.expand = P.ExpandDims()
755        self.reduce_sum = P.ReduceSum(keep_dims=False)
756        self.square = P.Square()
757        self.inv = P.Inv()
758        self.matmul = P.MatMul()
759        self.axis = 0
760        self.one = Tensor(1, mstype.int32)
761        self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
762
763    def _define_ascend_reducer(self, split_indices):
764        """define ascend reducer"""
765        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
766        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
767        if self.is_distributed:
768            mean = _get_gradients_mean()
769            degree = _get_device_num()
770            if not split_indices:
771                self.split_indices = [len(self.matrix_a_cov) - 1]
772            else:
773                self.split_indices = split_indices
774            if self.conv_layer_count > 0:
775                auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
776                auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
777                self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
778                self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
779
780            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
781            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
782            self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
783            self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
784
785    def _get_weight_idx_map(self, layer_type, idx, weight_shape):
786        """for Ascend, get weight idx map"""
787        if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
788            self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
789            self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
790            if layer_type == Embedding:
791                a_pad_dim = 0
792                g_pad_dim = 0
793                self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
794                self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
795            else:
796                out_channels = weight_shape[0]
797                g_pad_dim = self._get_pad_dim(out_channels)
798                self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
799                matrix_a_dim = weight_shape[1]
800                if layer_type == Conv:
801                    matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
802                a_pad_dim = self._get_pad_dim(matrix_a_dim)
803                self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
804
805            self.thor_layer_count = self.thor_layer_count + 1
806            if layer_type == Conv:
807                self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
808                self.conv_layer_count = self.conv_layer_count + 1
809            else:
810                self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
811        else:
812            self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
813            self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
814            if layer_type == LayerNorm:
815                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
816            else:
817                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
818
819    def _get_fc_matrix(self, weight_shape):
820        """for Ascend, get fc matrix_a and matrix_g"""
821        out_channels = weight_shape[0]
822        in_channels = weight_shape[1]
823        if self.conv_layer_count > 0:
824            if out_channels == 1001:
825                fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
826                                        name='matrix_a_inv_' + str(self.thor_layer_count),
827                                        requires_grad=False)
828                fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
829                                        name="matrix_g_inv_" + str(self.thor_layer_count),
830                                        requires_grad=False)
831            else:
832                fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
833                                        name='matrix_a_inv_' + str(self.thor_layer_count),
834                                        requires_grad=False)
835                fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
836                                        name="matrix_g_inv_" + str(self.thor_layer_count),
837                                        requires_grad=False)
838            self.matrix_a = self.matrix_a + (fc_matrix_a,)
839            self.matrix_g = self.matrix_g + (fc_matrix_g,)
840
841    def _process_matrix_init_and_weight_idx_map(self, net):
842        """for Ascend, process matrix init shape, and get weight idx map"""
843        layer_counter = 0
844        layer_type_map = get_net_layertype_mask(net)
845        for idx in range(len(self.params)):
846            layer_type = layer_type_map[layer_counter]
847            weight = self.params[idx]
848            weight_shape = self.shape(weight)
849            if layer_type == Conv and "bias" not in self.params[idx].name.lower():
850                in_channels = weight_shape[1]
851                out_channels = weight_shape[0]
852                matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
853                matrix_g_dim = out_channels
854                matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
855                matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
856                ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
857                if ret:
858                    matrix_a_inv = Parameter(
859                        Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
860                        name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
861                    matrix_g_inv = Parameter(
862                        Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
863                        name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
864                    self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
865                else:
866                    matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
867                                             name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
868                    matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
869                                             name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
870                    self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
871                self.matrix_a = self.matrix_a + (matrix_a_inv,)
872                self.matrix_g = self.matrix_g + (matrix_g_inv,)
873                device_shape_pad_flag = False
874                if matrix_a_dim != matrix_a_device_dim:
875                    device_shape_pad_flag = True
876                self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
877            elif layer_type == FC and "bias" not in self.params[idx].name.lower():
878                self._get_fc_matrix(weight_shape)
879            self._get_weight_idx_map(layer_type, idx, weight_shape)
880            # bert.cls1.output_bias: not a network layer, only a trainable param
881            if "output_bias" not in self.params[idx].name.lower():
882                layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
883
884    def _process_batch_matmul(self, input_matrix):
885        """process batch matmul"""
886        input_matrix_shape = self.shape(input_matrix)
887        if input_matrix_shape[0] in self.batch_matmul_support_list:
888            input_matrix = self.vector_matmul(input_matrix, input_matrix)
889        else:
890            input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
891        return input_matrix
892
893    def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
894        """process cholesky pad"""
895        if pad_dim > 0:
896            matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
897            matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
898            input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
899            input_matrix = self.concat((input_matrix, matrix_sup))
900        return input_matrix
901
902    def _get_abs_max(self, matrix_inv, origin_dim):
903        """get matrix abs max"""
904        cholesky_shape = self.shape(matrix_inv)
905        if cholesky_shape[0] in self.abs_max_support_list:
906            matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
907            matrix_max = self.fused_abs_max2(matrix_inv_max)
908            matrix_inv = self.matrix_combine(matrix_inv)
909        else:
910            matrix_inv = self.matrix_combine(matrix_inv)
911            matrix_abs = P.Abs()(matrix_inv)
912            matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
913        return matrix_max, matrix_inv
914
915    def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
916                          matrix_a_max_allreduce, matrix_g_max_allreduce):
917        """get fc layer ainv and ginv"""
918        thor_layer_count = self.weight_fim_idx_map[index]
919        g = gradients[index]
920        matrix_a = self.matrix_a_cov[thor_layer_count]
921        matrix_g = self.matrix_g_cov[thor_layer_count]
922        matrix_a = F.depend(matrix_a, g)
923        matrix_g = F.depend(matrix_g, g)
924        a_shape = self.shape(matrix_a)
925        a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
926        g_shape = self.shape(matrix_g)
927        g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
928        damping = self.sqrt(damping_step)
929        matrix_a = matrix_a + damping * a_eye
930        a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
931        matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
932        matrix_a_inv = self.cholesky(matrix_a)
933        matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
934
935        weight_shape = self.shape(self.params[index])
936        out_channels = weight_shape[0]
937        in_channels = weight_shape[1]
938        if out_channels == 2:
939            matrix_a_inv = self.matrix_combine(matrix_a_inv)
940            matrix_g_inv = g_eye
941        else:
942            matrix_g = self.mul(matrix_g, self.loss_scale)
943            matrix_g = self.mul(matrix_g, self.batch_size_scale)
944            matrix_g = matrix_g + damping * g_eye
945            g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
946            matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
947            matrix_g_inv = self.cholesky(matrix_g)
948            matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
949            if self.conv_layer_count > 0:
950                a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
951                g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
952                a_max = F.depend(a_max, g)
953                g_max = F.depend(g_max, g)
954                matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
955                matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
956            else:
957                matrix_a_inv = self.matrix_combine(matrix_a_inv)
958                matrix_g_inv = self.matrix_combine(matrix_g_inv)
959
960            if a_pad_dim > 0:
961                matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
962            if g_pad_dim > 0:
963                matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
964            matrix_a_inv_shape = self.shape(matrix_a_inv)
965            matrix_g_combine_shape = self.shape(matrix_g_inv)
966            if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
967                matrix_a_inv = self.reshape(matrix_a_inv,
968                                            (matrix_a_inv_shape[0] // 16, 16,
969                                             matrix_a_inv_shape[0] // 16, 16))
970                matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
971                matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
972
973                matrix_g_inv_shape = self.shape(matrix_g_inv)
974                matrix_g_inv = self.reshape(matrix_g_inv,
975                                            (matrix_g_inv_shape[0] // 16, 16,
976                                             matrix_g_inv_shape[0] // 16, 16))
977                matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
978
979        matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
980        matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
981        return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
982
983    def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
984        """process conv matmul device pad"""
985        if self.device_shape_pad_flag[conv_layer_count]:
986            kernel_hw = weight_shape[2] * weight_shape[3]
987            in_channels = weight_shape[1]
988            matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
989            matrix_a_inv = P.Pad(((0, 0), (0, self.c0 - in_channels), (0, 0),
990                                  (0, self.c0 - in_channels)))(matrix_a_inv)
991        return matrix_a_inv
992
993    def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
994                                      matrix_a_max_allreduce, matrix_g_max_allreduce):
995        """get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
996        for i in range(len(self.params)):
997            thor_layer_count = self.weight_fim_idx_map[i]
998            conv_layer_count = self.weight_conv_idx_map[i]
999            layer_type = self.weight_layertype_idx_map[i]
1000            weight_shape = self.shape(self.params[i])
1001            out_channels = weight_shape[0]
1002            if layer_type == Conv:
1003                g = gradients[i]
1004                matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
1005                matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
1006                matrix_a = self.matrix_a_cov[thor_layer_count]
1007                matrix_g = self.matrix_g_cov[thor_layer_count]
1008                matrix_a = F.depend(matrix_a, g)
1009                matrix_g = F.depend(matrix_g, g)
1010                a_shape = self.shape(matrix_a)
1011                a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
1012                g_shape = self.shape(matrix_g)
1013                g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
1014                a_normalizer = self.a_normalizer[conv_layer_count]
1015                g_normalizer = self.g_normalizer[conv_layer_count]
1016                a_normalizer = F.depend(a_normalizer, g)
1017                g_normalizer = F.depend(g_normalizer, g)
1018                damping_a = self.mul(damping_step, self.batch_size / a_normalizer)
1019                damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
1020                damping_a = self.sqrt(damping_a)
1021                matrix_a = matrix_a + damping_a * a_eye
1022                a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
1023                matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
1024                matrix_a_inv = self.cholesky(matrix_a)
1025                matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
1026                a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)
1027
1028                damping_g = self.sqrt(damping_g)
1029                matrix_g = self.mul(matrix_g, self.loss_scale)
1030                matrix_g = self.mul(matrix_g, self.batch_size_scale)
1031                matrix_g = matrix_g + damping_g * g_eye
1032                g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
1033                matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
1034                matrix_g_inv = self.cholesky(matrix_g)
1035                matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
1036                g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
1037
1038                if a_pad_dim > 0:
1039                    matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
1040                if g_pad_dim > 0:
1041                    matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
1042
1043                if matmul_support_flag == 1:
1044                    matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
1045                    matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
1046                    matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
1047                                                  matrix_a_inv_shape[1], matrix_a_inv_shape[3])
1048                    matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
1049                    matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
1050                    matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
1051                    matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
1052                                                  matrix_g_inv_shape[1], matrix_g_inv_shape[3])
1053                    matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
1054                    matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
1055
1056                a_max = F.depend(a_max, g)
1057                g_max = F.depend(g_max, g)
1058                matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
1059                matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
1060                matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
1061                matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
1062            elif layer_type == FC:
1063                matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
1064                    self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
1065                                           matrix_a_max_allreduce, matrix_g_max_allreduce)
1066            elif layer_type == Embedding:
1067                g = gradients[i]
1068                matrix_a = self.matrix_a_cov[thor_layer_count]
1069                matrix_g = self.matrix_g_cov[thor_layer_count]
1070                matrix_a = F.depend(matrix_a, g)
1071                matrix_g = F.depend(matrix_g, g)
1072                g_shape = self.shape(matrix_g)
1073                g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
1074                damping = self.sqrt(damping_step)
1075                a_eye = P.OnesLike()(matrix_a)
1076                matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
1077                matrix_a = matrix_a + damping * a_eye
1078                matrix_a_inv = self.inv(matrix_a)
1079                matrix_g = self.mul(matrix_g, self.loss_scale)
1080                matrix_g = self.mul(matrix_g, self.batch_size_scale)
1081                matrix_g = matrix_g + damping * g_eye
1082                matrix_g_inv = self.cholesky(matrix_g)
1083                matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
1084                matrix_g_inv = self.matrix_combine(matrix_g_inv)
1085                matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
1086                matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
1087        return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
1088
1089    def _process_layernorm(self, damping_step, gradient):
1090        """process layernorm layer for thor"""
1091        damping = self.sqrt(damping_step)
1092        normalizer = self.cast(self.batch_size, mstype.float32)
1093        fim_cov = self.square(gradient)
1094        fim_cov = self.mul(fim_cov, 1.0 / normalizer)
1095        fim_cov = fim_cov + damping
1096        fim_inv = self.inv(fim_cov)
1097        gradient = self.mul(fim_inv, gradient)
1098        return gradient
1099
1100    def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g):
1101        """process thor graph fc layer"""
1102        temp_a = matrix_a_allreduce[thor_layer_count]
1103        temp_g = matrix_g_allreduce[thor_layer_count]
1104        self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
1105        self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
1106        temp_a = self.cast(temp_a, mstype.float16)
1107        temp_g = self.cast(temp_g, mstype.float16)
1108        g = self.cast(g, mstype.float16)
1109        g = self.matmul(temp_g, g)
1110        g = self.matmul(g, temp_a)
1111        g = self.cast(g, mstype.float32)
1112        return g
1113
1114    def _get_second_gradients_one(self, params_len, gradients, new_grads):
1115        """get second gradients one"""
1116        for i in range(params_len):
1117            g = gradients[i]
1118            thor_layer_count = self.weight_fim_idx_map[i]
1119            conv_layer_count = self.weight_conv_idx_map[i]
1120            layer_type = self.weight_layertype_idx_map[i]
1121            matrix_a = self.matrix_a[thor_layer_count]
1122            matrix_g = self.matrix_g[thor_layer_count]
1123            matrix_max = self.matrix_max_inv[thor_layer_count]
1124            grad_shape = self.shape(g)
1125            if layer_type == FC:
1126                if grad_shape[0] == 1001:
1127                    g = self.cube_matmul_left_fc(matrix_g, g)
1128                    g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
1129                else:
1130                    temp_a = self.cast(matrix_a, mstype.float16)
1131                    temp_g = self.cast(matrix_g, mstype.float16)
1132                    g = self.cast(g, mstype.float16)
1133                    g = self.matmul(temp_g, g)
1134                    g = self.matmul(g, temp_a)
1135                    g = self.cast(g, mstype.float32)
1136                    g = self.mul(g, matrix_max)
1137            elif layer_type == Conv:
1138                matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
1139                if matmul_support_flag == 1:
1140                    g = self.cube_matmul_left(matrix_g, g)
1141                    g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
1142                else:
1143                    g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
1144                    temp_a = self.cast(matrix_a, mstype.float16)
1145                    temp_g = self.cast(matrix_g, mstype.float16)
1146                    g = self.cast(g, mstype.float16)
1147                    g = self.matmul(temp_g, g)
1148                    g = self.matmul(g, temp_a)
1149                    g = self.cast(g, mstype.float32)
1150                    g = self.mul(g, matrix_max)
1151                    g = self.reshape(g, grad_shape)
1152            new_grads = new_grads + (g,)
1153        return new_grads
1154
1155    def _get_second_gradients(self, new_grads, damping_step, gradients):
1156        """get second gradients for thor"""
1157        params_len = len(self.params)
1158        if self.conv_layer_count > 0:
1159            new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
1160        else:
1161            for i in range(params_len):
1162                g = gradients[i]
1163                thor_layer_count = self.weight_fim_idx_map[i]
1164                layer_type = self.weight_layertype_idx_map[i]
1165                if layer_type == Embedding:
1166                    temp_a_ori = self.matrix_a_cov[thor_layer_count]
1167                    temp_g = self.matrix_g_cov[thor_layer_count]
1168                    temp_a = self.expand(temp_a_ori, 1)
1169                    g = self.mul(temp_a, g)
1170                    temp_g = self.cast(temp_g, mstype.float16)
1171                    g = self.cast(g, mstype.float16)
1172                    g = self.matmul(g, temp_g)
1173                    g = self.cast(g, mstype.float32)
1174                elif layer_type == FC:
1175                    temp_a = self.matrix_a_cov[thor_layer_count]
1176                    temp_g = self.matrix_g_cov[thor_layer_count]
1177                    temp_a = self.cast(temp_a, mstype.float16)
1178                    temp_g = self.cast(temp_g, mstype.float16)
1179                    g = self.cast(g, mstype.float16)
1180                    g = self.matmul(temp_g, g)
1181                    g = self.matmul(g, temp_a)
1182                    g = self.cast(g, mstype.float32)
1183                elif layer_type == LayerNorm:
1184                    g = self._process_layernorm(damping_step, g)
1185                new_grads = new_grads + (g,)
1186        return new_grads
1187
1188    def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max):
1189        """get second gradient by matmul"""
1190        conv_layer_count = self.weight_conv_idx_map[index]
1191        layer_type = self.weight_layertype_idx_map[index]
1192        grad_shape = self.shape(g)
1193        if layer_type == FC:
1194            if grad_shape[0] == 1001:
1195                g = self.cube_matmul_left_fc(temp_g, g)
1196                g = self.cube_matmul_right_fc(g, temp_a, temp_max)
1197            else:
1198                temp_a = self.cast(temp_a, mstype.float16)
1199                temp_g = self.cast(temp_g, mstype.float16)
1200                g = self.cast(g, mstype.float16)
1201                g = self.matmul(temp_g, g)
1202                g = self.matmul(g, temp_a)
1203                g = self.cast(g, mstype.float32)
1204                g = self.mul(g, temp_max)
1205        elif layer_type == Conv:
1206            a_normalizer = self.a_normalizer[conv_layer_count]
1207            a_normalizer = F.depend(a_normalizer, g)
1208            temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
1209            matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
1210            if matmul_support_flag == 1:
1211                g = self.cube_matmul_left(temp_g, g)
1212                g = self.cube_matmul_right_mul(g, temp_a, temp_max)
1213            else:
1214                g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
1215                temp_a = self.cast(temp_a, mstype.float16)
1216                temp_g = self.cast(temp_g, mstype.float16)
1217                g = self.cast(g, mstype.float16)
1218                g = self.matmul(temp_g, g)
1219                g = self.matmul(g, temp_a)
1220                g = self.cast(g, mstype.float32)
1221                g = self.mul(g, temp_max)
1222                g = self.reshape(g, grad_shape)
1223        return g, temp_max
1224
1225    def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):
1226        """get second gradient by layertype"""
1227        thor_layer_count = self.weight_fim_idx_map[index]
1228        layer_type = self.weight_layertype_idx_map[index]
1229        if layer_type == Embedding:
1230            temp_a_ori = matrix_a_allreduce[thor_layer_count]
1231            temp_g = matrix_g_allreduce[thor_layer_count]
1232            self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
1233            self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
1234            temp_a = self.expand(temp_a_ori, 1)
1235            g = self.mul(temp_a, g)
1236            temp_g = self.cast(temp_g, mstype.float16)
1237            g = self.cast(g, mstype.float16)
1238            g = self.matmul(g, temp_g)
1239            g = self.cast(g, mstype.float32)
1240        elif layer_type == FC:
1241            g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g)
1242        elif layer_type == LayerNorm:
1243            g = self._process_layernorm(damping_step, g)
1244        return g
1245
1246    def construct(self, gradients):
1247        params = self.params
1248        moments = self.moments
1249        gradients = self.flatten_gradients(gradients)
1250        gradients = self.scale_grad(gradients)
1251        damping_step = self.gather(self.damping, self.cov_step, self.axis)
1252        damping_step = self.cast(damping_step, mstype.float32)
1253        if self.thor:
1254            matrix_a_allreduce = ()
1255            matrix_g_allreduce = ()
1256            matrix_a_max_allreduce = ()
1257            matrix_g_max_allreduce = ()
1258            matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
1259                self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
1260                                                   matrix_a_max_allreduce, matrix_g_max_allreduce)
1261            if self.is_distributed:
1262                matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
1263                matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
1264                if self.conv_layer_count > 0:
1265                    matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce)
1266                    matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce)
1267
1268            new_grads = ()
1269            if self.conv_layer_count > 0:
1270                for i in range(len(self.params)):
1271                    g = gradients[i]
1272                    thor_layer_count = self.weight_fim_idx_map[i]
1273                    temp_a = matrix_a_allreduce[thor_layer_count]
1274                    temp_g = matrix_g_allreduce[thor_layer_count]
1275                    matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count])
1276                    matrix_a_inv_max = self.mul(matrix_a_inv_max, -1)
1277                    matrix_a_inv_max = self.exp(matrix_a_inv_max)
1278                    temp_a = self.mul(temp_a, matrix_a_inv_max)
1279                    matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count])
1280                    matrix_g_inv_max = self.mul(matrix_g_inv_max, -1)
1281                    matrix_g_inv_max = self.exp(matrix_g_inv_max)
1282                    temp_g = self.mul(temp_g, matrix_g_inv_max)
1283                    temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count],
1284                                        matrix_g_max_allreduce[thor_layer_count])
1285                    temp_a = self.cast(temp_a, mstype.float16)
1286                    temp_g = self.cast(temp_g, mstype.float16)
1287                    g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
1288                    self.assign(self.matrix_a[thor_layer_count], temp_a)
1289                    self.assign(self.matrix_g[thor_layer_count], temp_g)
1290                    self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
1291                    new_grads = new_grads + (g,)
1292                gradients = new_grads
1293            else:
1294                for i in range(len(self.params)):
1295                    g = gradients[i]
1296                    g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step)
1297                    new_grads = new_grads + (g,)
1298                gradients = new_grads
1299        else:
1300            new_grads = ()
1301            gradients = self._get_second_gradients(new_grads, damping_step, gradients)
1302
1303        self.cov_step = self.cov_step + self.one
1304        if self.weight_decay > 0:
1305            gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
1306        gradients = clip_gradient(self.enable_clip_grad, gradients)
1307        lr = self.get_lr()
1308        self.assignadd(self.global_step, self.global_step_increase_tensor)
1309        success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
1310        return success
1311