• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""base process"""
16import copy
17import mindspore.nn as nn
18from mindspore.nn.optim import LARS
19from mindspore import log as logger
20from mindspore.common import Parameter
21from .less_batch_normalization import CommonHeadLastFN
22
23
24__all__ = ["OptimizerProcess", "ParameterProcess"]
25
26
27class OptimizerProcess:
28    """
29    Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags
30    and creating new optimizers.
31
32    Args:
33       opt (Cell): Optimizer used.
34
35    Examples:
36        >>> from mindspore import Tensor, Parameter, nn
37        >>> import mindspore.ops import ops
38        >>> from mindspore.boost import OptimizerProcess
39        >>>
40        >>> class Net(nn.Cell):
41        ...     def __init__(self, in_features, out_features):
42        ...         super(Net, self).__init__()
43        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
44        ...                                 name='weight')
45        ...         self.matmul = ops.MatMul()
46        ...
47        ...     def construct(self, x):
48        ...         output = self.matmul(x, self.weight)
49        ...         return output
50        ...
51        >>> size, in_features, out_features = 16, 16, 10
52        >>> network = Net(in_features, out_features)
53        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
54        >>> optimizer_process = OptimizerProcess(optimizer)
55        >>> optimizer_process.add_grad_centralization(network)
56        >>> optimizer = optimizer_process.generate_new_optimizer()
57    """
58    def __init__(self, opt):
59        if isinstance(opt, LARS):
60            self.is_lars = True
61            self.opt_class = type(opt.opt)
62            self.opt_init_args = opt.opt.init_args
63            self.lars_init_args = opt.init_args
64        else:
65            self.is_lars = False
66            self.opt_class = type(opt)
67            self.opt_init_args = opt.init_args
68        self.origin_params = opt.init_params["params"]
69
70    def build_params_dict(self, network):
71        """Build the params dict of the network"""
72        cells = network.cells_and_names()
73        params_dict = {}
74        for _, cell in cells:
75            for par in cell.get_parameters(expand=False):
76                params_dict[id(par)] = cell
77        return params_dict
78
79    def build_gc_params_group(self, params_dict, parameters):
80        """Build the params group that needs gc"""
81        group_params = []
82        for group_param in parameters:
83            if 'order_params' in group_param.keys():
84                group_params.append(group_param)
85                continue
86            params_gc_value = []
87            params_value = []
88            for param in group_param['params']:
89                if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
90                    param_cell = params_dict[id(param)]
91                    if (isinstance(param_cell, nn.Conv2d) and param_cell.group > 1) or \
92                        isinstance(param_cell, CommonHeadLastFN):
93                        params_value.append(param)
94                    else:
95                        params_gc_value.append(param)
96                else:
97                    params_value.append(param)
98            if params_gc_value:
99                new_group_param = copy.deepcopy(group_param)
100                new_group_param['params'] = params_gc_value
101                new_group_param['grad_centralization'] = True
102                group_params.append(new_group_param)
103            if params_value:
104                new_group_param = copy.deepcopy(group_param)
105                new_group_param['params'] = params_value
106                group_params.append(new_group_param)
107        return group_params
108
109    def add_grad_centralization(self, network):
110        """Add gradient centralization."""
111        params_dict = self.build_params_dict(network)
112
113        parameters = self.origin_params
114        if parameters is not None and not isinstance(parameters, list):
115            parameters = list(parameters)
116
117        if not parameters:
118            raise ValueError("Optimizer got an empty parameter list.")
119
120        if not isinstance(parameters[0], (dict, Parameter)):
121            raise TypeError("Only a list of Parameter or dict can be supported.")
122
123        if isinstance(parameters[0], Parameter):
124            logger.warning("Only group parameters support gradient centralization.")
125            return
126
127        self.origin_params = self.build_gc_params_group(params_dict, parameters)
128
129    def generate_new_optimizer(self):
130        """Generate new optimizer."""
131        if not self.is_lars:
132            opt = self.opt_class(params=self.origin_params, **self.opt_init_args)
133        else:
134            opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args)
135
136        return opt
137
138
139class ParameterProcess:
140    """
141    Process parameter for Boost. Currently, this class supports creating group parameters
142    and automatically setting gradient segmentation point.
143
144    Examples:
145        >>> from mindspore import Tensor, Parameter, nn
146        >>> import mindspore.ops as ops
147        >>> from mindspore.boost import OptimizerProcess
148        >>>
149        >>> class Net(nn.Cell):
150        ...     def __init__(self, in_features, out_features):
151        ...         super(Net, self).__init__()
152        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
153        ...                                 name='weight')
154        ...         self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
155        ...                                 name='weight2')
156        ...         self.matmul = ops.MatMul()
157        ...         self.matmul2 = ops.MatMul()
158        ...
159        ...     def construct(self, x):
160        ...         output = self.matmul(x, self.weight)
161        ...         output2 = self.matmul2(x, self.weight2)
162        ...         return output + output2
163        ...
164        >>> size, in_features, out_features = 16, 16, 10
165        >>> network = Net(in_features, out_features)
166        >>> new_parameter = net.trainable_params()[:1]
167        >>> parameter_process = ParameterProcess()
168        >>> group_params = parameter_process.generate_group_params(new_parameter, net.trainable_params())
169    """
170    def __init__(self):
171        self._parameter_indices = 1
172
173    def assign_parameter_group(self, parameters, split_point=None):
174        """Assign parameter group."""
175        if not isinstance(parameters, (list, tuple)) or not parameters:
176            return parameters
177
178        parameter_len = len(parameters)
179        if split_point:
180            split_parameter_index = split_point
181        else:
182            split_parameter_index = [parameter_len // 2]
183        for i in range(parameter_len):
184            if i in split_parameter_index:
185                self._parameter_indices += 1
186            parameters[i].comm_fusion = self._parameter_indices
187        return parameters
188
189    def generate_group_params(self, parameters, origin_params):
190        """Generate group parameters."""
191        origin_params_copy = origin_params
192        if origin_params_copy is not None:
193            if not isinstance(origin_params_copy, list):
194                origin_params_copy = list(origin_params_copy)
195
196        if not origin_params_copy:
197            raise ValueError("Optimizer got an empty parameter list.")
198
199        if not isinstance(origin_params_copy[0], (dict, Parameter)):
200            raise TypeError("Only a list of Parameter or dict can be supported.")
201
202        if isinstance(origin_params_copy[0], Parameter):
203            group_params = [{"params": parameters}]
204            return group_params
205
206        group_params = []
207        params_name = [param.name for param in parameters]
208        new_params_count = copy.deepcopy(params_name)
209        new_params_clone = {}
210        max_key_number = 0
211        for group_param in origin_params_copy:
212            if 'order_params' in group_param.keys():
213                new_group_param = copy.deepcopy(group_param)
214                new_group_param['order_params'] = parameters
215                group_params.append(new_group_param)
216                continue
217            params_value = []
218            for param in group_param['params']:
219                if param.name in params_name:
220                    index = params_name.index(param.name)
221                    params_value.append(parameters[index])
222                    new_params_count.remove(param.name)
223            new_group_param = copy.deepcopy(group_param)
224            new_group_param['params'] = params_value
225            group_params.append(new_group_param)
226            if len(group_param.keys()) > max_key_number:
227                max_key_number = len(group_param.keys())
228                new_params_clone = copy.deepcopy(group_param)
229        if new_params_count:
230            params_value = []
231            for param in new_params_count:
232                index = params_name.index(param)
233                params_value.append(parameters[index])
234            if new_params_clone:
235                new_params_clone['params'] = params_value
236                group_params.append(new_params_clone)
237            else:
238                group_params.append({"params": params_value})
239        return group_params
240