1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""boost""" 16from .less_batch_normalization import LessBN 17from .grad_freeze import GradientFreeze 18from .base import OptimizerProcess, ParameterProcess 19 20 21__all__ = ["AutoBoost"] 22 23 24_boost_config_level = { 25 "O0": { 26 "less_bn": False, 27 "grad_freeze": False, 28 "adasum": False}, 29 "O1": { 30 "less_bn": True, 31 "grad_freeze": True, 32 "adasum": False}, 33 "O2": { 34 "less_bn": True, 35 "grad_freeze": True, 36 "adasum": True}} 37 38 39class AutoBoost: 40 """ 41 Provide auto accelerating for network. 42 43 Args: 44 level (str): boost config level. 45 kwargs (any): Additional configuration parameters related to boost. 46 """ 47 def __init__(self, level, kwargs): 48 if level not in _boost_config_level.keys(): 49 level = 'O0' 50 self.level = level 51 boost_config = _boost_config_level[level] 52 self._boost_config = boost_config 53 self._fn_flag = True 54 self._gc_flag = True 55 self._param_groups = 10 56 self._freeze_type = 1 57 self._freeze_p = 0.7 58 self._total_steps = 65536 59 self._gradient_groups = None 60 self._get_configuration(kwargs) 61 self._param_processer = ParameterProcess() 62 63 def _get_configuration(self, kwargs): 64 """Get configuration.""" 65 for key, val in kwargs.items(): 66 if key not in self._boost_config_func_map.keys(): 67 continue 68 self._boost_config_func_map[key](self, val) 69 70 def network_auto_process_train(self, network, optimizer): 71 """Network train.""" 72 if self._boost_config["less_bn"]: 73 network = LessBN(network, fn_flag=self._fn_flag) 74 optimizer_process = OptimizerProcess(optimizer) 75 group_params = self._param_processer.assign_parameter_group(network.trainable_params(), 76 self._gradient_groups) 77 optimizer_process.origin_params = \ 78 self._param_processer.generate_group_params(group_params, optimizer_process.origin_params) 79 if self._gc_flag: 80 optimizer_process.add_grad_centralization(network) 81 optimizer = optimizer_process.generate_new_optimizer() 82 83 if self._boost_config["grad_freeze"]: 84 freeze_processer = GradientFreeze(self._param_groups, self._freeze_type, 85 self._freeze_p, self._total_steps) 86 network, optimizer = freeze_processer.freeze_generate(network, optimizer) 87 88 if self._boost_config["adasum"]: 89 setattr(optimizer, "adasum", True) 90 return network, optimizer 91 92 def network_auto_process_eval(self, network): 93 """Network eval.""" 94 if self._boost_config["less_bn"]: 95 network = LessBN(network) 96 97 return network 98 99 def set_fn_flag(self, fn_flag): 100 self._fn_flag = fn_flag 101 102 def set_gc_flag(self, gc_flag): 103 self._gc_flag = gc_flag 104 105 def set_param_groups(self, param_groups): 106 self._param_groups = param_groups 107 108 def set_freeze_type(self, freeze_type): 109 self._freeze_type = freeze_type 110 111 def set_freeze_p(self, freeze_p): 112 self._freeze_p = freeze_p 113 114 def set_total_steps(self, total_steps): 115 self._total_steps = total_steps 116 117 def set_gradient_groups(self, gradient_groups): 118 if not isinstance(gradient_groups, (list, int)): 119 raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)") 120 if isinstance(gradient_groups, int): 121 gradient_groups = list(gradient_groups) 122 self._gradient_groups = gradient_groups 123 124 _boost_config_func_map = { 125 "fn_flag": set_fn_flag, 126 "gc_flag": set_gc_flag, 127 "param_groups": set_param_groups, 128 "freeze_type": set_freeze_type, 129 "freeze_p": set_freeze_p, 130 "total_steps": set_total_steps, 131 "gradient_groups": set_gradient_groups 132 } 133