1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""grad freeze""" 16 17import numpy as np 18 19from mindspore.nn.cell import Cell 20from mindspore.nn.optim import Optimizer 21from mindspore.common import Tensor 22from mindspore.common import dtype as mstype 23from mindspore.nn.optim import LARS 24from mindspore.nn.wrap.grad_reducer import DistributedGradReducer 25from mindspore.ops import functional as F 26 27from .base import ParameterProcess 28from .grad_accumulation import GradientAccumulation 29 30__all__ = ['GradientFreeze', 'FreezeOpt', 'freeze_cell'] 31 32 33CONTINUOUS_STRATEGY = 0 34INTERVAL_STRATEGY = 1 35 36 37class FreezeOpt(Cell): 38 """ 39 Optimizer that supports gradients freezing training. 40 41 Args: 42 opt (Cell): non-freezing optimizer instance, such as 'Momentum', 'SGD'. 43 train_parameter_groups (Union[tuple, list]): Groups of parameters for gradients freezing training. 44 train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training. 45 46 Supported Platforms: 47 ``Ascend`` 48 """ 49 def __init__(self, opt, train_parameter_groups=None, train_strategy=None): 50 super(FreezeOpt, self).__init__() 51 if not isinstance(opt, Optimizer): 52 raise TypeError( 53 f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}") 54 if train_strategy is not None and train_parameter_groups is None: 55 raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' " 56 "must also be specified") 57 if isinstance(opt, LARS): 58 self.is_lars = True 59 self.opt_class = type(opt.opt) 60 self.opt_init_args = opt.opt.init_args 61 self.lars_init_args = opt.init_args 62 self.parameters = opt.opt.parameters 63 else: 64 self.is_lars = False 65 self.opt_class = type(opt) 66 self.opt_init_args = opt.init_args 67 self.parameters = opt.parameters 68 self.opts = [] 69 70 if train_parameter_groups is None: 71 groups_num = 10 72 step = 6 73 parameters = opt.parameters 74 para_groups = (parameters[(i * step):] for i in range(groups_num)) 75 self.opts = [self._generate_new_optimizer( 76 params) for params in para_groups] 77 else: 78 if not isinstance(train_parameter_groups, (tuple, list)): 79 raise TypeError( 80 "The specified 'train_parameter_groups' should be tuple or list") 81 for params in train_parameter_groups: 82 if not isinstance(params, (tuple, list)): 83 raise TypeError("The each element of 'train_parameter_groups' should be tuple or list " 84 "to store the Parameter") 85 86 # generate one-to-one opt corresponding to the parameter group 87 self.opts.append(self._generate_new_optimizer(params)) 88 89 if isinstance(train_strategy, (tuple, list)): 90 for ele in train_strategy: 91 if not isinstance(ele, int): 92 raise ValueError( 93 "The element in train_strategy should be int number") 94 self.train_strategy = Tensor(train_strategy, mstype.int32) 95 elif isinstance(train_strategy, Tensor): 96 if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32: 97 raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and " 98 "the dtype should be int32") 99 self.train_strategy = train_strategy 100 elif train_strategy is None: 101 self.train_strategy = None 102 else: 103 raise TypeError( 104 "The specified 'train_strategy' should be None, tuple, list or Tensor") 105 106 def _generate_new_optimizer(self, params): 107 """Generate new optimizer.""" 108 if not self.is_lars: 109 opt = self.opt_class(params=params, **self.opt_init_args) 110 else: 111 opt = LARS(self.opt_class(params=params, **self.opt_init_args), 112 **self.lars_init_args) 113 return opt 114 115 116class _TrainFreezeCell(Cell): 117 """ 118 Gradient freezing training network. 119 120 Args: 121 net (Cell): The training network. 122 sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. 123 grad (tuple(Tensor)): The gradients of network parameters and inputs. 124 grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on 125 single-process gradient values. 126 use_grad_accumulation (bool): Whether use grad accumulation. 127 optimizer (Union[Cell]): Optimizer for updating the weights. 128 max_accumulation_step (numbers.Number): Max grad accumulation steps. Default: 1.0 129 130 Supported Platforms: 131 ``Ascend`` 132 """ 133 def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, optimizer, max_accumulation_step=1): 134 super(_TrainFreezeCell, self).__init__(auto_prefix=False) 135 self.net = net 136 self.grad = grad 137 self.grad_reducer = grad_reducer 138 self.opt = optimizer 139 self.parameters = optimizer.parameters 140 self.sens = sens 141 self.use_grad_accumulation = use_grad_accumulation 142 self.max_accumulation_step = max_accumulation_step 143 if use_grad_accumulation: 144 self.grad_accumulation = GradientAccumulation( 145 self.max_accumulation_step, self.optimizer) 146 147 def construct(self, *inputs): 148 loss = self.net(*inputs) 149 sens = F.fill(loss.dtype, loss.shape, self.sens) 150 grads = self.grad(self.net, self.parameters)(*inputs, sens) 151 grads = self.grad_reducer(grads) 152 if self.use_grad_accumulation: 153 loss = self.grad_accumulation(loss, grads) 154 else: 155 loss = F.depend(loss, self.opt(grads)) 156 return loss 157 158 159class GradientFreeze: 160 """ 161 Freezing the gradients of some layers randomly. The number and 162 probability of frozen layers can be configured by users 163 164 Args: 165 param_groups (Union[tuple, list]): Groups of parameters for gradients freezing training. 166 freeze_type (int): Strategy of gradients freezing training. 167 freeze_p (float): probability of gradients freezing training. 168 total_steps (numbers.Number): Steps of the whole training. 169 170 Examples: 171 >>> gradient_freeze_class = acc.GradientFreeze(10, 1, 0.5, 2000) 172 >>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer) 173 """ 174 def __init__(self, param_groups, freeze_type, freeze_p, total_steps): 175 self._param_groups = param_groups 176 self._freeze_type = freeze_type 177 self._freeze_p = freeze_p 178 self._total_steps = total_steps 179 self.grad_reducer = F.identity 180 self._param_processer = ParameterProcess() 181 182 def split_parameters_groups(self, net, freeze_para_groups_number): 183 """Split parameter groups for gradients freezing training.""" 184 grouped_params = [] 185 tmp = [] 186 for para in net.trainable_params(): 187 name = para.name 188 # ensure 'bn' after 'conv' is not split 189 if 'bn' in name or 'bias' in name: 190 tmp.append(para) 191 elif len(tmp) >= 3: 192 grouped_params.append(tmp) 193 tmp = [para] 194 else: 195 tmp.append(para) 196 if tmp: 197 grouped_params.append(tmp) 198 stride = len(grouped_params) // freeze_para_groups_number 199 freeze_grouped_params = [sum(grouped_params[i * stride:], []) 200 for i in range(freeze_para_groups_number)] 201 return freeze_grouped_params 202 203 def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps): 204 """Generate index sequence for gradient freezing training.""" 205 total_step = int(total_steps * 1.01) 206 if parameter_groups_number <= 1: 207 return [0 for _ in range(total_step)] 208 # local continuous freezing training strategy, as '00001234' 209 if freeze_strategy == CONTINUOUS_STRATEGY: 210 zero_cnt = int( 211 freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5) 212 sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number)) 213 freeze_idxes = [] 214 while len(freeze_idxes) < total_step: 215 freeze_idxes += sub_idx 216 return freeze_idxes 217 # interval freezing training strategy, as '01020304' 218 if freeze_strategy == INTERVAL_STRATEGY: 219 index_all = list(range(1, parameter_groups_number)) 220 prob = [x / sum(index_all) for x in index_all] 221 freeze_idxes = [0] 222 zero_cnt = 1 223 freeze_cnt = 0 224 while len(freeze_idxes) < total_step: 225 freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt) 226 if freeze_p_cur < 1 - freeze_p: 227 freeze_idxes.append( 228 int(np.random.choice(index_all[::-1], p=prob))) 229 freeze_cnt += 1 230 else: 231 freeze_idxes.append(0) 232 zero_cnt += 1 233 return freeze_idxes 234 raise ValueError( 235 f"Unsupported freezing training strategy '{freeze_strategy}'") 236 237 def freeze_generate(self, network, optimizer): 238 """Generate freeze network and optimizer.""" 239 train_para_groups = self.split_parameters_groups( 240 network, self._param_groups) 241 for i in range(self._param_groups): 242 train_para_groups[i] = self._param_processer.generate_group_params(train_para_groups[i], 243 optimizer.init_params['params']) 244 train_strategy = self.generate_freeze_index_sequence( 245 self._param_groups, self._freeze_type, self._freeze_p, self._total_steps) 246 optimizer = FreezeOpt(optimizer, train_para_groups, train_strategy) 247 248 return network, optimizer 249 250 251def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None, 252 max_accumulation_step=1): 253 """Provide freeze network cell.""" 254 if reducer_flag: 255 param_processer = ParameterProcess() 256 grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters), 257 mean, degree) for opt in optimizer.opts) 258 freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, reducer, 259 use_grad_accumulation, opt, max_accumulation_step) 260 for reducer, opt in zip(grad_reducers, optimizer.opts)) 261 else: 262 freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, F.identity, 263 use_grad_accumulation, opt, max_accumulation_step) 264 for opt in optimizer.opts) 265 return freeze_nets 266