• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""boost"""
16from __future__ import absolute_import
17
18import threading
19from mindspore.nn.optim import SGD
20from mindspore.boost.less_batch_normalization import LessBN
21from mindspore.boost.grad_freeze import GradientFreeze
22from mindspore.boost.base import OptimizerProcess, ParameterProcess
23from mindspore.boost.base import _get_local_pca_mat_path
24
25
26__all__ = ["AutoBoost"]
27
28_boost_config_mode = ["auto", "manual", "enable_all", "disable_all"]
29_boost_config_level = {
30    "O0": {
31        "less_bn": False,
32        "grad_freeze": False,
33        "adasum": False,
34        "grad_accumulation": False,
35        "dim_reduce": False,
36        'loss_scale_group': False},
37    "O1": {
38        "less_bn": True,
39        "grad_freeze": True,
40        "adasum": False,
41        "grad_accumulation": False,
42        "dim_reduce": False,
43        'loss_scale_group': False},
44    "O2": {
45        "less_bn": True,
46        "grad_freeze": True,
47        "adasum": True,
48        "grad_accumulation": False,
49        "dim_reduce": False,
50        'loss_scale_group': False}
51    }
52
53
54class AutoBoost:
55    r"""
56    Provide auto accelerating for network.
57
58    Args:
59        level (str): Boost config level. Default: ``"O0"`` .
60        boost_config_dict (dict): User config hyperparameter dict, recommended config format:
61
62            .. code-block::
63
64                {
65                    "boost": {
66                        "mode": "auto",
67                        "less_bn": False,
68                        "grad_freeze": False,
69                        "adasum": False,
70                        "grad_accumulation": False,
71                        "dim_reduce": False,
72                        "loss_scale_group": False
73                    },
74                    "common": {
75                        "gradient_split_groups": [50, 100],
76                        "device_number": 8
77                    },
78                    "less_bn": {
79                        "fn_flag": True,
80                        "gc_flag": True
81                    },
82                    "grad_freeze": {
83                        "param_groups": 10,
84                        "freeze_type": 1,
85                        "freeze_p": 0.7,
86                        "total_steps": 65536
87                    }
88                    "dim_reduce": {
89                        "rho": 0.55,
90                        "gamma": 0.9,
91                        "alpha": 0.001,
92                        "sigma": 0.4,
93                        "n_components": 32,
94                        "pca_mat_path": None,
95                        "weight_load_dir": None,
96                        "timeout": 1800
97                    }
98                }
99
100            Default: ``""`` .
101
102            - boost:
103
104              - mode (str): How to set the boost. Supports ["auto", "manual", "enable_all", "disable_all"].
105                Default: ``"auto"`` .
106
107                - auto: Depend on the argument "boost_level" in class Model.
108                - manual: Depend on "boost_config_dict".
109                - enable_all: Set all boost functions true.
110                - disable_all: Set all boost functions false.
111
112              - less_bn (bool): Whether to apply less_bn function. Default: ``False`` .
113              - grad_freeze: (bool): Whether to apply grad_freeze function. Default: ``False`` .
114              - adasum (bool): Whether to apply adasum function. Default: ``False`` .
115              - grad_accumulation (bool): Whether to apply grad_accumulation function. Default: ``False`` .
116              - dim_reduce (bool): Whether to apply dim_reduce function. Default: ``False`` .
117              - loss_scale_group (bool): Whether to apply loss_scale_group function. Default: ``False`` .
118
119              If set dim_reduce true, other functions will be false.
120              If set grad_freeze true and dim_reduce false, other functions will be false.
121
122            - common:
123
124              - gradient_split_groups (list): The gradient split point of this network. Default: ``[50, 100]`` .
125              - device_number (int): Device number. Default: ``8`` .
126
127            - less_bn:
128
129              - fn_flag (bool): Whether changing fc to fn. Default: ``True`` .
130              - gc_flag (bool): Whether to apply gc. Default: ``True`` .
131
132            - grad_freeze:
133
134              - param_groups (int): The number of parameter groups. Default: ``10`` .
135              - freeze_type (int): Gradient freeze grouping strategy, select from [0, 1]. Default: ``1`` .
136              - freeze_p (float): Gradient freezing probability. Default: ``0.7`` .
137              - total_steps (int): Total training steps. Default: ``65536`` .
138
139            - dim_reduce:
140
141              The leading principles of dim_reduce:
142
143              .. math::
144
145                  \begin{align}
146                  grad\_k &= pca\_mat \cdot grad\\
147                  dk &= - bk \cdot grad\_k\\
148                  sk &= rho ^ m \cdot dk\\
149                  delta\_loss &= sigma \cdot grad\_k.T \cdot sk
150                  \end{align}
151
152              Here:
153
154              - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
155              - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
156
157              we need to find the m satisfy:
158
159              .. math::
160                  new\_loss < old\_loss + delta\_loss
161
162              Then, get delta_grad to update the weights for model:
163
164              .. math::
165
166                  \begin{align}
167                  grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\
168                  new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\
169                  delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk
170                  \end{align}
171
172              - rho (float): Generally, it does not need to be modified. Default: ``0.55`` .
173              - gamma (float): Generally, it does not need to be modified. Default: ``0.9`` .
174              - alpha (float): Generally, it does not need to be modified. Default: ``0.001`` .
175              - sigma (float): Generally, it does not need to be modified. Default: ``0.4`` .
176              - n_components (int): PCA component. Default: ``32`` .
177              - pca_mat_path (str): The path to load pca mat. Default: ``None`` .
178              - weight_load_dir (str): The directory to load weight files saved as ckpt. Default: ``None`` .
179              - timeout (int): Waiting time to load local pca mat. Default: ``1800 (second)`` .
180
181            User can load the config through the JSON file or use the dictionary directly.
182            The unconfigured parameters will adopt the default values.
183
184    Raises:
185        ValueError: The boost mode not in ["auto", "manual", "enable_all", "disable_all"].
186
187    Supported Platforms:
188        ``Ascend``
189
190    Examples:
191        >>> from mindspore.boost import AutoBoost
192        >>> #1) when configuring the dict directly:
193        >>> boost_config_dict = {"boost": {"mode": "auto"}}
194        >>> boost = AutoBoost("O1", boost_config_dict)
195        >>>
196        >>> #2) when loading the dict from a json file:
197        >>> import json
198        >>> boost_json = "/path/boost_config.json"
199        >>> with open(boost_json, 'r') as fp:
200        ...     boost_config_dict = json.load(fp)
201        >>> boost = AutoBoost("O1", boost_config_dict)
202    """
203    _instance_lock = threading.Lock()
204    _instance = None
205
206    # pylint: disable=unused-argument
207    def __new__(cls, *args, **kwargs):
208        if AutoBoost._instance is None:
209            with AutoBoost._instance_lock:
210                if AutoBoost._instance is None:
211                    AutoBoost._instance = object.__new__(cls)
212                    AutoBoost._instance.level = None
213                    AutoBoost._instance.boost_config_dict = None
214        return AutoBoost._instance
215
216    def __init__(self, level="O0", boost_config_dict=""):
217        if level not in _boost_config_level.keys():
218            level = "O0"
219        if self._instance.level is None:
220            self.level = level
221            self.boost_config_dict = boost_config_dict
222            self._fn_flag = True
223            self._gc_flag = True
224            self._param_groups = 10
225            self._freeze_type = 1
226            self._freeze_p = 0.7
227            self._total_steps = 65536
228            self.gradient_groups = None
229            self.device_number = 8
230            self.grad_accumulation_step = 1
231            self.rho = 0.55
232            self.gamma = 0.9
233            self.alpha = 0.001
234            self.sigma = 0.4
235            self.n_components = 32
236            self.pca_mat_path = None
237            self.weight_load_dir = None
238            self.local_pca_mat_path = None
239            self.timeout = 1800
240            self.boost_config = self._get_configuration(level, self.boost_config_dict)
241            self._param_processer = ParameterProcess()
242
243    def network_auto_process_train(self, network, optimizer):
244        r"""
245        Boost network train.
246
247        Args:
248            network (Cell): The training network.
249            optimizer (Cell): Optimizer for updating the weights.
250        """
251        if self.boost_config.get("dim_reduce"):
252            self.local_pca_mat_path = _get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path,
253                                                              self.n_components, self.device_number, network)
254            optimizer = SGD(network.trainable_params(), learning_rate=1)
255            setattr(optimizer, "dim_reduce", True)
256            return network, optimizer
257
258        if self.boost_config.get("less_bn"):
259            network = LessBN(network, fn_flag=self._fn_flag)
260            optimizer_process = OptimizerProcess(optimizer)
261            group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
262                                                                        self.gradient_groups)
263            optimizer_process.origin_params = \
264                ParameterProcess.generate_group_params(group_params, optimizer_process.origin_params)
265            if self._gc_flag:
266                optimizer_process.add_grad_centralization(network)
267            optimizer = optimizer_process.generate_new_optimizer()
268
269        if self.boost_config.get("grad_freeze"):
270            freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
271                                              self._freeze_p, self._total_steps)
272            network, optimizer = freeze_processer.freeze_generate(network, optimizer)
273
274        if self.boost_config.get("adasum"):
275            setattr(optimizer, "adasum", True)
276        return network, optimizer
277
278    def network_auto_process_eval(self, network):
279        r"""
280        Boost network eval.
281
282        Args:
283            network (Cell): The inference network.
284        """
285        if self.boost_config.get("dim_reduce"):
286            return network
287        if self.boost_config.get("less_bn"):
288            network = LessBN(network)
289
290        return network
291
292    def _set_fn_flag(self, fn_flag):
293        self._fn_flag = fn_flag
294
295    def _set_gc_flag(self, gc_flag):
296        self._gc_flag = gc_flag
297
298    def _set_param_groups(self, param_groups):
299        self._param_groups = param_groups
300
301    def _set_freeze_type(self, freeze_type):
302        self._freeze_type = freeze_type
303
304    def _set_freeze_p(self, freeze_p):
305        self._freeze_p = freeze_p
306
307    def _set_total_steps(self, total_steps):
308        self._total_steps = total_steps
309
310    def _set_device_number(self, device_number):
311        self.device_number = device_number
312
313    def _set_grad_accumulation_step(self, grad_accumulation_step):
314        self.grad_accumulation_step = grad_accumulation_step
315
316    def _set_gradient_split_groups(self, gradient_groups):
317        if not isinstance(gradient_groups, (list, int)):
318            raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
319        if isinstance(gradient_groups, int):
320            gradient_groups = list(gradient_groups)
321        self.gradient_groups = gradient_groups
322
323    def _set_rho(self, rho):
324        self.rho = rho
325
326    def _set_gamma(self, gamma):
327        self.gamma = gamma
328
329    def _set_alpha(self, alpha):
330        self.alpha = alpha
331
332    def _set_sigma(self, sigma):
333        self.sigma = sigma
334
335    def _set_n_components(self, n_components):
336        self.n_components = n_components
337
338    def _set_pca_mat_path(self, pca_mat_path):
339        self.pca_mat_path = pca_mat_path
340
341    def _set_weight_load_dir(self, weight_load_dir):
342        self.weight_load_dir = weight_load_dir
343
344    def _set_timeout(self, timeout):
345        self.timeout = timeout
346
347    def _get_configuration(self, level, boost_config_dict):
348        """Get configuration."""
349        level_config = _boost_config_level.get(level)
350        if not boost_config_dict:
351            return level_config
352
353        mode = "auto"
354        if 'boost' in boost_config_dict and 'mode' in boost_config_dict['boost']:
355            mode = boost_config_dict['boost']['mode']
356        if mode not in _boost_config_mode:
357            raise ValueError("The boost mode must be in {}, but got {}".format(_boost_config_mode, mode))
358
359        if mode == "manual":
360            for key, value in boost_config_dict["boost"].items():
361                if key in level_config:
362                    level_config[key] = value
363        elif mode == "enable_all":
364            level_config = {key: True for key in level_config}
365        elif mode == "disable_all":
366            level_config = {key: False for key in level_config}
367
368        self._do_new_config_func(boost_config_dict, level_config)
369        return level_config
370
371    def _do_new_config_func(self, boost_config_dict, level_config):
372        valid_boost_each_mode_config = []
373        for key, boost_each_mode_config in boost_config_dict.items():
374            if key in level_config.keys() and level_config[key] or key == "common":
375                valid_boost_each_mode_config.append(boost_each_mode_config)
376
377        for boost_each_mode_config in valid_boost_each_mode_config:
378            for key_s in boost_each_mode_config.keys():
379                if key_s in self._boost_config_func_map:
380                    self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s])
381
382    _boost_config_func_map = {
383        "fn_flag": _set_fn_flag,
384        "gc_flag": _set_gc_flag,
385        "param_groups": _set_param_groups,
386        "freeze_type": _set_freeze_type,
387        "freeze_p": _set_freeze_p,
388        "total_steps": _set_total_steps,
389        "device_number": _set_device_number,
390        "gradient_split_groups": _set_gradient_split_groups,
391        "grad_accumulation_step": _set_grad_accumulation_step,
392        "rho": _set_rho,
393        "gamma": _set_gamma,
394        "alpha": _set_alpha,
395        "sigma": _set_sigma,
396        "n_components": _set_n_components,
397        "pca_mat_path": _set_pca_mat_path,
398        "weight_load_dir": _set_weight_load_dir,
399        "timeout": _set_timeout
400    }
401