1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""base process""" 16import copy 17import mindspore.nn as nn 18from mindspore.nn.optim import LARS 19from mindspore import log as logger 20from mindspore.common import Parameter 21from .less_batch_normalization import CommonHeadLastFN 22 23 24__all__ = ["OptimizerProcess", "ParameterProcess"] 25 26 27class OptimizerProcess: 28 """ 29 Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags 30 and creating new optimizers. 31 32 Args: 33 opt (Cell): Optimizer used. 34 35 Examples: 36 >>> from mindspore import Tensor, Parameter, nn 37 >>> import mindspore.ops import ops 38 >>> from mindspore.boost import OptimizerProcess 39 >>> 40 >>> class Net(nn.Cell): 41 ... def __init__(self, in_features, out_features): 42 ... super(Net, self).__init__() 43 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 44 ... name='weight') 45 ... self.matmul = ops.MatMul() 46 ... 47 ... def construct(self, x): 48 ... output = self.matmul(x, self.weight) 49 ... return output 50 ... 51 >>> size, in_features, out_features = 16, 16, 10 52 >>> network = Net(in_features, out_features) 53 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 54 >>> optimizer_process = OptimizerProcess(optimizer) 55 >>> optimizer_process.add_grad_centralization(network) 56 >>> optimizer = optimizer_process.generate_new_optimizer() 57 """ 58 def __init__(self, opt): 59 if isinstance(opt, LARS): 60 self.is_lars = True 61 self.opt_class = type(opt.opt) 62 self.opt_init_args = opt.opt.init_args 63 self.lars_init_args = opt.init_args 64 else: 65 self.is_lars = False 66 self.opt_class = type(opt) 67 self.opt_init_args = opt.init_args 68 self.origin_params = opt.init_params["params"] 69 70 def build_params_dict(self, network): 71 """Build the params dict of the network""" 72 cells = network.cells_and_names() 73 params_dict = {} 74 for _, cell in cells: 75 for par in cell.get_parameters(expand=False): 76 params_dict[id(par)] = cell 77 return params_dict 78 79 def build_gc_params_group(self, params_dict, parameters): 80 """Build the params group that needs gc""" 81 group_params = [] 82 for group_param in parameters: 83 if 'order_params' in group_param.keys(): 84 group_params.append(group_param) 85 continue 86 params_gc_value = [] 87 params_value = [] 88 for param in group_param['params']: 89 if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: 90 param_cell = params_dict[id(param)] 91 if (isinstance(param_cell, nn.Conv2d) and param_cell.group > 1) or \ 92 isinstance(param_cell, CommonHeadLastFN): 93 params_value.append(param) 94 else: 95 params_gc_value.append(param) 96 else: 97 params_value.append(param) 98 if params_gc_value: 99 new_group_param = copy.deepcopy(group_param) 100 new_group_param['params'] = params_gc_value 101 new_group_param['grad_centralization'] = True 102 group_params.append(new_group_param) 103 if params_value: 104 new_group_param = copy.deepcopy(group_param) 105 new_group_param['params'] = params_value 106 group_params.append(new_group_param) 107 return group_params 108 109 def add_grad_centralization(self, network): 110 """Add gradient centralization.""" 111 params_dict = self.build_params_dict(network) 112 113 parameters = self.origin_params 114 if parameters is not None and not isinstance(parameters, list): 115 parameters = list(parameters) 116 117 if not parameters: 118 raise ValueError("Optimizer got an empty parameter list.") 119 120 if not isinstance(parameters[0], (dict, Parameter)): 121 raise TypeError("Only a list of Parameter or dict can be supported.") 122 123 if isinstance(parameters[0], Parameter): 124 logger.warning("Only group parameters support gradient centralization.") 125 return 126 127 self.origin_params = self.build_gc_params_group(params_dict, parameters) 128 129 def generate_new_optimizer(self): 130 """Generate new optimizer.""" 131 if not self.is_lars: 132 opt = self.opt_class(params=self.origin_params, **self.opt_init_args) 133 else: 134 opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args) 135 136 return opt 137 138 139class ParameterProcess: 140 """ 141 Process parameter for Boost. Currently, this class supports creating group parameters 142 and automatically setting gradient segmentation point. 143 144 Examples: 145 >>> from mindspore import Tensor, Parameter, nn 146 >>> import mindspore.ops as ops 147 >>> from mindspore.boost import OptimizerProcess 148 >>> 149 >>> class Net(nn.Cell): 150 ... def __init__(self, in_features, out_features): 151 ... super(Net, self).__init__() 152 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 153 ... name='weight') 154 ... self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 155 ... name='weight2') 156 ... self.matmul = ops.MatMul() 157 ... self.matmul2 = ops.MatMul() 158 ... 159 ... def construct(self, x): 160 ... output = self.matmul(x, self.weight) 161 ... output2 = self.matmul2(x, self.weight2) 162 ... return output + output2 163 ... 164 >>> size, in_features, out_features = 16, 16, 10 165 >>> network = Net(in_features, out_features) 166 >>> new_parameter = net.trainable_params()[:1] 167 >>> parameter_process = ParameterProcess() 168 >>> group_params = parameter_process.generate_group_params(new_parameter, net.trainable_params()) 169 """ 170 def __init__(self): 171 self._parameter_indices = 1 172 173 def assign_parameter_group(self, parameters, split_point=None): 174 """Assign parameter group.""" 175 if not isinstance(parameters, (list, tuple)) or not parameters: 176 return parameters 177 178 parameter_len = len(parameters) 179 if split_point: 180 split_parameter_index = split_point 181 else: 182 split_parameter_index = [parameter_len // 2] 183 for i in range(parameter_len): 184 if i in split_parameter_index: 185 self._parameter_indices += 1 186 parameters[i].comm_fusion = self._parameter_indices 187 return parameters 188 189 def generate_group_params(self, parameters, origin_params): 190 """Generate group parameters.""" 191 origin_params_copy = origin_params 192 if origin_params_copy is not None: 193 if not isinstance(origin_params_copy, list): 194 origin_params_copy = list(origin_params_copy) 195 196 if not origin_params_copy: 197 raise ValueError("Optimizer got an empty parameter list.") 198 199 if not isinstance(origin_params_copy[0], (dict, Parameter)): 200 raise TypeError("Only a list of Parameter or dict can be supported.") 201 202 if isinstance(origin_params_copy[0], Parameter): 203 group_params = [{"params": parameters}] 204 return group_params 205 206 group_params = [] 207 params_name = [param.name for param in parameters] 208 new_params_count = copy.deepcopy(params_name) 209 new_params_clone = {} 210 max_key_number = 0 211 for group_param in origin_params_copy: 212 if 'order_params' in group_param.keys(): 213 new_group_param = copy.deepcopy(group_param) 214 new_group_param['order_params'] = parameters 215 group_params.append(new_group_param) 216 continue 217 params_value = [] 218 for param in group_param['params']: 219 if param.name in params_name: 220 index = params_name.index(param.name) 221 params_value.append(parameters[index]) 222 new_params_count.remove(param.name) 223 new_group_param = copy.deepcopy(group_param) 224 new_group_param['params'] = params_value 225 group_params.append(new_group_param) 226 if len(group_param.keys()) > max_key_number: 227 max_key_number = len(group_param.keys()) 228 new_params_clone = copy.deepcopy(group_param) 229 if new_params_count: 230 params_value = [] 231 for param in new_params_count: 232 index = params_name.index(param) 233 params_value.append(parameters[index]) 234 if new_params_clone: 235 new_params_clone['params'] = params_value 236 group_params.append(new_params_clone) 237 else: 238 group_params.append({"params": params_value}) 239 return group_params 240