1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor, Parameter 20from mindspore import context 21from mindspore.common import dtype as mstype 22from mindspore.common.api import _cell_graph_executor 23from mindspore.nn.cell import Cell 24from mindspore.nn.optim.momentum import Momentum 25from mindspore.ops import composite as C 26from mindspore.ops import functional as F 27from mindspore.ops import operations as P 28from mindspore.train import Model 29from mindspore.context import ParallelMode 30from tests.dataset_mock import MindData 31from tests.ut.python.ops.test_math_ops import VirtualLoss 32 33 34grad_all = C.GradOperation(get_all=True) 35 36 37device_num = 16 38device_id = 2 39 40 41class StrategyModel(): 42 onehot_strategy = ((1, device_num), (), ()) 43 twod_strategy = ((1, device_num),) 44 twod_strategy_m = ((device_num, 1),) 45 scalar_twod_strategy = ((), (1, device_num)) 46 twod_scalar_strategy = ((1, device_num), ()) 47 scalar_strategy = ((),) 48 oned_strategy = ((1,),) 49 scalar_scalar_strategy = ((), ()) 50 twod_twod_strategy = ((1, device_num), (1, device_num)) 51 twod_twodbc_strategy = ((1, device_num), (1, 1)) 52 twodbc_twod_strategy = ((1, 1), (device_num, 1)) 53 54 55class StrategyBatch(): 56 onehot_strategy = ((device_num, 1), (), ()) 57 twod_strategy = ((1, device_num),) 58 twod_strategy_m = ((device_num, 1),) 59 scalar_twod_strategy = ((), (1, device_num)) 60 twod_scalar_strategy = ((1, device_num), ()) 61 scalar_strategy = ((),) 62 oned_strategy = ((1,),) 63 scalar_scalar_strategy = ((), ()) 64 twod_twod_strategy = ((1, device_num), (1, device_num)) 65 twod_twodbc_strategy = ((1, device_num), (1, 1)) 66 twodbc_twod_strategy = ((1, 1), (device_num, 1)) 67 68 69class Args(): 70 a = 1 71 b = 2 72 c = 3 73 d = 4 74 e = 5 75 num_classes = 512 76 emb_size = 512 77 78 79class SemiAutoOneHotNet(Cell): 80 def __init__(self, args, strategy): 81 super(SemiAutoOneHotNet, self).__init__() 82 self.a = args.a 83 self.b = args.b 84 self.c = args.c 85 self.d = args.d 86 self.e = args.e 87 self.cast = P.Cast() 88 self.cast.shard(strategy=strategy.twod_strategy) 89 self.cast1 = P.Cast() 90 self.cast1.shard(strategy=strategy.twod_strategy) 91 self.cast2 = P.Cast() 92 self.cast2.shard(strategy=strategy.twod_strategy) 93 self.cast3 = P.Cast() 94 self.cast3.shard(strategy=strategy.scalar_strategy) 95 self.cast4 = P.Cast() 96 self.cast4.shard(strategy=strategy.scalar_strategy) 97 self.a_const = Tensor(self.a, dtype=mstype.float32) 98 self.b_const = Tensor(self.b, dtype=mstype.float32) 99 self.c_const = Tensor(self.c, dtype=mstype.float32) 100 self.d_const = Tensor(self.d, dtype=mstype.float32) 101 self.e_const = Tensor(self.e, dtype=mstype.float32) 102 self.m_const_zero = Tensor(0, dtype=mstype.float32) 103 self.a_const_one = Tensor(1, dtype=mstype.float32) 104 self.onehot = P.OneHot() 105 self.onehot.shard(strategy=strategy.onehot_strategy) 106 self.exp = P.Exp() 107 self.exp.shard(strategy=strategy.twod_strategy) 108 self.exp2 = P.Exp() 109 self.exp2.shard(strategy=strategy.twod_strategy) 110 self.exp3 = P.Exp() 111 self.exp3.shard(strategy=strategy.twod_strategy) 112 self.mul_const = P.Mul() 113 self.mul_const.shard(strategy=strategy.scalar_twod_strategy) 114 self.mul_const2 = P.Add() 115 self.mul_const2.shard(strategy=strategy.scalar_twod_strategy) 116 self.mul_const3 = P.Sub() 117 self.mul_const3.shard(strategy=strategy.twod_scalar_strategy) 118 self.mul_const4 = P.Sub() 119 self.mul_const4.shard(strategy=strategy.scalar_twod_strategy) 120 self.mul_const5 = P.Mul() 121 self.mul_const5.shard(strategy=strategy.twod_scalar_strategy) 122 self.mul = P.Mul() 123 self.mul.shard(strategy=strategy.twod_twod_strategy) 124 self.mul2 = P.Mul() 125 self.mul2.shard(strategy=strategy.twod_twod_strategy) 126 self.mul3 = P.Add() 127 self.mul3.shard(strategy=strategy.twod_twod_strategy) 128 self.mul4 = P.Sub() 129 self.mul4.shard(strategy=strategy.twod_twodbc_strategy) 130 self.mul5 = P.RealDiv() 131 self.mul5.shard(strategy=strategy.twod_twodbc_strategy) 132 self.mul6 = P.Mul() 133 self.mul6.shard(strategy=strategy.twod_twod_strategy) 134 self.mul7 = P.Mul() 135 self.mul7.shard(strategy=strategy.twod_scalar_strategy) 136 self.mul8 = P.RealDiv() 137 self.mul8.shard(strategy=strategy.scalar_scalar_strategy) 138 self.mul9 = P.Add() 139 self.mul9.shard(strategy=strategy.twod_scalar_strategy) 140 141 self.reduce_max = P.ReduceMax(keep_dims=True) 142 self.reduce_max.shard(strategy=strategy.twod_strategy) 143 144 self.reduce_sum = P.ReduceSum(keep_dims=False) 145 self.reduce_sum.shard(strategy=strategy.twod_strategy) 146 self.reduce_sum_2 = P.ReduceSum(keep_dims=False) 147 self.reduce_sum_2.shard(strategy=strategy.twod_strategy) 148 self.reduce_sum_3 = P.ReduceSum(keep_dims=False) 149 self.reduce_sum_3.shard(strategy=strategy.oned_strategy) 150 151 self.reshape = P.Reshape() 152 self.log = P.Log() 153 self.log.shard(strategy=strategy.twod_strategy) 154 155 self.on_value = Tensor(1.0, mstype.float32) 156 self.off_value = Tensor(0.0, mstype.float32) 157 self.normalize = P.L2Normalize(axis=1) 158 self.normalize.shard(strategy=strategy.twod_strategy_m) 159 self.normalize2 = P.L2Normalize(axis=1) 160 self.normalize2.shard(strategy=strategy.twod_strategy_m) 161 self.fc = P.MatMul(transpose_b=True) 162 self.fc.shard(strategy=strategy.twodbc_twod_strategy) 163 weight_shape = [args.num_classes, args.emb_size] 164 weight_np = np.zeros(weight_shape, np.float32) 165 self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight') 166 167 def construct(self, input_, label): 168 input_n = self.normalize(input_) 169 w = self.normalize2(self.weight) 170 fc_o = self.fc(input_n, w) 171 fc_o_shape = F.shape(fc_o) 172 one_hot_float = self.onehot(label, fc_o_shape[1], self.on_value, self.off_value) 173 local_label = self.cast(one_hot_float, mstype.int32) 174 175 exp_o = self.exp(fc_o) 176 mul_const_o = self.mul_const(self.a_const, exp_o) 177 mul_const2_o = self.mul_const2(self.b_const, mul_const_o) 178 exp2_o = self.exp2(mul_const2_o) 179 mul_const3_o = self.mul_const3(exp2_o, self.c_const) 180 mul_const4_o = self.mul_const4(F.scalar_to_array(1), local_label) 181 mul6_o = self.mul6(self.mul(mul_const3_o, one_hot_float), 182 self.mul2(fc_o, self.cast2(mul_const4_o, mstype.float32))) 183 mul_const5_o = self.mul_const5(mul6_o, self.d_const) 184 185 max_o = self.reduce_max(mul_const5_o, -1) 186 mul4_o = self.mul4(mul_const5_o, max_o) 187 exp3_o = self.exp3(mul4_o) 188 sum_o = self.reduce_sum(exp3_o, -1) 189 reshape_o = self.reshape(sum_o, (F.shape(sum_o)[0], 1)) 190 mul5_o = self.mul5(exp3_o, reshape_o) 191 log_o = self.log(self.mul9(mul5_o, self.e_const)) 192 mul3_o = self.mul3(log_o, one_hot_float) 193 mul7_o = self.mul7(mul3_o, self.cast3(F.scalar_to_array(-1), mstype.float32)) 194 sum2_o = self.reduce_sum_2(mul7_o, -1) 195 loss = self.mul8(self.reduce_sum_3(sum2_o, -1), 196 self.cast4(F.scalar_to_array(F.shape(mul_const5_o)[0]), mstype.float32)) 197 return loss 198 199 200class Dataset(MindData): 201 def __init__(self, predict, label, length=3, input_num=2): 202 super(Dataset, self).__init__(size=length) 203 self.predict = predict 204 self.label = label 205 self.index = 0 206 self.length = length 207 self.input_num = input_num 208 209 def __iter__(self): 210 return self 211 212 def __next__(self): 213 if self.index >= self.length: 214 raise StopIteration 215 self.index += 1 216 if self.input_num == 2: 217 return (self.predict, self.label) 218 return (self.predict,) 219 220 def reset(self): 221 self.index = 0 222 223 224class NetWithLoss(nn.Cell): 225 def __init__(self, network): 226 super(NetWithLoss, self).__init__() 227 self.loss = VirtualLoss() 228 self.network = network 229 230 def construct(self, x, b): 231 predict = self.network(x, b) 232 return self.loss(predict) 233 234 235class GradWrap(nn.Cell): 236 def __init__(self, network): 237 super(GradWrap, self).__init__() 238 self.network = network 239 240 def construct(self, x, b): 241 return grad_all(self.network)(x, b) 242 243 244def bn_with_initialize(out_channels): 245 bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True) 246 return bn 247 248 249def fc_with_initialize(input_channels, out_channels): 250 return nn.Dense(input_channels, out_channels) 251 252 253class BNReshapeDenseBNNet(nn.Cell): 254 def __init__(self): 255 super(BNReshapeDenseBNNet, self).__init__() 256 self.batch_norm = bn_with_initialize(2) 257 self.reshape = P.Reshape() 258 self.batch_norm2 = nn.BatchNorm1d(512, affine=False) 259 self.fc = fc_with_initialize(2 * 32 * 32, 512) 260 self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch()) 261 262 def construct(self, x, label): 263 x = self.batch_norm(x) 264 x = self.reshape(x, (16, 2 * 32 * 32)) 265 x = self.fc(x) 266 x = self.batch_norm2(x) 267 loss = self.loss(x, label) 268 return loss 269 270 271def test_bn_reshape_dense_bn_train_loss(): 272 batch_size = 16 273 context.set_auto_parallel_context(device_num=device_num, global_rank=0) 274 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 275 input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01) 276 label = Tensor(np.ones([batch_size]), dtype=ms.int32) 277 278 net = GradWrap(NetWithLoss(BNReshapeDenseBNNet())) 279 net.set_auto_parallel() 280 281 net.set_train() 282 _cell_graph_executor.compile(net, input_, label) 283 284 285def test_semi_one_hot_net_batch(): 286 batch_size = 16 287 context.set_auto_parallel_context(device_num=device_num, global_rank=0) 288 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 289 input_ = Tensor(np.ones([batch_size * 1, 512]).astype(np.float32) * 0.01) 290 label = Tensor(np.ones([batch_size]), dtype=ms.int32) 291 292 net = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch()) 293 net = GradWrap(NetWithLoss(net)) 294 net.set_auto_parallel() 295 296 net.set_train() 297 _cell_graph_executor.compile(net, input_, label) 298 299 300def test_semi_one_hot_net_model(): 301 batch_size = 16 302 learning_rate = 0.1 303 momentum = 0.9 304 epoch_size = 2 305 306 predict = Tensor(np.ones([batch_size, 512]), dtype=ms.float32) 307 label = Tensor(np.ones([batch_size]), dtype=ms.int32) 308 dataset = Dataset(predict, label, 2, input_num=2) 309 310 context.reset_auto_parallel_context() 311 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=16) 312 context.set_context(mode=context.GRAPH_MODE) 313 net = SemiAutoOneHotNet(args=Args(), strategy=StrategyModel()) 314 opt = Momentum(net.trainable_params(), learning_rate, momentum) 315 model = Model(net, optimizer=opt) 316 model.train(epoch_size, dataset, dataset_sink_mode=False) 317