1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""grad accumulation""" 16from __future__ import absolute_import 17from __future__ import division 18 19from mindspore.nn.cell import Cell 20from mindspore.common import Parameter, Tensor 21from mindspore.common import dtype as mstype 22from mindspore.ops import composite as C 23from mindspore.ops import functional as F 24from mindspore.ops import operations as P 25 26 27__all__ = ["GradientAccumulation", "gradient_accumulation_op", "gradient_clear_op"] 28 29 30gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op") 31 32 33@gradient_accumulation_op.register("Int64", "Tensor", "Tensor") 34def cumulative_grad_process(accumulation_step, cumulative_grad, grad): 35 """Apply gradient accumulation to cumulative grad.""" 36 P.AssignAdd()(cumulative_grad, grad / accumulation_step) 37 return cumulative_grad 38 39 40gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op") 41 42 43@gradient_clear_op.register("Tensor") 44def clear_grad(cumulative_grad): 45 """Clear grad.""" 46 zero_grad = P.ZerosLike()(cumulative_grad) 47 return F.assign(cumulative_grad, zero_grad) 48 49 50class GradientAccumulation(Cell): 51 """ 52 After accumulating the gradients of multiple steps, call to optimize its update. 53 54 Args: 55 max_accumulation_step (int): Steps to accumulate gradients. 56 optimizer (Cell): Optimizer used. 57 """ 58 def __init__(self, max_accumulation_step, optimizer): 59 super(GradientAccumulation, self).__init__() 60 self._max_accumulation_step = max_accumulation_step 61 self.optimizer = optimizer 62 self.weights = optimizer.parameters 63 self.hyper_map = C.HyperMap() 64 self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros') 65 self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step") 66 67 def construct(self, loss, grads): 68 loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step), 69 self._grad_accumulation, grads)) 70 self._accumulation_step += 1 71 72 if self._accumulation_step >= self._max_accumulation_step: 73 loss = F.depend(loss, self.optimizer(self._grad_accumulation)) 74 F.assign(self._accumulation_step, 0) 75 76 if self._accumulation_step == 0: 77 loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation)) 78 79 return loss 80