• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""grad freeze"""
16
17import numpy as np
18
19from mindspore.nn.cell import Cell
20from mindspore.nn.optim import Optimizer
21from mindspore.common import Tensor
22from mindspore.common import dtype as mstype
23from mindspore.nn.optim import LARS
24from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
25from mindspore.ops import functional as F
26
27from .base import ParameterProcess
28from .grad_accumulation import GradientAccumulation
29
30__all__ = ['GradientFreeze', 'FreezeOpt', 'freeze_cell']
31
32
33CONTINUOUS_STRATEGY = 0
34INTERVAL_STRATEGY = 1
35
36
37class FreezeOpt(Cell):
38    """
39    Optimizer that supports gradients freezing training.
40
41    Args:
42        opt (Cell): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
43        train_parameter_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
44        train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.
45
46    Supported Platforms:
47        ``Ascend``
48    """
49    def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
50        super(FreezeOpt, self).__init__()
51        if not isinstance(opt, Optimizer):
52            raise TypeError(
53                f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
54        if train_strategy is not None and train_parameter_groups is None:
55            raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
56                             "must also be specified")
57        if isinstance(opt, LARS):
58            self.is_lars = True
59            self.opt_class = type(opt.opt)
60            self.opt_init_args = opt.opt.init_args
61            self.lars_init_args = opt.init_args
62            self.parameters = opt.opt.parameters
63        else:
64            self.is_lars = False
65            self.opt_class = type(opt)
66            self.opt_init_args = opt.init_args
67            self.parameters = opt.parameters
68        self.opts = []
69
70        if train_parameter_groups is None:
71            groups_num = 10
72            step = 6
73            parameters = opt.parameters
74            para_groups = (parameters[(i * step):] for i in range(groups_num))
75            self.opts = [self._generate_new_optimizer(
76                params) for params in para_groups]
77        else:
78            if not isinstance(train_parameter_groups, (tuple, list)):
79                raise TypeError(
80                    "The specified 'train_parameter_groups' should be tuple or list")
81            for params in train_parameter_groups:
82                if not isinstance(params, (tuple, list)):
83                    raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
84                                    "to store the Parameter")
85
86                # generate one-to-one opt corresponding to the parameter group
87                self.opts.append(self._generate_new_optimizer(params))
88
89        if isinstance(train_strategy, (tuple, list)):
90            for ele in train_strategy:
91                if not isinstance(ele, int):
92                    raise ValueError(
93                        "The element in train_strategy should be int number")
94            self.train_strategy = Tensor(train_strategy, mstype.int32)
95        elif isinstance(train_strategy, Tensor):
96            if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
97                raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and "
98                                 "the dtype should be int32")
99            self.train_strategy = train_strategy
100        elif train_strategy is None:
101            self.train_strategy = None
102        else:
103            raise TypeError(
104                "The specified 'train_strategy' should be None, tuple, list or Tensor")
105
106    def _generate_new_optimizer(self, params):
107        """Generate new optimizer."""
108        if not self.is_lars:
109            opt = self.opt_class(params=params, **self.opt_init_args)
110        else:
111            opt = LARS(self.opt_class(params=params, **self.opt_init_args),
112                       **self.lars_init_args)
113        return opt
114
115
116class _TrainFreezeCell(Cell):
117    """
118    Gradient freezing training network.
119
120    Args:
121        net (Cell): The training network.
122        sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
123        grad (tuple(Tensor)): The gradients of network parameters and inputs.
124        grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on
125    single-process gradient values.
126        use_grad_accumulation (bool): Whether use grad accumulation.
127        optimizer (Union[Cell]): Optimizer for updating the weights.
128        max_accumulation_step (numbers.Number): Max grad accumulation steps. Default: 1.0
129
130    Supported Platforms:
131        ``Ascend``
132    """
133    def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, optimizer, max_accumulation_step=1):
134        super(_TrainFreezeCell, self).__init__(auto_prefix=False)
135        self.net = net
136        self.grad = grad
137        self.grad_reducer = grad_reducer
138        self.opt = optimizer
139        self.parameters = optimizer.parameters
140        self.sens = sens
141        self.use_grad_accumulation = use_grad_accumulation
142        self.max_accumulation_step = max_accumulation_step
143        if use_grad_accumulation:
144            self.grad_accumulation = GradientAccumulation(
145                self.max_accumulation_step, self.optimizer)
146
147    def construct(self, *inputs):
148        loss = self.net(*inputs)
149        sens = F.fill(loss.dtype, loss.shape, self.sens)
150        grads = self.grad(self.net, self.parameters)(*inputs, sens)
151        grads = self.grad_reducer(grads)
152        if self.use_grad_accumulation:
153            loss = self.grad_accumulation(loss, grads)
154        else:
155            loss = F.depend(loss, self.opt(grads))
156        return loss
157
158
159class GradientFreeze:
160    """
161    Freezing the gradients of some layers randomly. The number and
162    probability of frozen layers can be configured by users
163
164    Args:
165        param_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
166        freeze_type (int): Strategy of gradients freezing training.
167        freeze_p (float): probability of gradients freezing training.
168        total_steps (numbers.Number): Steps of the whole training.
169
170    Examples:
171        >>> gradient_freeze_class = acc.GradientFreeze(10, 1, 0.5, 2000)
172        >>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer)
173    """
174    def __init__(self, param_groups, freeze_type, freeze_p, total_steps):
175        self._param_groups = param_groups
176        self._freeze_type = freeze_type
177        self._freeze_p = freeze_p
178        self._total_steps = total_steps
179        self.grad_reducer = F.identity
180        self._param_processer = ParameterProcess()
181
182    def split_parameters_groups(self, net, freeze_para_groups_number):
183        """Split parameter groups for gradients freezing training."""
184        grouped_params = []
185        tmp = []
186        for para in net.trainable_params():
187            name = para.name
188            # ensure 'bn' after 'conv' is not split
189            if 'bn' in name or 'bias' in name:
190                tmp.append(para)
191            elif len(tmp) >= 3:
192                grouped_params.append(tmp)
193                tmp = [para]
194            else:
195                tmp.append(para)
196        if tmp:
197            grouped_params.append(tmp)
198        stride = len(grouped_params) // freeze_para_groups_number
199        freeze_grouped_params = [sum(grouped_params[i * stride:], [])
200                                 for i in range(freeze_para_groups_number)]
201        return freeze_grouped_params
202
203    def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
204        """Generate index sequence for gradient freezing training."""
205        total_step = int(total_steps * 1.01)
206        if parameter_groups_number <= 1:
207            return [0 for _ in range(total_step)]
208        # local continuous freezing training strategy, as '00001234'
209        if freeze_strategy == CONTINUOUS_STRATEGY:
210            zero_cnt = int(
211                freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
212            sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
213            freeze_idxes = []
214            while len(freeze_idxes) < total_step:
215                freeze_idxes += sub_idx
216            return freeze_idxes
217        # interval freezing training strategy, as '01020304'
218        if freeze_strategy == INTERVAL_STRATEGY:
219            index_all = list(range(1, parameter_groups_number))
220            prob = [x / sum(index_all) for x in index_all]
221            freeze_idxes = [0]
222            zero_cnt = 1
223            freeze_cnt = 0
224            while len(freeze_idxes) < total_step:
225                freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
226                if freeze_p_cur < 1 - freeze_p:
227                    freeze_idxes.append(
228                        int(np.random.choice(index_all[::-1], p=prob)))
229                    freeze_cnt += 1
230                else:
231                    freeze_idxes.append(0)
232                    zero_cnt += 1
233            return freeze_idxes
234        raise ValueError(
235            f"Unsupported freezing training strategy '{freeze_strategy}'")
236
237    def freeze_generate(self, network, optimizer):
238        """Generate freeze network and optimizer."""
239        train_para_groups = self.split_parameters_groups(
240            network, self._param_groups)
241        for i in range(self._param_groups):
242            train_para_groups[i] = self._param_processer.generate_group_params(train_para_groups[i],
243                                                                               optimizer.init_params['params'])
244        train_strategy = self.generate_freeze_index_sequence(
245            self._param_groups, self._freeze_type, self._freeze_p, self._total_steps)
246        optimizer = FreezeOpt(optimizer, train_para_groups, train_strategy)
247
248        return network, optimizer
249
250
251def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
252                max_accumulation_step=1):
253    """Provide freeze network cell."""
254    if reducer_flag:
255        param_processer = ParameterProcess()
256        grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),
257                                                mean, degree) for opt in optimizer.opts)
258        freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, reducer,
259                                             use_grad_accumulation, opt, max_accumulation_step)
260                            for reducer, opt in zip(grad_reducers, optimizer.opts))
261    else:
262        freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, F.identity,
263                                             use_grad_accumulation, opt, max_accumulation_step)
264                            for opt in optimizer.opts)
265    return freeze_nets
266