1# Copyright 2021-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""dim_reduce""" 16from __future__ import absolute_import 17 18import math 19import numpy as np 20from mindspore.nn.cell import Cell 21from mindspore.ops import composite as C 22from mindspore.ops import functional as F 23from mindspore.ops import operations as P 24from mindspore.common.tensor import Tensor 25from mindspore.common.parameter import Parameter, ParameterTuple 26from mindspore.common import dtype as mstype 27 28 29__all__ = ["DimReduce"] 30 31 32_scale_grad = C.MultitypeFuncGraph("_scale_grad") 33 34 35@_scale_grad.register("Tensor", "Tensor") 36def _scale_grad_process(scale, grad): 37 grad = F.cast(grad, mstype.float32) 38 grad = P.Div()(grad, scale) 39 return grad 40 41 42_save_weight = C.MultitypeFuncGraph("_save_weight") 43 44 45@_save_weight.register("Tensor", "Tensor") 46def _save_weight_process(parameter, new_parameter): 47 P.Assign()(parameter, new_parameter) 48 return parameter 49 50 51_pca_projection = C.MultitypeFuncGraph("_pca_projection") 52 53 54@_pca_projection.register("Tensor", "Tensor") 55def _pca_projection_process(pca_mat, grad): 56 grad_k = P.MatMul()(pca_mat, F.reshape(grad, (-1, 1))) 57 return grad_k 58 59 60_pca_back_projection = C.MultitypeFuncGraph("_pca_back_projection") 61 62 63@_pca_back_projection.register("Tensor", "Tensor", "Tensor") 64def _pca_back_projection_process(grad_k, pca_mat, grad): 65 grad_proj = P.MatMul()(F.transpose(pca_mat, (1, 0)), grad_k) 66 grad_proj_reshape = F.reshape(grad_proj, F.shape(grad)) 67 return grad_proj_reshape 68 69 70_update_grad_res_momentum = C.MultitypeFuncGraph("_update_grad_res_momentum") 71 72 73@_update_grad_res_momentum.register("Float32", "Float32", "Tensor", "Tensor", "Tensor") 74def _update_grad_res_momentum_process(gamma, alpha, grad_res_momentum, grad, grad_proj): 75 grad_res_momentum_new = gamma * grad_res_momentum + grad - grad_proj 76 P.Assign()(grad_res_momentum, grad_res_momentum_new) 77 res = alpha * grad_res_momentum_new 78 return res 79 80 81_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight") 82 83 84@_get_delta_weight.register("Tensor", "Tensor", "Tensor") 85def _get_delta_weight_process(rho, dn, grad_res_momentum): 86 delta_weight = grad_res_momentum - rho * dn 87 return delta_weight 88 89 90class DimReduce(Cell): 91 r""" 92 The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models. 93 94 .. math:: 95 96 \begin{align} 97 grad\_k &= pca\_mat \cdot grad\\ 98 dk &= - bk \cdot grad\_k\\ 99 sk &= rho ^ m \cdot dk\\ 100 delta\_loss &= sigma \cdot grad\_k.T \cdot sk 101 \end{align} 102 103 Here: 104 105 - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight. 106 - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method. 107 108 we need to find the m satisfy: 109 110 .. math:: 111 new\_loss < old\_loss + delta\_loss 112 113 Then, get delta_grad to update the weights for model: 114 115 .. math:: 116 117 \begin{align} 118 grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\ 119 new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\ 120 delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk 121 \end{align} 122 123 Args: 124 network (Cell): The training network. The network only supports single output. 125 optimizer (Union[Cell]): Optimizer for updating the weights. 126 weight (Tuple(Parameter)): Tuple of parameters. 127 pca_mat_local (numpy.ndarray): For PCA operation, k*n, k is part of n_components, n is the size of weight. 128 n_components (int): PCA.components. 129 rho (float): Coefficient. 130 gamma (float): Coefficient. 131 alpha (float): Coefficient. 132 sigma (float): Coefficient. 133 rank (int): Rank number. 134 rank_size (int): Rank size. 135 136 Inputs: 137 - **loss** (Tensor) - Tensor with shape :math:`()`. 138 - **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors. 139 - **weight** (Tuple(Tensor)) - Tuple of parameters. 140 - **weight_clone** (Tuple(Tensor)) - clone of weight 141 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 142 143 Outputs: 144 - **loss** (Tensor) - Tensor with shape :math:`()`. 145 """ 146 def __init__(self, network, optimizer, weight, pca_mat_local, n_components, rho, gamma, alpha, sigma, rank, 147 rank_size): 148 super(DimReduce, self).__init__() 149 self.network = network 150 self.optimizer = optimizer 151 self.rank = rank 152 self.rank_size = rank_size 153 self.gamma = gamma 154 self.alpha = alpha 155 self.sigma = sigma 156 157 self.float_type = mstype.float32 158 self._set_rho_list(rho) 159 self._set_local_pca_mat(pca_mat_local, n_components, weight) 160 self._set_init_parameter(weight) 161 162 self.hyper_map = C.HyperMap() 163 self.concat = P.Concat() 164 self.matmul = P.MatMul() 165 self.mul = P.Mul() 166 self.add = P.Add() 167 168 def construct(self, loss, old_grad, loss_scale, weight, weight_clone, *inputs): 169 gk, old_loss, gk_local = self._generate_gk(weight, loss, old_grad, loss_scale) 170 171 _save_weight(self.gk_last_back, self.gk_last) 172 _save_weight(self.bk_back, self.bk) 173 174 dk = self._apply_quasi_newton_update(gk) 175 if self.dk_pad_flag: 176 dk_pad = self.concat((dk, self.dk_pad_part)) 177 else: 178 dk_pad = dk 179 dk_local = dk_pad[self.start_index: self.end_index, :] 180 181 dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad) 182 grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad) 183 dn = self.dn_init if self.rank_size > 1 else dn_local 184 grad_proj = self.grad_proj_init if self.rank_size > 1 else grad_proj_local 185 if self.rank_size > 1: 186 for broadcast in self.broadcast_list: 187 dn_part = broadcast(dn_local) 188 dn = self.hyper_map(self.add, dn, dn_part) 189 grad_proj_part = broadcast(grad_proj_local) 190 grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part) 191 192 rho, find = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs) 193 if not find: 194 _save_weight(self.gk_last, self.gk_last_back) 195 _save_weight(self.bk, self.bk_back) 196 197 clone = self._res_loss(old_grad, grad_proj, weight, weight_clone, rho, dn) 198 return F.depend(loss, clone) 199 200 def _set_rho_list(self, rho): 201 """set rho list info.""" 202 self.max_search_time = 2 203 self.rho_list = [] 204 for i in range(self.max_search_time): 205 self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type)) 206 self.rho_list.append(Tensor(0, dtype=self.float_type)) 207 208 def _set_local_pca_mat(self, pca_mat_local, n_components, parameter_tuple): 209 """set pca info.""" 210 self.n_components = n_components 211 local_dim = math.ceil(self.n_components // self.rank_size) 212 213 self.start_index = self.rank * local_dim 214 self.end_index = (self.rank + 1) * local_dim 215 216 start = 0 217 self.pca_list_local = () 218 for param in parameter_tuple: 219 size = np.shape(param.asnumpy().reshape((-1, 1)))[0] 220 self.pca_list_local += (Tensor(pca_mat_local[:, start:start + size], dtype=self.float_type),) 221 start += size 222 223 self.dk_pad_flag = False 224 pad_num = self.rank_size * local_dim - self.n_components 225 if pad_num: 226 self.dk_pad_flag = True 227 self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type) 228 229 if self.rank_size > 1: 230 self.broadcast_list = [] 231 for i in range(self.rank_size): 232 broadcast = P.Broadcast(i) 233 self.broadcast_list.append(broadcast) 234 self.allreduce = P.AllReduce() 235 self.allgather = P.AllGather() 236 237 def _set_init_parameter(self, parameter_tuple): 238 """init parameters.""" 239 self.true_flag = Tensor(True) 240 self.false_flag = Tensor(False) 241 self.epsilon = np.power(10.0, -20) 242 self.gk_last = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="gk_last") 243 self.gk_last_init = Parameter(Tensor(False), name="gk_last_init") 244 self.bk = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk") 245 self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk") 246 self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type) 247 self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros") 248 self.gk_last_back = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), 249 name="gk_last_back") 250 self.bk_back = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk_back") 251 self.grad_proj_init = ParameterTuple(parameter_tuple).clone(prefix="grad_proj_init", init="zeros") 252 self.dn_init = ParameterTuple(parameter_tuple).clone(prefix="dn_init", init="zeros") 253 254 def _res_loss(self, old_grad, grad_proj, weight, weight_clone, rho, dn): 255 """update loss""" 256 update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha), 257 self.grad_res_momentum, old_grad, grad_proj) 258 delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad) 259 update = self.optimizer(delta_weight) 260 weight = F.depend(weight, update) 261 clone = self.hyper_map(_save_weight, weight_clone, weight) 262 return clone 263 264 def _generate_gk(self, weight, loss, old_grad, loss_scale): 265 """generate gk""" 266 weight = F.depend(weight, loss) 267 old_grad = F.depend(old_grad, weight) 268 old_grad = self.hyper_map(F.partial(_scale_grad, loss_scale), old_grad) 269 old_loss = self.allreduce(loss) // self.rank_size if self.rank_size > 1 else loss 270 271 gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad) 272 gk_local = F.addn(gk_local) 273 gk_pad = self.allgather(gk_local) if self.rank_size > 1 else gk_local 274 gk_pad = F.reshape(gk_pad, (-1, 1)) 275 gk = gk_pad[0:self.n_components, :] 276 return gk, old_loss, gk_local 277 278 def _line_search(self, gk, dk, dn, old_loss, weight, weight_clone, *inputs): 279 """line search rho.""" 280 res = self.rho_list[-1] 281 find = self.false_flag 282 for i in range(self.max_search_time): 283 find = self._find_rho(gk, dk, dn, old_loss, weight, weight_clone, self.rho_list[i], *inputs) 284 if find: 285 res = self.rho_list[i] 286 break 287 return res, find 288 289 def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs): 290 """search rho.""" 291 res = self.false_flag 292 sn = self.hyper_map(F.partial(self.mul, -1 * rho), dn) 293 sn = F.depend(sn, old_loss) 294 update = self.optimizer(sn) 295 new_loss = F.depend(self.network(*inputs), update) 296 if self.rank_size > 1: 297 new_loss = self.allreduce(new_loss) // self.rank_size 298 old_loss_delta = old_loss + self.sigma * rho * F.squeeze(self.matmul(F.transpose(gk, (1, 0)), dk)) 299 if old_loss_delta > new_loss: 300 _save_weight(self.sk, rho * dk) 301 res = self.true_flag 302 weight_clone = F.depend(weight_clone, old_loss_delta) 303 restore = self.hyper_map(_save_weight, weight, weight_clone) 304 res = F.depend(res, restore) 305 return res 306 307 def _apply_quasi_newton_update(self, gk): 308 """apply quasi_newton update.""" 309 if self.gk_last_init: 310 yk = gk - self.gk_last 311 g = self.matmul(F.transpose(yk, (1, 0)), self.sk) 312 g = F.squeeze(g) 313 if g > self.epsilon: 314 pk = 1. / g 315 t1 = self.eye - self.matmul(pk * yk, F.transpose(self.sk, (1, 0))) 316 new_bk = self.matmul(self.matmul(F.transpose(t1, (1, 0)), self.bk), t1) + \ 317 self.matmul(pk * self.sk, F.transpose(self.sk, (1, 0))) 318 _save_weight(self.bk, new_bk) 319 else: 320 _save_weight(self.gk_last_init, self.true_flag) 321 _save_weight(self.gk_last, gk) 322 dk = -1 * self.matmul(self.bk, gk) 323 return dk 324