• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""dim_reduce"""
16from __future__ import absolute_import
17
18import math
19import numpy as np
20from mindspore.nn.cell import Cell
21from mindspore.ops import composite as C
22from mindspore.ops import functional as F
23from mindspore.ops import operations as P
24from mindspore.common.tensor import Tensor
25from mindspore.common.parameter import Parameter, ParameterTuple
26from mindspore.common import dtype as mstype
27
28
29__all__ = ["DimReduce"]
30
31
32_scale_grad = C.MultitypeFuncGraph("_scale_grad")
33
34
35@_scale_grad.register("Tensor", "Tensor")
36def _scale_grad_process(scale, grad):
37    grad = F.cast(grad, mstype.float32)
38    grad = P.Div()(grad, scale)
39    return grad
40
41
42_save_weight = C.MultitypeFuncGraph("_save_weight")
43
44
45@_save_weight.register("Tensor", "Tensor")
46def _save_weight_process(parameter, new_parameter):
47    P.Assign()(parameter, new_parameter)
48    return parameter
49
50
51_pca_projection = C.MultitypeFuncGraph("_pca_projection")
52
53
54@_pca_projection.register("Tensor", "Tensor")
55def _pca_projection_process(pca_mat, grad):
56    grad_k = P.MatMul()(pca_mat, F.reshape(grad, (-1, 1)))
57    return grad_k
58
59
60_pca_back_projection = C.MultitypeFuncGraph("_pca_back_projection")
61
62
63@_pca_back_projection.register("Tensor", "Tensor", "Tensor")
64def _pca_back_projection_process(grad_k, pca_mat, grad):
65    grad_proj = P.MatMul()(F.transpose(pca_mat, (1, 0)), grad_k)
66    grad_proj_reshape = F.reshape(grad_proj, F.shape(grad))
67    return grad_proj_reshape
68
69
70_update_grad_res_momentum = C.MultitypeFuncGraph("_update_grad_res_momentum")
71
72
73@_update_grad_res_momentum.register("Float32", "Float32", "Tensor", "Tensor", "Tensor")
74def _update_grad_res_momentum_process(gamma, alpha, grad_res_momentum, grad, grad_proj):
75    grad_res_momentum_new = gamma * grad_res_momentum + grad - grad_proj
76    P.Assign()(grad_res_momentum, grad_res_momentum_new)
77    res = alpha * grad_res_momentum_new
78    return res
79
80
81_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
82
83
84@_get_delta_weight.register("Tensor", "Tensor", "Tensor")
85def _get_delta_weight_process(rho, dn, grad_res_momentum):
86    delta_weight = grad_res_momentum - rho * dn
87    return delta_weight
88
89
90class DimReduce(Cell):
91    r"""
92    The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.
93
94    .. math::
95
96            \begin{align}
97            grad\_k &= pca\_mat \cdot grad\\
98            dk &= - bk \cdot grad\_k\\
99            sk &= rho ^ m \cdot dk\\
100            delta\_loss &= sigma \cdot grad\_k.T \cdot sk
101            \end{align}
102
103    Here:
104
105    - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
106    - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
107
108    we need to find the m satisfy:
109
110    .. math::
111            new\_loss < old\_loss + delta\_loss
112
113    Then, get delta_grad to update the weights for model:
114
115    .. math::
116
117            \begin{align}
118            grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\
119            new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\
120            delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk
121            \end{align}
122
123    Args:
124        network (Cell): The training network. The network only supports single output.
125        optimizer (Union[Cell]): Optimizer for updating the weights.
126        weight (Tuple(Parameter)): Tuple of parameters.
127        pca_mat_local (numpy.ndarray): For PCA operation, k*n, k is part of n_components, n is the size of weight.
128        n_components (int): PCA.components.
129        rho (float): Coefficient.
130        gamma (float): Coefficient.
131        alpha (float): Coefficient.
132        sigma (float): Coefficient.
133        rank (int): Rank number.
134        rank_size (int): Rank size.
135
136    Inputs:
137        - **loss** (Tensor) - Tensor with shape :math:`()`.
138        - **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
139        - **weight** (Tuple(Tensor)) - Tuple of parameters.
140        - **weight_clone** (Tuple(Tensor)) - clone of weight
141        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
142
143    Outputs:
144        - **loss** (Tensor) - Tensor with shape :math:`()`.
145    """
146    def __init__(self, network, optimizer, weight, pca_mat_local, n_components, rho, gamma, alpha, sigma, rank,
147                 rank_size):
148        super(DimReduce, self).__init__()
149        self.network = network
150        self.optimizer = optimizer
151        self.rank = rank
152        self.rank_size = rank_size
153        self.gamma = gamma
154        self.alpha = alpha
155        self.sigma = sigma
156
157        self.float_type = mstype.float32
158        self._set_rho_list(rho)
159        self._set_local_pca_mat(pca_mat_local, n_components, weight)
160        self._set_init_parameter(weight)
161
162        self.hyper_map = C.HyperMap()
163        self.concat = P.Concat()
164        self.matmul = P.MatMul()
165        self.mul = P.Mul()
166        self.add = P.Add()
167
168    def construct(self, loss, old_grad, loss_scale, weight, weight_clone, *inputs):
169        gk, old_loss, gk_local = self._generate_gk(weight, loss, old_grad, loss_scale)
170
171        _save_weight(self.gk_last_back, self.gk_last)
172        _save_weight(self.bk_back, self.bk)
173
174        dk = self._apply_quasi_newton_update(gk)
175        if self.dk_pad_flag:
176            dk_pad = self.concat((dk, self.dk_pad_part))
177        else:
178            dk_pad = dk
179        dk_local = dk_pad[self.start_index: self.end_index, :]
180
181        dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
182        grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
183        dn = self.dn_init if self.rank_size > 1 else dn_local
184        grad_proj = self.grad_proj_init if self.rank_size > 1 else grad_proj_local
185        if self.rank_size > 1:
186            for broadcast in self.broadcast_list:
187                dn_part = broadcast(dn_local)
188                dn = self.hyper_map(self.add, dn, dn_part)
189                grad_proj_part = broadcast(grad_proj_local)
190                grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
191
192        rho, find = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
193        if not find:
194            _save_weight(self.gk_last, self.gk_last_back)
195            _save_weight(self.bk, self.bk_back)
196
197        clone = self._res_loss(old_grad, grad_proj, weight, weight_clone, rho, dn)
198        return F.depend(loss, clone)
199
200    def _set_rho_list(self, rho):
201        """set rho list info."""
202        self.max_search_time = 2
203        self.rho_list = []
204        for i in range(self.max_search_time):
205            self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
206        self.rho_list.append(Tensor(0, dtype=self.float_type))
207
208    def _set_local_pca_mat(self, pca_mat_local, n_components, parameter_tuple):
209        """set pca info."""
210        self.n_components = n_components
211        local_dim = math.ceil(self.n_components // self.rank_size)
212
213        self.start_index = self.rank * local_dim
214        self.end_index = (self.rank + 1) * local_dim
215
216        start = 0
217        self.pca_list_local = ()
218        for param in parameter_tuple:
219            size = np.shape(param.asnumpy().reshape((-1, 1)))[0]
220            self.pca_list_local += (Tensor(pca_mat_local[:, start:start + size], dtype=self.float_type),)
221            start += size
222
223        self.dk_pad_flag = False
224        pad_num = self.rank_size * local_dim - self.n_components
225        if pad_num:
226            self.dk_pad_flag = True
227            self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
228
229        if self.rank_size > 1:
230            self.broadcast_list = []
231            for i in range(self.rank_size):
232                broadcast = P.Broadcast(i)
233                self.broadcast_list.append(broadcast)
234            self.allreduce = P.AllReduce()
235            self.allgather = P.AllGather()
236
237    def _set_init_parameter(self, parameter_tuple):
238        """init parameters."""
239        self.true_flag = Tensor(True)
240        self.false_flag = Tensor(False)
241        self.epsilon = np.power(10.0, -20)
242        self.gk_last = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="gk_last")
243        self.gk_last_init = Parameter(Tensor(False), name="gk_last_init")
244        self.bk = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk")
245        self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
246        self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
247        self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
248        self.gk_last_back = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type),
249                                      name="gk_last_back")
250        self.bk_back = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk_back")
251        self.grad_proj_init = ParameterTuple(parameter_tuple).clone(prefix="grad_proj_init", init="zeros")
252        self.dn_init = ParameterTuple(parameter_tuple).clone(prefix="dn_init", init="zeros")
253
254    def _res_loss(self, old_grad, grad_proj, weight, weight_clone, rho, dn):
255        """update loss"""
256        update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
257                                     self.grad_res_momentum, old_grad, grad_proj)
258        delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
259        update = self.optimizer(delta_weight)
260        weight = F.depend(weight, update)
261        clone = self.hyper_map(_save_weight, weight_clone, weight)
262        return clone
263
264    def _generate_gk(self, weight, loss, old_grad, loss_scale):
265        """generate gk"""
266        weight = F.depend(weight, loss)
267        old_grad = F.depend(old_grad, weight)
268        old_grad = self.hyper_map(F.partial(_scale_grad, loss_scale), old_grad)
269        old_loss = self.allreduce(loss) // self.rank_size if self.rank_size > 1 else loss
270
271        gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
272        gk_local = F.addn(gk_local)
273        gk_pad = self.allgather(gk_local) if self.rank_size > 1 else gk_local
274        gk_pad = F.reshape(gk_pad, (-1, 1))
275        gk = gk_pad[0:self.n_components, :]
276        return gk, old_loss, gk_local
277
278    def _line_search(self, gk, dk, dn, old_loss, weight, weight_clone, *inputs):
279        """line search rho."""
280        res = self.rho_list[-1]
281        find = self.false_flag
282        for i in range(self.max_search_time):
283            find = self._find_rho(gk, dk, dn, old_loss, weight, weight_clone, self.rho_list[i], *inputs)
284            if find:
285                res = self.rho_list[i]
286                break
287        return res, find
288
289    def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
290        """search rho."""
291        res = self.false_flag
292        sn = self.hyper_map(F.partial(self.mul, -1 * rho), dn)
293        sn = F.depend(sn, old_loss)
294        update = self.optimizer(sn)
295        new_loss = F.depend(self.network(*inputs), update)
296        if self.rank_size > 1:
297            new_loss = self.allreduce(new_loss) // self.rank_size
298        old_loss_delta = old_loss + self.sigma * rho * F.squeeze(self.matmul(F.transpose(gk, (1, 0)), dk))
299        if old_loss_delta > new_loss:
300            _save_weight(self.sk, rho * dk)
301            res = self.true_flag
302        weight_clone = F.depend(weight_clone, old_loss_delta)
303        restore = self.hyper_map(_save_weight, weight, weight_clone)
304        res = F.depend(res, restore)
305        return res
306
307    def _apply_quasi_newton_update(self, gk):
308        """apply quasi_newton update."""
309        if self.gk_last_init:
310            yk = gk - self.gk_last
311            g = self.matmul(F.transpose(yk, (1, 0)), self.sk)
312            g = F.squeeze(g)
313            if g > self.epsilon:
314                pk = 1. / g
315                t1 = self.eye - self.matmul(pk * yk, F.transpose(self.sk, (1, 0)))
316                new_bk = self.matmul(self.matmul(F.transpose(t1, (1, 0)), self.bk), t1) + \
317                         self.matmul(pk * self.sk, F.transpose(self.sk, (1, 0)))
318                _save_weight(self.bk, new_bk)
319        else:
320            _save_weight(self.gk_last_init, self.true_flag)
321        _save_weight(self.gk_last, gk)
322        dk = -1 * self.matmul(self.bk, gk)
323        return dk
324