• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""boost"""
16from .less_batch_normalization import LessBN
17from .grad_freeze import GradientFreeze
18from .base import OptimizerProcess, ParameterProcess
19
20
21__all__ = ["AutoBoost"]
22
23
24_boost_config_level = {
25    "O0": {
26        "less_bn": False,
27        "grad_freeze": False,
28        "adasum": False},
29    "O1": {
30        "less_bn": True,
31        "grad_freeze": True,
32        "adasum": False},
33    "O2": {
34        "less_bn": True,
35        "grad_freeze": True,
36        "adasum": True}}
37
38
39class AutoBoost:
40    """
41    Provide auto accelerating for network.
42
43    Args:
44       level (str): boost config level.
45       kwargs (any): Additional configuration parameters related to boost.
46    """
47    def __init__(self, level, kwargs):
48        if level not in _boost_config_level.keys():
49            level = 'O0'
50        self.level = level
51        boost_config = _boost_config_level[level]
52        self._boost_config = boost_config
53        self._fn_flag = True
54        self._gc_flag = True
55        self._param_groups = 10
56        self._freeze_type = 1
57        self._freeze_p = 0.7
58        self._total_steps = 65536
59        self._gradient_groups = None
60        self._get_configuration(kwargs)
61        self._param_processer = ParameterProcess()
62
63    def _get_configuration(self, kwargs):
64        """Get configuration."""
65        for key, val in kwargs.items():
66            if key not in self._boost_config_func_map.keys():
67                continue
68            self._boost_config_func_map[key](self, val)
69
70    def network_auto_process_train(self, network, optimizer):
71        """Network train."""
72        if self._boost_config["less_bn"]:
73            network = LessBN(network, fn_flag=self._fn_flag)
74            optimizer_process = OptimizerProcess(optimizer)
75            group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
76                                                                        self._gradient_groups)
77            optimizer_process.origin_params = \
78                self._param_processer.generate_group_params(group_params, optimizer_process.origin_params)
79            if self._gc_flag:
80                optimizer_process.add_grad_centralization(network)
81            optimizer = optimizer_process.generate_new_optimizer()
82
83        if self._boost_config["grad_freeze"]:
84            freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
85                                              self._freeze_p, self._total_steps)
86            network, optimizer = freeze_processer.freeze_generate(network, optimizer)
87
88        if self._boost_config["adasum"]:
89            setattr(optimizer, "adasum", True)
90        return network, optimizer
91
92    def network_auto_process_eval(self, network):
93        """Network eval."""
94        if self._boost_config["less_bn"]:
95            network = LessBN(network)
96
97        return network
98
99    def set_fn_flag(self, fn_flag):
100        self._fn_flag = fn_flag
101
102    def set_gc_flag(self, gc_flag):
103        self._gc_flag = gc_flag
104
105    def set_param_groups(self, param_groups):
106        self._param_groups = param_groups
107
108    def set_freeze_type(self, freeze_type):
109        self._freeze_type = freeze_type
110
111    def set_freeze_p(self, freeze_p):
112        self._freeze_p = freeze_p
113
114    def set_total_steps(self, total_steps):
115        self._total_steps = total_steps
116
117    def set_gradient_groups(self, gradient_groups):
118        if not isinstance(gradient_groups, (list, int)):
119            raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
120        if isinstance(gradient_groups, int):
121            gradient_groups = list(gradient_groups)
122        self._gradient_groups = gradient_groups
123
124    _boost_config_func_map = {
125        "fn_flag": set_fn_flag,
126        "gc_flag": set_gc_flag,
127        "param_groups": set_param_groups,
128        "freeze_type": set_freeze_type,
129        "freeze_p": set_freeze_p,
130        "total_steps": set_total_steps,
131        "gradient_groups": set_gradient_groups
132    }
133