• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""grad accumulation"""
16from __future__ import absolute_import
17from __future__ import division
18
19from mindspore.nn.cell import Cell
20from mindspore.common import Parameter, Tensor
21from mindspore.common import dtype as mstype
22from mindspore.ops import composite as C
23from mindspore.ops import functional as F
24from mindspore.ops import operations as P
25
26
27__all__ = ["GradientAccumulation", "gradient_accumulation_op", "gradient_clear_op"]
28
29
30gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
31
32
33@gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
34def cumulative_grad_process(accumulation_step, cumulative_grad, grad):
35    """Apply gradient accumulation to cumulative grad."""
36    P.AssignAdd()(cumulative_grad, grad / accumulation_step)
37    return cumulative_grad
38
39
40gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
41
42
43@gradient_clear_op.register("Tensor")
44def clear_grad(cumulative_grad):
45    """Clear grad."""
46    zero_grad = P.ZerosLike()(cumulative_grad)
47    return F.assign(cumulative_grad, zero_grad)
48
49
50class GradientAccumulation(Cell):
51    """
52    After accumulating the gradients of multiple steps, call to optimize its update.
53
54    Args:
55       max_accumulation_step (int): Steps to accumulate gradients.
56       optimizer (Cell): Optimizer used.
57    """
58    def __init__(self, max_accumulation_step, optimizer):
59        super(GradientAccumulation, self).__init__()
60        self._max_accumulation_step = max_accumulation_step
61        self.optimizer = optimizer
62        self.weights = optimizer.parameters
63        self.hyper_map = C.HyperMap()
64        self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
65        self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
66
67    def construct(self, loss, grads):
68        loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step),
69                                             self._grad_accumulation, grads))
70        self._accumulation_step += 1
71
72        if self._accumulation_step >= self._max_accumulation_step:
73            loss = F.depend(loss, self.optimizer(self._grad_accumulation))
74            F.assign(self._accumulation_step, 0)
75
76        if self._accumulation_step == 0:
77            loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
78
79        return loss
80