• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""grad freeze"""
16from __future__ import absolute_import
17from __future__ import division
18
19import numpy as np
20from mindspore import nn
21from mindspore.nn.cell import Cell
22from mindspore.nn.optim import Optimizer
23from mindspore.common import Tensor
24from mindspore.common import dtype as mstype
25from mindspore.nn.optim import LARS
26from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
27from mindspore.ops import functional as F
28
29from mindspore.boost.base import ParameterProcess
30from mindspore.boost.grad_accumulation import GradientAccumulation
31
32__all__ = ['GradientFreeze', 'FreezeOpt', 'freeze_cell']
33
34
35CONTINUOUS_STRATEGY = 0
36INTERVAL_STRATEGY = 1
37
38
39class FreezeOpt(Cell):
40    """
41    Optimizer that supports gradients freezing training.
42
43    Args:
44        opt (Cell): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
45        train_parameter_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
46            Default: ``None`` .
47        train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.
48            Default: ``None`` .
49
50    Supported Platforms:
51        ``Ascend``
52    """
53    def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
54        super(FreezeOpt, self).__init__()
55        if not isinstance(opt, Optimizer):
56            raise TypeError(
57                f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
58        if train_strategy is not None and train_parameter_groups is None:
59            raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
60                             "must also be specified")
61        if isinstance(opt, LARS):
62            self.is_lars = True
63            self.opt_class = type(opt.opt)
64            self.opt_init_args = opt.opt.init_args
65            self.lars_init_args = opt.init_args
66            self.single_opt = opt.opt
67            self.parameters = opt.opt.parameters
68            self.learning_rate = opt.opt.init_learning_rate
69            self.dynamic_lr = opt.opt.dynamic_lr
70        else:
71            self.is_lars = False
72            self.opt_class = type(opt)
73            self.opt_init_args = opt.init_args
74            self.single_opt = opt
75            self.parameters = opt.parameters
76            self.learning_rate = opt.init_learning_rate
77            self.dynamic_lr = opt.dynamic_lr
78
79        self.opts = []
80        if train_parameter_groups is None:
81            self.groups_num = 1
82            step = 1
83            parameters = opt.parameters
84            train_parameter_groups = (tuple(parameters[(i * step):]) for i in range(self.groups_num))
85        else:
86            if not isinstance(train_parameter_groups, (tuple, list)):
87                raise TypeError(
88                    "The specified 'train_parameter_groups' should be tuple or list")
89            self.groups_num = len(train_parameter_groups)
90
91        self._init_train_strategy(train_strategy)
92        self._create_new_group_learning_rate()
93
94        self.opt_index = 0
95        for params in train_parameter_groups:
96            if not isinstance(params, (tuple, list)):
97                raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
98                                "to store the Parameter")
99            # generate one-to-one opt corresponding to the parameter group
100            self.opts.append(self._generate_new_optimizer(params))
101            self.opt_index += 1
102
103    def _init_train_strategy(self, train_strategy):
104        """Init train strategy for gradient freeze."""
105        if isinstance(train_strategy, (tuple, list)):
106            for ele in train_strategy:
107                if not isinstance(ele, int):
108                    raise ValueError(
109                        "The element in train_strategy should be int number")
110            self.train_strategy = Tensor(train_strategy, mstype.int32)
111        elif isinstance(train_strategy, Tensor):
112            if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
113                raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and "
114                                 "the dtype should be int32")
115            self.train_strategy = train_strategy
116        elif train_strategy is None:
117            self.train_strategy = None
118        else:
119            raise TypeError(
120                "The specified 'train_strategy' should be None, tuple, list or Tensor")
121
122    def _create_new_group_learning_rate(self):
123        """Create new learning rate for different global step."""
124        self.dynamic_learning_rate = [[] for _ in range(self.groups_num)]
125        if self.learning_rate is None:
126            self.learning_rate = self.single_opt.learning_rate
127            return
128        if self.dynamic_lr and isinstance(self.learning_rate, list) and isinstance(self.train_strategy, Tensor):
129            train_strategy = list(self.train_strategy.asnumpy())
130            if len(self.learning_rate) <= len(train_strategy):
131                for i, lr in enumerate(self.learning_rate):
132                    self.dynamic_learning_rate[train_strategy[i]].append(lr)
133
134    def _generate_new_optimizer(self, params):
135        """Generate new optimizer."""
136        if self.dynamic_learning_rate[self.opt_index]:
137            lr = self.dynamic_learning_rate[self.opt_index]
138        else:
139            lr = self.learning_rate
140        if not self.is_lars:
141            opt = self.opt_class(params=params, learning_rate=lr, **self.opt_init_args)
142            opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
143        else:
144            opt = LARS(self.opt_class(params=params, learning_rate=lr, **self.opt_init_args),
145                       **self.lars_init_args)
146            opt.opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
147            opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
148        return opt
149
150
151class _TrainFreezeCell(Cell):
152    r"""
153    Gradient freezing training network.
154
155    Args:
156        net (Cell): The training network.
157        sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
158        grad (tuple(Tensor)): The gradients of network parameters and inputs.
159        grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on
160    single-process gradient values.
161        use_grad_accumulation (bool): Whether use grad accumulation.
162        optimizer (Union[Cell]): Optimizer for updating the weights.
163        max_accumulation_step (numbers.Number): Max grad accumulation steps. Default: 1.0
164
165    Supported Platforms:
166        ``Ascend``
167    """
168    def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, optimizer, max_accumulation_step=1):
169        super(_TrainFreezeCell, self).__init__(auto_prefix=False)
170        self.net = net
171        self.grad = grad
172        self.grad_reducer = grad_reducer
173        self.opt = optimizer
174        self.parameters = optimizer.parameters
175        self.sens = sens
176        self.use_grad_accumulation = use_grad_accumulation
177        self.max_accumulation_step = max_accumulation_step
178        if use_grad_accumulation:
179            self.grad_accumulation = GradientAccumulation(
180                self.max_accumulation_step, self.optimizer)
181
182    def construct(self, *inputs):
183        loss = self.net(*inputs)
184        sens = F.fill(loss.dtype, loss.shape, self.sens)
185        grads = self.grad(self.net, self.parameters)(*inputs, sens)
186        grads = self.grad_reducer(grads)
187        if self.use_grad_accumulation:
188            loss = self.grad_accumulation(loss, grads)
189        else:
190            loss = F.depend(loss, self.opt(grads))
191        return loss
192
193
194class GradientFreeze:
195    r"""
196    Gradients freezing algorithm, freezing the gradients of some layers randomly,
197    to improve network training performance. The number and
198    probability of frozen layers can be configured by users.
199
200    Args:
201        param_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
202        freeze_type (int): Strategy of gradients freezing training.
203        freeze_p (float): probability of gradients freezing training.
204        total_steps (int): Steps of the whole training.
205
206    Examples:
207        >>> import numpy as np
208        >>> from mindspore import Tensor, Parameter, nn
209        >>> from mindspore import ops
210        >>> from mindspore.nn import WithLossCell
211        >>> from mindspore import dtype as mstype
212        >>> from mindspore import boost
213        >>>
214        >>> class Net(nn.Cell):
215        ...    def __init__(self, in_features, out_features):
216        ...        super(Net, self).__init__()
217        ...        self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
218        ...                                name='weight')
219        ...        self.matmul = ops.MatMul()
220        ...
221        ...    def construct(self, x):
222        ...        output = self.matmul(x, self.weight)
223        ...        return output
224        >>> size, in_features, out_features = 16, 16, 10
225        >>> net = Net(in_features, out_features)
226        >>> loss = nn.MSELoss()
227        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
228        >>> net_with_loss = WithLossCell(net, loss)
229        >>> gradient_freeze_class = boost.GradientFreeze(10, 1, 0.5, 2000)
230        >>> network, optimizer = gradient_freeze_class.freeze_generate(net_with_loss, optimizer)
231        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
232        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
233        >>> output = network(inputs, label)
234    """
235    def __init__(self, param_groups, freeze_type, freeze_p, total_steps):
236        self._param_groups = param_groups
237        self._freeze_type = freeze_type
238        self._freeze_p = freeze_p
239        self._total_steps = total_steps
240        self.grad_reducer = nn.Identity()
241
242    def split_parameters_groups(self, net, freeze_para_groups_number):
243        r"""
244        Split parameter groups for gradients freezing training.
245
246        Args:
247            net (Cell): The training network.
248            freeze_para_groups_number (int): The number of gradient freeze groups.
249        """
250        grouped_params = []
251        tmp = []
252        for para in net.trainable_params():
253            name = para.name
254            # ensure 'bn' after 'conv' is not split
255            if 'bn' in name or 'bias' in name:
256                tmp.append(para)
257            elif len(tmp) >= 3:
258                grouped_params.append(tmp)
259                tmp = [para]
260            else:
261                tmp.append(para)
262        if tmp:
263            grouped_params.append(tmp)
264        stride = len(grouped_params) // freeze_para_groups_number
265        freeze_grouped_params = [sum(grouped_params[i * stride:], [])
266                                 for i in range(freeze_para_groups_number)]
267        return freeze_grouped_params
268
269    def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
270        r"""
271        Generate index sequence for gradient freezing training.
272
273        Args:
274            parameter_groups_number (int): The number of parameter groups.
275            freeze_strategy (int): Gradient freeze grouping strategy, select from [0, 1].
276            freeze_p (float): Gradient freezing probability.
277            total_steps (int): Total training steps.
278        """
279        total_step = int(total_steps * 1.01)
280        if parameter_groups_number <= 1:
281            return [0 for _ in range(total_step)]
282        # local continuous freezing training strategy, as '00001234'
283        if freeze_strategy == CONTINUOUS_STRATEGY:
284            zero_cnt = int(
285                freeze_p * (parameter_groups_number - 1) // (1 - freeze_p) + 0.5)
286            sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
287            freeze_idxes = []
288            while len(freeze_idxes) < total_step:
289                freeze_idxes += sub_idx
290            return freeze_idxes
291        # interval freezing training strategy, as '01020304'
292        if freeze_strategy == INTERVAL_STRATEGY:
293            index_all = list(range(1, parameter_groups_number))
294            prob = [x / sum(index_all) for x in index_all]
295            freeze_idxes = [0]
296            zero_cnt = 1
297            freeze_cnt = 0
298            while len(freeze_idxes) < total_step:
299                freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
300                if freeze_p_cur < 1 - freeze_p:
301                    freeze_idxes.append(
302                        int(np.random.choice(index_all[::-1], p=prob)))
303                    freeze_cnt += 1
304                else:
305                    freeze_idxes.append(0)
306                    zero_cnt += 1
307            return freeze_idxes
308        raise ValueError(
309            f"Unsupported freezing training strategy '{freeze_strategy}'")
310
311    def freeze_generate(self, network, optimizer):
312        r"""
313        Generate freeze network and optimizer.
314
315        Args:
316            network (Cell): The training network.
317            optimizer (Cell): Optimizer for updating the weights.
318        """
319        train_para_groups = self.split_parameters_groups(
320            network, self._param_groups)
321        for i in range(self._param_groups):
322            train_para_groups[i] = ParameterProcess.generate_group_params(train_para_groups[i], \
323                                                                          optimizer.init_params['params'])
324        train_strategy = self.generate_freeze_index_sequence(
325            self._param_groups, self._freeze_type, self._freeze_p, self._total_steps)
326        optimizer = FreezeOpt(optimizer, train_para_groups, train_strategy)
327
328        return network, optimizer
329
330
331def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
332                max_accumulation_step=1):
333    r"""
334    Generate freeze network and optimizer.
335
336    Args:
337        reducer_flag (bool): Reducer flag.
338        network (Cell): The training network.
339        optimizer (Cell): Optimizer for updating the weights.
340        sens (numbers.Number):  The scaling number.
341        grad (tuple(Tensor)): Tuple of gradient tensors.
342        use_grad_accumulation (bool): Use gradient accumulation flag.
343        mean (bool): Gradients mean flag. Default: ``None`` .
344        degree (int): Device number. Default: ``None`` .
345        max_accumulation_step (int): Max accumulation steps. Default: ``1`` .
346
347    Examples:
348        >>> import numpy as np
349        >>> from mindspore import Tensor, Parameter, nn
350        >>> from mindspore import ops
351        >>> from mindspore.boost.grad_freeze import freeze_cell, FreezeOpt
352        >>>
353        >>> class Net(nn.Cell):
354        ...     def __init__(self, in_features, out_features):
355        ...         super(Net, self).__init__()
356        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
357        ...                                 name='weight')
358        ...         self.matmul = ops.MatMul()
359        ...
360        ...     def construct(self, x):
361        ...         output = self.matmul(x, self.weight)
362        ...         return output
363        ...
364        >>> in_features, out_features = 16, 10
365        >>> network = Net(in_features, out_features)
366        >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
367        >>> optimizer = FreezeOpt(optimizer)
368        >>> grad = ops.GradOperation(get_by_list=True, sens_param=True)
369        >>> freeze_nets = freeze_cell(False, network, optimizer, 1.0, grad, False, None, None, 1)
370    """
371    if reducer_flag:
372        param_processer = ParameterProcess()
373        grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),
374                                                mean, degree) for opt in optimizer.opts)
375        freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, reducer,
376                                             use_grad_accumulation, opt, max_accumulation_step)
377                            for reducer, opt in zip(grad_reducers, optimizer.opts))
378    else:
379        freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, nn.Identity(),
380                                             use_grad_accumulation, opt, max_accumulation_step)
381                            for opt in optimizer.opts)
382    return freeze_nets
383