1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import mindspore 18from mindspore import nn, Tensor 19from mindspore.ops import operations as P 20from mindspore.nn.optim import ASGD 21from mindspore.nn.optim import Rprop 22from mindspore.nn.optim import AdaMax 23 24np.random.seed(1024) 25 26fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149, 27 0.6942514, 0.39767185, 0.24918061, 0.4548748], 28 [0.7203382, 0.19086994, 0.76286614, 0.87920564, 29 0.3169892, 0.9462494, 0.62827677, 0.27504718], 30 [0.3544535, 0.2524781, 0.5370583, 0.8313121, 31 0.6670143, 0.0488653, 0.62225235, 0.7546456], 32 [0.17985944, 0.05106374, 0.31064633, 0.4863033, 33 0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32") 34 35fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32") 36 37fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32") 38 39fc2_bias = np.array([0.09996348]).astype("float32") 40 41 42def make_fake_data(): 43 """ 44 make fake data 45 """ 46 data, label = [], [] 47 for i in range(20): 48 data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32))) 49 label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32))) 50 return data, label 51 52 53class NetWithLoss(nn.Cell): 54 """ 55 build net with loss 56 """ 57 58 def __init__(self, network, loss_fn): 59 super(NetWithLoss, self).__init__() 60 self.network = network 61 self.loss = loss_fn 62 63 def construct(self, x, label): 64 out = self.network(x) 65 loss = self.loss(out, label) 66 return loss 67 68 69class FakeNet(nn.Cell): 70 """ 71 build fake net 72 """ 73 74 def __init__(self): 75 super(FakeNet, self).__init__() 76 self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias)) 77 self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias)) 78 self.relu = nn.ReLU() 79 self.reducemean = P.ReduceMean() 80 81 def construct(self, x): 82 x = self.relu(self.fc1(x)) 83 x = self.fc2(x) 84 return x 85 86 def _initialize_weights(self): 87 """ 88 parameter initialization 89 """ 90 self.init_parameters_data() 91 for name, m in self.cells_and_names(): 92 if name == 'fc1': 93 m.weight.set_data(Tensor(fc1_weight)) 94 m.bias.set_data(Tensor(fc1_bias)) 95 elif name == 'fc2': 96 m.weight.set_data(Tensor(fc2_weight)) 97 m.bias.set_data(Tensor(fc2_bias)) 98 99 100def build_network(opt_config, net, is_group=None, loss_fn=None): 101 """ 102 Construct training 103 """ 104 if is_group is None: 105 is_group = False 106 if loss_fn is None: 107 loss_fn = nn.L1Loss(reduction='mean') 108 losses = [] 109 networkwithloss = NetWithLoss(net, loss_fn) 110 networkwithloss.set_train() 111 112 if is_group: 113 fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params())) 114 fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params())) 115 if opt_config['name'] == 'ASGD': 116 params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.1}] 117 elif opt_config['name'] == 'adamax': 118 params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}] 119 elif opt_config['name'] == 'SGD': 120 params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}] 121 else: 122 params = [{'params': fc1_params, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.01}] 123 else: 124 params = networkwithloss.trainable_params() 125 126 if opt_config['name'] == 'ASGD': 127 net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'], 128 t0=opt_config['t0'], weight_decay=opt_config['weight_decay']) 129 130 elif opt_config['name'] == 'Rprop': 131 net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'], 132 step_sizes=opt_config['step_sizes'], weight_decay=0.0) 133 134 elif opt_config['name'] == 'adamax': 135 net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'], 136 beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0) 137 elif opt_config['name'] == 'SGD': 138 net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1) 139 trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt) 140 data, label = make_fake_data() 141 for i in range(20): 142 loss = trainonestepcell(data[i], label[i]) 143 losses.append(loss.asnumpy()) 144 return np.array(losses), net_opt 145 146 147default_fc1_weight_asgd = np.array([[0.460443, 0.693057, 0.145399, -0.076741, 0.431228, 0.134655, 148 -0.013833, 0.191857], 149 [0.391073, -0.138385, 0.433600, 0.549937, -0.012268, 0.616980, 150 0.299013, -0.054209], 151 [0.064144, -0.037829, 0.246745, 0.540993, 0.376698, -0.241438, 152 0.331937, 0.464328], 153 [-0.066224, -0.195017, 0.064560, 0.240214, 0.602717, 0.306225, 154 -0.043127, 0.475241]], dtype=np.float32) 155default_fc1_bias_asgd = np.array([0.740427, 0.091827, 0.624849, 0.851911], dtype=np.float32) 156default_fc2_weight_asgd = np.array([[0.585555, 0.512303, 0.424419, 0.323499]], dtype=np.float32) 157default_fc2_bias_asgd = np.array([0.059962], dtype=np.float32) 158 159no_default_fc1_weight_asgd = np.array([[0.645291, 0.877900, 0.330253, 0.108117, 0.616077, 0.319509, 0.171024, 160 0.376710], 161 [0.687056, 0.157610, 0.729583, 0.845918, 0.283724, 0.912958, 0.594999, 162 0.241783], 163 [0.328432, 0.226461, 0.511030, 0.805272, 0.640981, 0.022857, 0.596221, 164 0.728608], 165 [0.165102, 0.036311, 0.295884, 0.471533, 0.834030, 0.537543, 0.188198, 166 0.706556]], dtype=np.float32) 167no_default_fc1_bias_asgd = np.array([0.785650, 0.131580, 0.658614, 0.878328], dtype=np.float32) 168no_default_fc2_weight_asgd = np.array([[0.374859, -0.049370, -0.068307, -0.115195]], dtype=np.float32) 169no_default_fc2_bias_asgd = np.array([0.083960], dtype=np.float32) 170 171no_default_group_fc1_weight_asgd = np.array([[0.197470, 0.429578, -0.116887, -0.338544, 0.168320, -0.127608, 172 -0.275773, -0.070531], 173 [0.119964, -0.408341, 0.162399, 0.278482, -0.282498, 0.345379, 174 0.028105, -0.324348], 175 [-0.168310, -0.270062, 0.013893, 0.307500, 0.143563, -0.473227, 176 0.098900, 0.231002], 177 [-0.254349, -0.382861, -0.123849, 0.051422, 0.413136, 0.117289, 178 -0.231302, 0.285938]], dtype=np.float32) 179no_default_group_fc1_bias_asgd = np.array([0.706595, 0.042866, 0.579553, 0.811499], dtype=np.float32) 180no_default_group_fc2_weight_asgd = np.array([[-0.076689, -0.092399, -0.072100, -0.054189]], dtype=np.float32) 181no_default_group_fc2_bias_asgd = np.array([0.698678], dtype=np.float32) 182 183default_fc1_weight_sgd = np.array([[0.00533873, 0.03210080, -0.03090680, -0.05646387, 0.00197765, 184 -0.03214293, -0.04922638, -0.02556189], 185 [-0.00658702, -0.06750072, -0.00169432, 0.01169018, -0.05299109, 186 0.01940336, -0.01717841, -0.05781638], 187 [-0.03723934, -0.04897130, -0.01623122, 0.01762178, -0.00128018, 188 -0.07239634, -0.00642990, 0.00880153], 189 [-0.04421479, -0.05903235, -0.02916817, -0.00895938, 0.03274637, 190 -0.00136485, -0.04155754, 0.01808037]], dtype=np.float32) 191default_fc2_weight_sgd = np.array([[-0.01070179, -0.00702989, -0.00210839, 0.00160410]], dtype=np.float32) 192 193default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 194 0.00000000, 0.00000000, 0.00000000], 195 [11.18415642, 11.18415642, 11.18415642, 11.18415642, 11.18415642, 196 11.18415642, 11.18415642, 11.18415642], 197 [-6.70855522, -6.70855522, -6.70855522, -6.70855522, -6.70855522, 198 -6.70855522, -6.70855522, -6.70855522], 199 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 200 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) 201default_fc1_bias_adamax = np.array([0.00000000, 0.86349380, -0.51633584, 0.00000000], dtype=np.float32) 202 203no_default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 204 0.00000000, 0.00000000, 0.00000000], 205 [-4.02891350, -4.02891350, -4.02891350, -4.02891350, -4.02891350, 206 -4.02891350, -4.02891350, -4.02891350], 207 [3.10859227, 3.10859227, 3.10859227, 3.10859227, 3.10859227, 208 3.10859227, 3.10859227, 3.10859227], 209 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 210 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) 211no_default_fc1_bias_adamax = np.array([0.00000000, -0.04809491, 0.06205747, 0.00000000], dtype=np.float32) 212 213default_group_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 214 0.00000000, 0.00000000, 0.00000000], 215 [11.07278919, 11.07278919, 11.07278919, 11.07278919, 11.07278919, 216 11.07278919, 11.07278919, 11.07278919], 217 [-6.81674862, -6.81674862, -6.81674862, -6.81674862, -6.81674862, 218 -6.81674862, -6.81674862, -6.81674862], 219 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 220 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) 221default_group_fc1_bias_adamax = np.array([0.00000000, 0.85614461, -0.52348828, 0.00000000], dtype=np.float32) 222 223default_fc1_weight_rprop = np.array([[9.10877514, 9.10877514, 9.10877514, 9.10877514, 9.10877514, 224 9.10877514, 9.10877514, 9.10877514], 225 [2.68465400, 2.68465400, 2.68465400, 2.68465400, 2.68465400, 226 2.68465400, 2.68465400, 2.68465400], 227 [1.04377401, 1.04377401, 1.04377401, 1.04377401, 1.04377401, 228 1.04377401, 1.04377401, 1.04377401], 229 [-1.33468997, -1.33468997, -1.33468997, -1.33468997, -1.33468997, 230 -1.33468997, -1.33468997, -1.33468997]], dtype=np.float32) 231default_fc1_bias_rprop = np.array([0.47940922, 0.14129758, 0.05493547, -0.07024684], dtype=np.float32) 232 233no_default_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 234 8.41605091, 8.41605091], 235 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 236 0.00000000, 0.00000000], 237 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 238 0.00000000, 0.00000000], 239 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 240 0.00000000, 0.00000000]], dtype=np.float32) 241no_default_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) 242 243default_group_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 244 8.41605091, 8.41605091], 245 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 246 0.00000000, 0.00000000], 247 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 248 0.00000000, 0.00000000], 249 [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 250 0.00000000, 0.00000000]], dtype=np.float32) 251default_group_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) 252