1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""base process""" 16from __future__ import absolute_import 17 18import os 19import time 20import math 21import copy 22import numpy as np 23from scipy import linalg as la 24from mindspore.context import ParallelMode 25import mindspore.nn as nn 26from mindspore.nn.optim import LARS 27from mindspore import log as logger 28from mindspore.common import Parameter 29from mindspore.communication.management import get_group_size 30from mindspore.train.serialization import load_checkpoint 31from mindspore.parallel._utils import _get_global_rank 32from mindspore.parallel._auto_parallel_context import auto_parallel_context 33from mindspore.boost.less_batch_normalization import CommonHeadLastFN 34 35 36__all__ = ["OptimizerProcess", "ParameterProcess"] 37 38 39class OptimizerProcess: 40 r""" 41 Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags 42 and creating new optimizers. 43 44 Args: 45 opt (Cell): Optimizer used. 46 47 Examples: 48 >>> import numpy as np 49 >>> from mindspore import Tensor, Parameter, nn 50 >>> from mindspore import ops 51 >>> from mindspore.boost import OptimizerProcess 52 >>> 53 >>> class Net(nn.Cell): 54 ... def __init__(self, in_features, out_features): 55 ... super(Net, self).__init__() 56 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 57 ... name='weight') 58 ... self.matmul = ops.MatMul() 59 ... 60 ... def construct(self, x): 61 ... output = self.matmul(x, self.weight) 62 ... return output 63 ... 64 >>> size, in_features, out_features = 16, 16, 10 65 >>> network = Net(in_features, out_features) 66 >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) 67 >>> optimizer_process = OptimizerProcess(optimizer) 68 >>> optimizer_process.add_grad_centralization(network) 69 >>> optimizer = optimizer_process.generate_new_optimizer() 70 """ 71 def __init__(self, opt): 72 if isinstance(opt, LARS): 73 self.is_lars = True 74 self.single_opt = opt.opt 75 self.opt_class = type(opt.opt) 76 self.opt_init_args = opt.opt.init_args 77 self.lars_init_args = opt.init_args 78 self.learning_rate = opt.opt.init_learning_rate 79 else: 80 self.is_lars = False 81 self.single_opt = opt 82 self.opt_class = type(opt) 83 self.opt_init_args = opt.init_args 84 self.learning_rate = opt.init_learning_rate 85 self.origin_params = opt.init_params["params"] 86 87 @staticmethod 88 def build_params_dict(network): 89 r""" 90 Build the parameter's dict of the network. 91 92 Args: 93 network (Cell): The training network. 94 """ 95 cells = network.cells_and_names() 96 params_dict = {} 97 for _, cell in cells: 98 for par in cell.get_parameters(expand=False): 99 params_dict[id(par)] = cell 100 return params_dict 101 102 @staticmethod 103 def build_gc_params_group(params_dict, parameters): 104 r""" 105 Build the parameter's group with grad centralization. 106 107 Args: 108 params_dict (dict): The network's parameter dict. 109 parameters (list): The network's parameter list. 110 """ 111 group_params = [] 112 for group_param in parameters: 113 if 'order_params' in group_param.keys(): 114 group_params.append(group_param) 115 continue 116 params_gc_value = [] 117 params_value = [] 118 for param in group_param['params']: 119 if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: 120 param_cell = params_dict[id(param)] 121 if (isinstance(param_cell, nn.Conv2d) and param_cell.group > 1) or \ 122 isinstance(param_cell, CommonHeadLastFN): 123 params_value.append(param) 124 else: 125 params_gc_value.append(param) 126 else: 127 params_value.append(param) 128 if params_gc_value: 129 new_group_param = copy.deepcopy(group_param) 130 new_group_param['params'] = params_gc_value 131 new_group_param['grad_centralization'] = True 132 group_params.append(new_group_param) 133 if params_value: 134 new_group_param = copy.deepcopy(group_param) 135 new_group_param['params'] = params_value 136 group_params.append(new_group_param) 137 return group_params 138 139 def add_grad_centralization(self, network): 140 r""" 141 Add gradient centralization. 142 143 Args: 144 network (Cell): The training network. 145 """ 146 params_dict = OptimizerProcess.build_params_dict(network) 147 148 if self.origin_params is not None and not isinstance(self.origin_params, list): 149 parameters = list(self.origin_params) 150 else: 151 parameters = self.origin_params 152 153 if not parameters: 154 raise ValueError("Optimizer got an empty parameter list.") 155 156 if not isinstance(parameters[0], (dict, Parameter)): 157 raise TypeError("Only a list of Parameter or dict can be supported.") 158 159 if isinstance(parameters[0], Parameter): 160 logger.warning("Only group parameters support gradient centralization.") 161 return 162 163 self.origin_params = OptimizerProcess.build_gc_params_group(params_dict, parameters) 164 165 def generate_new_optimizer(self): 166 """Generate new optimizer.""" 167 if self.learning_rate is None: 168 self.learning_rate = self.single_opt.learning_rate 169 if not self.is_lars: 170 opt = self.opt_class(params=self.origin_params, learning_rate=self.learning_rate, **self.opt_init_args) 171 else: 172 opt = LARS(self.opt_class(params=self.origin_params, learning_rate=self.learning_rate, \ 173 **self.opt_init_args), **self.lars_init_args) 174 175 return opt 176 177 178class ParameterProcess: 179 r""" 180 Process parameter for Boost. Currently, this class supports creating group parameters 181 and automatically setting gradient segmentation point. 182 183 Examples: 184 >>> import numpy as np 185 >>> from mindspore import Tensor, Parameter, nn 186 >>> from mindspore import ops 187 >>> from mindspore.boost import ParameterProcess 188 >>> 189 >>> class Net(nn.Cell): 190 ... def __init__(self, in_features, out_features): 191 ... super(Net, self).__init__() 192 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 193 ... name='weight') 194 ... self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 195 ... name='weight2') 196 ... self.matmul = ops.MatMul() 197 ... self.matmul2 = ops.MatMul() 198 ... 199 ... def construct(self, x): 200 ... output = self.matmul(x, self.weight) 201 ... output2 = self.matmul2(x, self.weight2) 202 ... return output + output2 203 ... 204 >>> size, in_features, out_features = 16, 16, 10 205 >>> network = Net(in_features, out_features) 206 >>> new_parameter = network.trainable_params()[:1] 207 >>> group_params = ParameterProcess.generate_group_params(new_parameter, network.trainable_params()) 208 """ 209 def __init__(self): 210 self._parameter_indices = 1 211 212 @staticmethod 213 def generate_group_params(parameters, origin_params): 214 r""" 215 Generate group parameters. 216 217 Args: 218 parameters (list): The network's parameter list. 219 origin_params (list): The network's origin parameter list. 220 """ 221 origin_params_copy = origin_params 222 if origin_params_copy is not None: 223 if not isinstance(origin_params_copy, list): 224 origin_params_copy = list(origin_params_copy) 225 226 if not origin_params_copy: 227 raise ValueError("Optimizer got an empty parameter list.") 228 229 if not isinstance(origin_params_copy[0], (dict, Parameter)): 230 raise TypeError("Only a list of Parameter or dict can be supported.") 231 232 if isinstance(origin_params_copy[0], Parameter): 233 group_params = [{"params": parameters}] 234 return group_params 235 236 return ParameterProcess._generate_new_group_params(parameters, origin_params_copy) 237 238 @staticmethod 239 def _generate_new_group_params(parameters, origin_params_copy): 240 r""" 241 Generate new group parameters. 242 243 Args: 244 parameters (list): The network's parameter list. 245 origin_params_copy (list): Copy from origin parameter list. 246 """ 247 group_params = [] 248 params_name = [param.name for param in parameters] 249 new_params_count = copy.deepcopy(params_name) 250 new_params_clone = {} 251 max_key_number = 0 252 for group_param in origin_params_copy: 253 if 'order_params' in group_param.keys(): 254 new_group_param = copy.deepcopy(group_param) 255 new_group_param['order_params'] = parameters 256 group_params.append(new_group_param) 257 continue 258 params_value = [] 259 for param in group_param['params']: 260 if param.name in params_name: 261 index = params_name.index(param.name) 262 params_value.append(parameters[index]) 263 new_params_count.remove(param.name) 264 new_group_param = copy.deepcopy(group_param) 265 new_group_param['params'] = params_value 266 group_params.append(new_group_param) 267 if len(group_param.keys()) > max_key_number: 268 max_key_number = len(group_param.keys()) 269 new_params_clone = copy.deepcopy(group_param) 270 if new_params_count: 271 params_value = [] 272 for param in new_params_count: 273 index = params_name.index(param) 274 params_value.append(parameters[index]) 275 if new_params_clone: 276 new_params_clone['params'] = params_value 277 group_params.append(new_params_clone) 278 else: 279 group_params.append({"params": params_value}) 280 return group_params 281 282 def assign_parameter_group(self, parameters, split_point=None): 283 r""" 284 Assign parameter group. 285 286 Args: 287 parameters (list): The network's parameter list. 288 split_point (list): The gradient split point of this network. Default: ``None``. 289 """ 290 if not isinstance(parameters, (list, tuple)) or not parameters: 291 return parameters 292 293 parameter_len = len(parameters) 294 if split_point: 295 split_parameter_index = split_point 296 else: 297 split_parameter_index = [parameter_len // 2] 298 for i in range(parameter_len): 299 if i in split_parameter_index: 300 self._parameter_indices += 1 301 parameters[i].comm_fusion = self._parameter_indices 302 return parameters 303 304 305def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number, network): 306 """ 307 get local pca mat path. 308 309 Args: 310 weight_load_dir (str): The weight(ckpt) file directory to be load. 311 pca_mat_path (str): the path to load pca mat. Default: ``None``. 312 n_component (int): pca component. 313 device_number (int): device number. 314 network (Cell): The network. 315 """ 316 if pca_mat_path is not None and os.path.exists(pca_mat_path) and os.path.isfile(pca_mat_path) and \ 317 pca_mat_path.endswith(".npy"): 318 full_pca_mat_path = pca_mat_path 319 pca_mat_exist = True 320 321 else: 322 if weight_load_dir is None or not os.path.exists(weight_load_dir) or not os.path.isdir(weight_load_dir): 323 raise ValueError("The weight_load_dir: {} is None / not exists / not directory.".format(weight_load_dir)) 324 325 full_pca_mat_path = os.path.join(weight_load_dir, "pca_mat_temp.npy") 326 pca_mat_exist = False 327 328 save_pca_end_path = os.path.join(os.path.dirname(full_pca_mat_path), "save_pca_end.txt") 329 if os.path.exists(save_pca_end_path): 330 os.remove(save_pca_end_path) 331 332 rank = _get_global_rank() 333 local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank) + ".npy" 334 if os.path.exists(local_pca_mat_path): 335 os.remove(local_pca_mat_path) 336 if rank % device_number != 0: 337 return local_pca_mat_path 338 339 if pca_mat_exist: 340 pca_mat = np.load(full_pca_mat_path) 341 else: 342 data = _load_weights(weight_load_dir, network) 343 pca_mat = _compute_pca_mat(data, n_component) 344 np.save(full_pca_mat_path, pca_mat) 345 _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component) 346 return local_pca_mat_path 347 348 349def _load_weights(weight_load_dir, network): 350 """ 351 load weights. 352 353 Args: 354 weight_load_dir (str): The weight(ckpt) file directory to be load. 355 network (Cell): The network. 356 """ 357 param_requires_grad_list = [] 358 for param in network.trainable_params(): 359 param_requires_grad_list.append(param.name) 360 361 param_mat_tuple = () 362 weight_file_list = os.listdir(weight_load_dir) 363 for file in weight_file_list: 364 if not file.endswith('.ckpt'): 365 continue 366 file_path = os.path.join(weight_load_dir, file) 367 param_dict = load_checkpoint(file_path) 368 param_tuple = () 369 for key, value in param_dict.items(): 370 if key in param_requires_grad_list: 371 param_tuple += (value.asnumpy().reshape((1, -1)),) 372 param = np.concatenate(param_tuple, axis=1) 373 param_mat_tuple += (param,) 374 param_mat = np.concatenate(param_mat_tuple, axis=0) 375 return param_mat 376 377 378def _compute_pca_mat(data, n_component, randomized=True): 379 """ 380 compute pca mat. 381 382 Args: 383 data (array): array-like of shape (n_samples, n_features) 384 Training data, where `n_samples` is the number of samples 385 and `n_features` is the number of features. 386 n_component (int): pca component. 387 randomized (bool) if use randomized svd. 388 """ 389 if data.shape[0] < n_component: 390 raise ValueError("The samples: {} is less than: n_component {}.".format(data.shape[0], n_component)) 391 392 if randomized: 393 components = _randomized_svd(data, n_component) 394 else: 395 components = _full_svd(data, n_component) 396 397 return components 398 399 400def _randomized_svd(data, n_component, n_oversample=10, n_iter=1): 401 """ 402 compute pca mat use randomized svd. 403 404 Args: 405 data (array): array-like of shape (n_samples, n_features) 406 Training data, where `n_samples` is the number of samples 407 and `n_features` is the number of features. 408 n_component (int): pca component. 409 n_oversample (int): oversample num 410 n_iter (int): iteration count 411 """ 412 data -= np.mean(data, axis=0) 413 n_random = n_component + n_oversample 414 n_samples, n_features = data.shape 415 transpose = n_samples < n_features 416 if transpose: 417 data = data.T 418 q_mat = _randomized_range_finder(data, n_random, n_iter) 419 b_mat = q_mat.T @ data 420 u_hat, _, vt_mat = la.svd(b_mat, full_matrices=False) 421 del b_mat 422 u_mat = np.dot(q_mat, u_hat) 423 u_mat, vt_mat = _svd_flip(u_mat, vt_mat, transpose) 424 if transpose: 425 components = u_mat[:, :n_component].T 426 else: 427 components = vt_mat[:n_component, :] 428 return components 429 430 431def _full_svd(data, n_component): 432 """ 433 compute pca mat use full svd. 434 435 Args: 436 data (array): array-like of shape (n_samples, n_features) 437 Training data, where `n_samples` is the number of samples 438 and `n_features` is the number of features. 439 n_component (int): pca component. 440 """ 441 mean = np.mean(data, axis=0) 442 data -= mean 443 u, _, v = la.svd(data, full_matrices=False) 444 _, v = _svd_flip(u, v) 445 components = v[:n_component] 446 return components 447 448 449def _randomized_range_finder(data, size, n_iter=1): 450 """ 451 compute pca mat use randomized svd. 452 453 Args: 454 data (array): array-like of shape (n_samples, n_features) 455 Training data, where `n_samples` is the number of samples 456 and `n_features` is the number of features. 457 size (int): n_component + n_oversample. 458 n_iter (int): iteration count 459 """ 460 q_mat = np.random.normal(size=(data.shape[1], size)) 461 462 for _ in range(n_iter): 463 q_mat, _ = la.lu(data @ q_mat, permute_l=True) 464 q_mat, _ = la.lu(data.T @ q_mat, permute_l=True) 465 466 q_mat, _ = la.qr(data @ q_mat, mode="economic") 467 return q_mat 468 469 470def _svd_flip(u, v, transpose=True): 471 """ 472 svd flip. 473 474 Args: 475 u (ndarray): the output of `linalg.svd`. 476 v (ndarray): the output of `linalg.svd`. 477 transpose (bool): if data is transposed. 478 """ 479 if not transpose: 480 max_abs_cols = np.argmax(np.abs(u), axis=0) 481 signs = np.sign(u[max_abs_cols, range(u.shape[1])]) 482 u *= signs 483 v *= signs[:, np.newaxis] 484 else: 485 max_abs_rows = np.argmax(np.abs(v), axis=1) 486 signs = np.sign(v[range(v.shape[0]), max_abs_rows]) 487 u *= signs 488 v *= signs[:, np.newaxis] 489 return u, v 490 491 492def _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component): 493 """ 494 save pca mat. 495 496 Args: 497 pca_mat (numpy.ndarray): pca mat to be saved. 498 full_pca_mat_path (str): the path of full pca mat. 499 n_component (int): pca component. 500 """ 501 parallel_mode = auto_parallel_context().get_parallel_mode() 502 rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size() 503 local_dim = math.ceil(n_component // rank_size) 504 for rank_id in range(rank_size): 505 start_index = rank_id * local_dim 506 end_index = (rank_id + 1) * local_dim 507 pca_start_index = min(n_component, start_index) 508 pca_end_index = min(n_component, end_index) 509 p_local = np.zeros([local_dim, pca_mat.shape[1]]) 510 if pca_start_index != pca_end_index: 511 p_local[0: pca_end_index - pca_start_index, :] = pca_mat[pca_start_index: pca_end_index, :] 512 local_pca_mat_path = "{}_rank_{}.npy".format(full_pca_mat_path[:-4], str(rank_id)) 513 np.save(local_pca_mat_path, p_local) 514 save_pca_end_path = os.path.join(os.path.dirname(full_pca_mat_path), "save_pca_end.txt") 515 os.mknod(save_pca_end_path) 516 517 518def _load_local_pca_mat(local_pca_mat_path, timeout): 519 """ 520 load pca mat. 521 522 Args: 523 local_pca_mat_path (str): local pca mat file path. 524 """ 525 save_pca_end_path = os.path.join(os.path.dirname(local_pca_mat_path), "save_pca_end.txt") 526 start_time = time.time() 527 while True: 528 current_time = time.time() 529 if (current_time - start_time) > timeout: 530 raise RuntimeError("the time of waiting to load local pca mat is larger than {} second.".format(timeout)) 531 if os.path.exists(save_pca_end_path): 532 break 533 time.sleep(5) 534 pca_mat = np.load(local_pca_mat_path) 535 return pca_mat 536