• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import mindspore
18from mindspore import nn, Tensor
19from mindspore.ops import operations as P
20from mindspore.nn.optim import ASGD
21from mindspore.nn.optim import Rprop
22from mindspore.nn.optim import AdaMax
23
24np.random.seed(1024)
25
26fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
27                        0.6942514, 0.39767185, 0.24918061, 0.4548748],
28                       [0.7203382, 0.19086994, 0.76286614, 0.87920564,
29                        0.3169892, 0.9462494, 0.62827677, 0.27504718],
30                       [0.3544535, 0.2524781, 0.5370583, 0.8313121,
31                        0.6670143, 0.0488653, 0.62225235, 0.7546456],
32                       [0.17985944, 0.05106374, 0.31064633, 0.4863033,
33                        0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
34
35fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
36
37fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
38
39fc2_bias = np.array([0.09996348]).astype("float32")
40
41
42def make_fake_data():
43    """
44    make fake data
45    """
46    data, label = [], []
47    for i in range(20):
48        data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
49        label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
50    return data, label
51
52
53class NetWithLoss(nn.Cell):
54    """
55    build net with loss
56    """
57
58    def __init__(self, network, loss_fn):
59        super(NetWithLoss, self).__init__()
60        self.network = network
61        self.loss = loss_fn
62
63    def construct(self, x, label):
64        out = self.network(x)
65        loss = self.loss(out, label)
66        return loss
67
68
69class FakeNet(nn.Cell):
70    """
71    build fake net
72    """
73
74    def __init__(self):
75        super(FakeNet, self).__init__()
76        self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
77        self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
78        self.relu = nn.ReLU()
79        self.reducemean = P.ReduceMean()
80
81    def construct(self, x):
82        x = self.relu(self.fc1(x))
83        x = self.fc2(x)
84        return x
85
86    def _initialize_weights(self):
87        """
88        parameter initialization
89        """
90        self.init_parameters_data()
91        for name, m in self.cells_and_names():
92            if name == 'fc1':
93                m.weight.set_data(Tensor(fc1_weight))
94                m.bias.set_data(Tensor(fc1_bias))
95            elif name == 'fc2':
96                m.weight.set_data(Tensor(fc2_weight))
97                m.bias.set_data(Tensor(fc2_bias))
98
99
100def build_network(opt_config, net, is_group=None, loss_fn=None):
101    """
102    Construct training
103    """
104    if is_group is None:
105        is_group = False
106    if loss_fn is None:
107        loss_fn = nn.L1Loss(reduction='mean')
108    losses = []
109    networkwithloss = NetWithLoss(net, loss_fn)
110    networkwithloss.set_train()
111
112    if is_group:
113        fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params()))
114        fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
115        if opt_config['name'] == 'ASGD':
116            params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.1}]
117        elif opt_config['name'] == 'adamax':
118            params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}]
119        elif opt_config['name'] == 'SGD':
120            params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}]
121        else:
122            params = [{'params': fc1_params, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.01}]
123    else:
124        params = networkwithloss.trainable_params()
125
126    if opt_config['name'] == 'ASGD':
127        net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'],
128                       t0=opt_config['t0'], weight_decay=opt_config['weight_decay'])
129
130    elif opt_config['name'] == 'Rprop':
131        net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'],
132                        step_sizes=opt_config['step_sizes'], weight_decay=0.0)
133
134    elif opt_config['name'] == 'adamax':
135        net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'],
136                         beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0)
137    elif opt_config['name'] == 'SGD':
138        net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1)
139    trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
140    data, label = make_fake_data()
141    for i in range(20):
142        loss = trainonestepcell(data[i], label[i])
143        losses.append(loss.asnumpy())
144    return np.array(losses), net_opt
145
146
147default_fc1_weight_asgd = np.array([[0.460443, 0.693057, 0.145399, -0.076741, 0.431228, 0.134655,
148                                     -0.013833, 0.191857],
149                                    [0.391073, -0.138385, 0.433600, 0.549937, -0.012268, 0.616980,
150                                     0.299013, -0.054209],
151                                    [0.064144, -0.037829, 0.246745, 0.540993, 0.376698, -0.241438,
152                                     0.331937, 0.464328],
153                                    [-0.066224, -0.195017, 0.064560, 0.240214, 0.602717, 0.306225,
154                                     -0.043127, 0.475241]], dtype=np.float32)
155default_fc1_bias_asgd = np.array([0.740427, 0.091827, 0.624849, 0.851911], dtype=np.float32)
156default_fc2_weight_asgd = np.array([[0.585555, 0.512303, 0.424419, 0.323499]], dtype=np.float32)
157default_fc2_bias_asgd = np.array([0.059962], dtype=np.float32)
158
159no_default_fc1_weight_asgd = np.array([[0.645291, 0.877900, 0.330253, 0.108117, 0.616077, 0.319509, 0.171024,
160                                        0.376710],
161                                       [0.687056, 0.157610, 0.729583, 0.845918, 0.283724, 0.912958, 0.594999,
162                                        0.241783],
163                                       [0.328432, 0.226461, 0.511030, 0.805272, 0.640981, 0.022857, 0.596221,
164                                        0.728608],
165                                       [0.165102, 0.036311, 0.295884, 0.471533, 0.834030, 0.537543, 0.188198,
166                                        0.706556]], dtype=np.float32)
167no_default_fc1_bias_asgd = np.array([0.785650, 0.131580, 0.658614, 0.878328], dtype=np.float32)
168no_default_fc2_weight_asgd = np.array([[0.374859, -0.049370, -0.068307, -0.115195]], dtype=np.float32)
169no_default_fc2_bias_asgd = np.array([0.083960], dtype=np.float32)
170
171no_default_group_fc1_weight_asgd = np.array([[0.197470, 0.429578, -0.116887, -0.338544, 0.168320, -0.127608,
172                                              -0.275773, -0.070531],
173                                             [0.119964, -0.408341, 0.162399, 0.278482, -0.282498, 0.345379,
174                                              0.028105, -0.324348],
175                                             [-0.168310, -0.270062, 0.013893, 0.307500, 0.143563, -0.473227,
176                                              0.098900, 0.231002],
177                                             [-0.254349, -0.382861, -0.123849, 0.051422, 0.413136, 0.117289,
178                                              -0.231302, 0.285938]], dtype=np.float32)
179no_default_group_fc1_bias_asgd = np.array([0.706595, 0.042866, 0.579553, 0.811499], dtype=np.float32)
180no_default_group_fc2_weight_asgd = np.array([[-0.076689, -0.092399, -0.072100, -0.054189]], dtype=np.float32)
181no_default_group_fc2_bias_asgd = np.array([0.698678], dtype=np.float32)
182
183default_fc1_weight_sgd = np.array([[0.00533873, 0.03210080, -0.03090680, -0.05646387, 0.00197765,
184                                    -0.03214293, -0.04922638, -0.02556189],
185                                   [-0.00658702, -0.06750072, -0.00169432, 0.01169018, -0.05299109,
186                                    0.01940336, -0.01717841, -0.05781638],
187                                   [-0.03723934, -0.04897130, -0.01623122, 0.01762178, -0.00128018,
188                                    -0.07239634, -0.00642990, 0.00880153],
189                                   [-0.04421479, -0.05903235, -0.02916817, -0.00895938, 0.03274637,
190                                    -0.00136485, -0.04155754, 0.01808037]], dtype=np.float32)
191default_fc2_weight_sgd = np.array([[-0.01070179, -0.00702989, -0.00210839, 0.00160410]], dtype=np.float32)
192
193default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
194                                       0.00000000, 0.00000000, 0.00000000],
195                                      [11.18415642, 11.18415642, 11.18415642, 11.18415642, 11.18415642,
196                                       11.18415642, 11.18415642, 11.18415642],
197                                      [-6.70855522, -6.70855522, -6.70855522, -6.70855522, -6.70855522,
198                                       -6.70855522, -6.70855522, -6.70855522],
199                                      [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
200                                       0.00000000, 0.00000000, 0.00000000]], dtype=np.float32)
201default_fc1_bias_adamax = np.array([0.00000000, 0.86349380, -0.51633584, 0.00000000], dtype=np.float32)
202
203no_default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
204                                          0.00000000, 0.00000000, 0.00000000],
205                                         [-4.02891350, -4.02891350, -4.02891350, -4.02891350, -4.02891350,
206                                          -4.02891350, -4.02891350, -4.02891350],
207                                         [3.10859227, 3.10859227, 3.10859227, 3.10859227, 3.10859227,
208                                          3.10859227, 3.10859227, 3.10859227],
209                                         [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
210                                          0.00000000, 0.00000000, 0.00000000]], dtype=np.float32)
211no_default_fc1_bias_adamax = np.array([0.00000000, -0.04809491, 0.06205747, 0.00000000], dtype=np.float32)
212
213default_group_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
214                                             0.00000000, 0.00000000, 0.00000000],
215                                            [11.07278919, 11.07278919, 11.07278919, 11.07278919, 11.07278919,
216                                             11.07278919, 11.07278919, 11.07278919],
217                                            [-6.81674862, -6.81674862, -6.81674862, -6.81674862, -6.81674862,
218                                             -6.81674862, -6.81674862, -6.81674862],
219                                            [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
220                                             0.00000000, 0.00000000, 0.00000000]], dtype=np.float32)
221default_group_fc1_bias_adamax = np.array([0.00000000, 0.85614461, -0.52348828, 0.00000000], dtype=np.float32)
222
223default_fc1_weight_rprop = np.array([[9.10877514, 9.10877514, 9.10877514, 9.10877514, 9.10877514,
224                                      9.10877514, 9.10877514, 9.10877514],
225                                     [2.68465400, 2.68465400, 2.68465400, 2.68465400, 2.68465400,
226                                      2.68465400, 2.68465400, 2.68465400],
227                                     [1.04377401, 1.04377401, 1.04377401, 1.04377401, 1.04377401,
228                                      1.04377401, 1.04377401, 1.04377401],
229                                     [-1.33468997, -1.33468997, -1.33468997, -1.33468997, -1.33468997,
230                                      -1.33468997, -1.33468997, -1.33468997]], dtype=np.float32)
231default_fc1_bias_rprop = np.array([0.47940922, 0.14129758, 0.05493547, -0.07024684], dtype=np.float32)
232
233no_default_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091,
234                                         8.41605091, 8.41605091],
235                                        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
236                                         0.00000000, 0.00000000],
237                                        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
238                                         0.00000000, 0.00000000],
239                                        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
240                                         0.00000000, 0.00000000]], dtype=np.float32)
241no_default_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32)
242
243default_group_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091,
244                                            8.41605091, 8.41605091],
245                                           [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
246                                            0.00000000, 0.00000000],
247                                           [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
248                                            0.00000000, 0.00000000],
249                                           [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
250                                            0.00000000, 0.00000000]], dtype=np.float32)
251default_group_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32)
252