1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.common.parameter import Parameter 23from mindspore.common.parameter import ParameterTuple 24from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 25from mindspore.nn.optim.momentum import Momentum 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell 29from mindspore.parallel import set_algo_parameters 30from mindspore.train import Model 31from mindspore.context import ParallelMode 32from tests.dataset_mock import MindData 33from tests.ut.python.ops.test_math_ops import VirtualLoss 34 35context.set_context(mode=context.GRAPH_MODE) 36context.reset_auto_parallel_context() 37 38 39grad_all = C.GradOperation(get_all=True) 40 41 42class Dataset(MindData): 43 def __init__(self, predict, label, length=3, input_num=2): 44 super(Dataset, self).__init__(size=length) 45 self.predict = predict 46 self.label = label 47 self.index = 0 48 self.length = length 49 self.input_num = input_num 50 51 def __iter__(self): 52 return self 53 54 def __next__(self): 55 if self.index >= self.length: 56 raise StopIteration 57 self.index += 1 58 if self.input_num == 2: 59 return (self.predict, self.label) 60 return (self.predict,) 61 62 def reset(self): 63 self.index = 0 64 65 66class ReshapeNet(nn.Cell): 67 def __init__(self, strategy0, strategy1, strategy2): 68 super(ReshapeNet, self).__init__() 69 self.relu = P.ReLU().shard(strategy0) 70 self.reshape = P.Reshape().shard(strategy1) 71 self.matmul = P.MatMul().shard(strategy2) 72 self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 73 74 def construct(self, x): 75 x = self.relu(x) 76 x = self.reshape(x, (256, 25088)) 77 x = self.matmul(x, self.matmul_weight) 78 return x 79 80 81def reshape_net(strategy0, strategy1, strategy2): 82 return ReshapeNet(strategy0=strategy0, strategy1=strategy1, strategy2=strategy2) 83 84 85def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss): 86 learning_rate = 0.1 87 momentum = 0.9 88 epoch_size = 2 89 90 context.reset_auto_parallel_context() 91 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) 92 predict = Tensor(np.ones([32, 512, 7, 7]), dtype=ms.float32) 93 label = Tensor(np.ones([32]), dtype=ms.int32) 94 dataset = Dataset(predict, label, 2) 95 net = reshape_net(strategy0, strategy1, strategy2) 96 97 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 98 loss.softmax_cross_entropy.shard(strategy_loss) 99 loss.one_hot.shard(((8, 1), (), ())) 100 opt = Momentum(net.trainable_params(), learning_rate, momentum) 101 model = Model(net, loss, opt) 102 model.train(epoch_size, dataset, dataset_sink_mode=False) 103 104 105def test_reshape1(): 106 strategy0 = ((8, 1, 1, 1),) 107 strategy1 = None 108 strategy2 = ((8, 1), (1, 1)) 109 strategy_loss = ((8, 1), (8, 1)) 110 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 111 112 113def test_reshape1_strategy_1(): 114 strategy0 = ((8, 1, 1, 1),) 115 strategy1 = ((8, 1, 1, 1),) 116 strategy2 = ((8, 1), (1, 1)) 117 strategy_loss = ((8, 1), (8, 1)) 118 try: 119 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 120 except ValueError: 121 pass 122 except TypeError: 123 pass 124 except RuntimeError: 125 pass 126 127 128def test_reshape1_strategy_2(): 129 strategy0 = ((8, 1, 1, 1),) 130 strategy1 = ((8, 1, 1, 1),) 131 strategy2 = ((8, 1), (1, 1)) 132 strategy_loss = ((8, 1), (8, 1)) 133 try: 134 reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 135 except ValueError: 136 pass 137 except TypeError: 138 pass 139 except RuntimeError: 140 pass 141 142 143def test_reshape2(): 144 strategy0 = ((8, 1, 1, 1),) 145 strategy1 = None 146 strategy2 = ((8, 1), (1, 1)) 147 strategy_loss = ((8, 1), (8, 1)) 148 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 149 150 151def test_reshape3(): 152 strategy0 = ((2, 1, 1, 1),) 153 strategy1 = None 154 strategy2 = ((8, 1), (1, 1)) 155 strategy_loss = ((8, 1), (8, 1)) 156 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 157 158 159def test_reshape4(): 160 strategy0 = ((1, 1, 1, 1),) 161 strategy1 = None 162 strategy2 = ((8, 1), (1, 1)) 163 strategy_loss = ((8, 1), (8, 1)) 164 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 165 166 167def test_reshape5(): 168 strategy0 = ((2, 1, 1, 1),) 169 strategy1 = None 170 strategy2 = ((1, 8), (8, 1)) 171 strategy_loss = ((8, 1), (8, 1)) 172 reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 173 174 175def test_reshape_auto(): 176 strategy0 = None 177 strategy1 = None 178 strategy2 = None 179 strategy_loss = None 180 reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss) 181 182 183class NetWithLoss(nn.Cell): 184 def __init__(self, network): 185 super(NetWithLoss, self).__init__() 186 self.loss = VirtualLoss() 187 self.network = network 188 189 def construct(self, x): 190 predict = self.network(x) 191 return self.loss(predict) 192 193 194class GradWrap(nn.Cell): 195 def __init__(self, network): 196 super(GradWrap, self).__init__() 197 self.network = network 198 199 def construct(self, x): 200 return grad_all(self.network)(x) 201 202 203class ReshapeNet1(nn.Cell): 204 def __init__(self, strategy0): 205 super(ReshapeNet1, self).__init__() 206 self.reshape = P.Reshape() 207 self.matmul = P.MatMul().shard(strategy0) 208 self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 209 self.reshape2 = P.Reshape() 210 211 def construct(self, x): 212 x = self.reshape(x, (256, 25088)) 213 x = self.matmul(x, self.matmul_weight) 214 x = self.reshape2(x, (256 * 256,)) 215 return x 216 217 218class ReshapeNet2(nn.Cell): 219 def __init__(self, strategy0): 220 super(ReshapeNet2, self).__init__() 221 self.reshape = P.Reshape() 222 self.matmul = P.MatMul().shard(strategy0) 223 self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 224 self.reshape2 = P.Reshape() 225 self.reduce_sum = P.ReduceSum(keep_dims=True) 226 self.reshape3 = P.Reshape() 227 228 def construct(self, x): 229 x = self.reshape(x, (256, 25088)) 230 x = self.matmul(x, self.matmul_weight) 231 x = self.reshape2(x, (256 * 256,)) 232 x = self.reduce_sum(x, -1) 233 x = self.reshape3(x, ()) 234 return x 235 236 237class ReshapeNet3(nn.Cell): 238 def __init__(self, strategy0): 239 super(ReshapeNet3, self).__init__() 240 self.reshape = P.Reshape() 241 self.matmul = P.MatMul().shard(strategy0) 242 self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 243 self.reshape2 = P.Reshape() 244 self.reduce_sum = P.ReduceSum(keep_dims=False) 245 self.reshape3 = P.Reshape() 246 247 def construct(self, x): 248 x = self.reshape(x, (256, 25088)) 249 x = self.matmul(x, self.matmul_weight) 250 x = self.reshape2(x, (256 * 256,)) 251 x = self.reduce_sum(x, -1) 252 x = self.reshape3(x, (1, 1)) 253 return x 254 255 256class ReshapeNet4(nn.Cell): 257 def __init__(self, strategy0): 258 super(ReshapeNet4, self).__init__() 259 self.reshape = P.Reshape() 260 self.reshape2 = P.Reshape() 261 self.matmul = P.MatMul().shard(strategy0) 262 self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 263 264 def construct(self, x): 265 x = self.reshape(x, (256, 25088)) 266 w = self.reshape2(self.matmul_weight, (25088, 256)) 267 x = self.matmul(x, w) 268 return x 269 270 271class ReshapeNet5(nn.Cell): 272 def __init__(self, strategy0): 273 super(ReshapeNet5, self).__init__() 274 self.reshape = P.Reshape() 275 self.matmul1 = P.MatMul().shard(strategy0) 276 self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 277 self.matmul2 = P.MatMul().shard(strategy0) 278 279 def construct(self, x): 280 x = self.reshape(x, (256, 25088)) 281 matmul1_o = self.matmul1(x, self.matmul1_weight) 282 matmul2_o = self.matmul2(matmul1_o, x) 283 return matmul2_o 284 285 286class ReshapeNet6(nn.Cell): 287 def __init__(self, strategy0): 288 super(ReshapeNet6, self).__init__() 289 self.reshape = P.Reshape() 290 self.matmul1_1 = P.MatMul().shard(strategy0) 291 self.matmul1_2 = P.MatMul().shard(strategy0) 292 self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") 293 self.matmul2 = P.MatMul().shard(strategy0) 294 self.add = P.Add() 295 296 def construct(self, x): 297 x = self.reshape(x, (256, 25088)) 298 matmul1_1_o = self.matmul1_1(x, self.matmul1_weight) 299 matmul1_2_o = self.matmul1_2(x, self.matmul1_weight) 300 matmul1_o = self.add(matmul1_1_o, matmul1_2_o) 301 matmul2_o = self.matmul2(matmul1_o, x) 302 return matmul2_o 303 304 305def compile_net(net, input_): 306 net.set_auto_parallel() 307 net.set_train() 308 _cell_graph_executor.compile(net, input_) 309 310 311def reshape_net2(backbone): 312 batch_size = 16 313 device_num = 16 314 context.set_auto_parallel_context(device_num=device_num, global_rank=0) 315 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 316 input_ = Tensor(np.ones([batch_size * device_num, 512, 7, 7]).astype(np.float32) * 0.01) 317 318 net = GradWrap(NetWithLoss(backbone)) 319 320 compile_net(net, input_) 321 322 323def test_reshape_net1_1(): 324 reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) 325 326 327def test_reshape_net1_2(): 328 reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) 329 330 331def test_reshape_net2_1(): 332 reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) 333 334 335def test_reshape_net2_2(): 336 reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) 337 338 339def test_reshape_net3_1(): 340 reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) 341 342 343def test_reshape_net3_2(): 344 reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) 345 346 347def test_reshape_net4_1(): 348 try: 349 reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1))))) 350 except ValueError: 351 pass 352 except TypeError: 353 pass 354 except RuntimeError: 355 pass 356 357 358def test_reshape_net4_2(): 359 try: 360 reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2))))) 361 except ValueError: 362 pass 363 except TypeError: 364 pass 365 except RuntimeError: 366 pass 367 368 369def test_reshape_net5_1(): 370 reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1))))) 371 372 373def test_reshape_net5_2(): 374 reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2))))) 375 376 377def test_reshape_net6_1(): 378 reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1))))) 379 380 381def test_reshape_net6_2(): 382 reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2))))) 383 384 385class TrainOneStepCell(nn.Cell): 386 """ 387 Network training package class. 388 389 Append an optimizer to the training network after that the construct function 390 can be called to create the backward graph. 391 392 Args: 393 network (Cell): The training network. 394 optimizer (Cell): Optimizer for updating the weights. 395 sens (Number): The adjust parameter. Default: 1.0. 396 397 Examples: 398 >>> net = Net() 399 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 400 >>> optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 401 >>> loss_net = WithLossCell(net, loss_fn) 402 >>> train_net = TrainOneStepCell(loss_net, optim) 403 """ 404 405 def __init__(self, network, optimizer, sens=1.0): 406 super(TrainOneStepCell, self).__init__(auto_prefix=False) 407 self.network = network 408 self.network.add_flags(defer_inline=True) 409 self.weights = ParameterTuple(network.trainable_params()) 410 self.optimizer = optimizer 411 self.grad = C.GradOperation(get_by_list=True, 412 sens_param=True) 413 self.sens = sens 414 415 def construct(self, data): 416 weights = self.weights 417 loss = self.network(data) 418 sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) 419 grads = self.grad(self.network, weights)(data, sens) 420 421 self.optimizer(grads) 422 return loss 423 424 425def reshape_common2(parallel_mode, net): 426 batch_size = 16 427 learning_rate = 0.1 428 momentum = 0.9 429 epoch_size = 2 430 431 predict = Tensor(np.ones([batch_size, 512, 7, 7]), dtype=ms.float32) 432 label = Tensor(np.ones([batch_size]), dtype=ms.int32) 433 dataset = Dataset(predict, label, 2, input_num=1) 434 context.reset_auto_parallel_context() 435 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=16) 436 437 opt = Momentum(net.trainable_params(), learning_rate, momentum) 438 train_net = TrainOneStepCell(net, opt).set_train() 439 model = Model(train_net) 440 model.train(epoch_size, dataset, dataset_sink_mode=False) 441 442 443def test_reshape_common2_0(): 444 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) 445 446 447def test_reshape_common2_1(): 448 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) 449 450 451def test_reshape_common2_2(): 452 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) 453 454 455def test_reshape_common2_3(): 456 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) 457 458 459def test_reshape_common2_4(): 460 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) 461 462 463def test_reshape_common2_5(): 464 reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) 465 466 467class BatchNormReshapeNet(nn.Cell): 468 def __init__(self): 469 super(BatchNormReshapeNet, self).__init__() 470 self.batch_norm = nn.BatchNorm1d(512, affine=False) 471 self.reshape = P.Reshape() 472 self.prelu = nn.PReLU(channel=256) 473 474 def construct(self, x): 475 x = self.batch_norm(x) 476 x = self.reshape(x, (512, 256)) 477 x = self.prelu(x) 478 return x 479 480 481def test_batchnorm_reshape_train(): 482 batch_size = 16 483 device_num = 16 484 context.set_auto_parallel_context(device_num=device_num, global_rank=0) 485 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 486 input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01) 487 488 net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet()))) 489 490 compile_net(net, input_) 491 492 493def bn_with_initialize(out_channels): 494 bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True) 495 return bn 496 497 498def fc_with_initialize(input_channels, out_channels): 499 return nn.Dense(input_channels, out_channels).add_flags_recursive(fp16=True) 500 501 502class BNReshapeDenseBNNet(nn.Cell): 503 def __init__(self): 504 super(BNReshapeDenseBNNet, self).__init__() 505 self.batch_norm = bn_with_initialize(2) 506 self.reshape = P.Reshape() 507 self.cast = P.Cast() 508 self.batch_norm2 = nn.BatchNorm1d(512, affine=False) 509 self.fc = fc_with_initialize(2 * 32 * 32, 512) 510 511 def construct(self, x): 512 x = self.batch_norm(x) 513 x = self.reshape(x, (16, 2 * 32 * 32)) 514 x = self.fc(x) 515 x = self.batch_norm2(x) 516 return x 517 518 519def test_bn_reshape_dense_bn_train(): 520 batch_size = 16 521 device_num = 16 522 context.set_auto_parallel_context(device_num=device_num, global_rank=0) 523 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 524 input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01) 525 526 net = GradWrap(NetWithLoss(BNReshapeDenseBNNet())) 527 528 compile_net(net, input_) 529 530 531class ParallelReduceMeanNet(nn.Cell): 532 def __init__(self, conv_in_channel, conv_out_channel, 533 reducemean_keep_dims=False, reducemean_axis=-1, strategy=None): 534 super().__init__() 535 self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel, 536 kernel_size=1, stride=1, pad_mode='valid', has_bias=True, 537 weight_init='ones', bias_init='ones') 538 self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 539 self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims) 540 self.flat = nn.Flatten() 541 self.reducemean_axis = reducemean_axis 542 if strategy is not None: 543 self.reduce_mean.shard(strategy) 544 545 def construct(self, inputs): 546 x = self.conv(inputs) 547 x = self.reduce_mean(x, self.reducemean_axis) 548 x = self.flat(x) 549 return x 550 551 552class CrossEntropyLoss(nn.Cell): 553 def __init__(self, reduction='mean'): 554 super(CrossEntropyLoss, self).__init__() 555 556 self.reduce_mean = P.ReduceMean() 557 self.cross_entropy = SoftmaxCrossEntropyWithLogits() 558 self.reduction = reduction 559 560 def construct(self, logits, label): 561 loss = self.cross_entropy(logits, label) 562 if self.reduction == 'mean': 563 loss = self.reduce_mean(loss, (-1,)) 564 return loss 565 566 567def test_flatten_reshape(parallel_mode="auto_parallel"): 568 batch_size = 16 569 learning_rate = 0.1 570 momentum = 0.9 571 epoch_size = 2 572 context.reset_auto_parallel_context() 573 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) 574 net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), 575 strategy=((4, 2, 1, 1),)) 576 loss = CrossEntropyLoss() 577 predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) 578 label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32) 579 dataset = Dataset(predict, label, 2, input_num=2) 580 581 opt = Momentum(net.trainable_params(), learning_rate, momentum) 582 model = Model(net, loss_fn=loss, optimizer=opt) 583 model.train(epoch_size, dataset, dataset_sink_mode=False) 584 585 586def test_flatten_reshape2(parallel_mode="auto_parallel"): 587 batch_size = 16 588 learning_rate = 0.1 589 momentum = 0.9 590 epoch_size = 2 591 context.reset_auto_parallel_context() 592 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) 593 set_algo_parameters(fully_use_devices=False) 594 net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), 595 strategy=((4, 1, 1, 1),)) 596 loss = CrossEntropyLoss() 597 predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) 598 label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32) 599 dataset = Dataset(predict, label, 2, input_num=2) 600 601 opt = Momentum(net.trainable_params(), learning_rate, momentum) 602 model = Model(net, loss_fn=loss, optimizer=opt) 603 model.train(epoch_size, dataset, dataset_sink_mode=False) 604 605 606class ParallelReshapeNet(nn.Cell): 607 def __init__(self, dense_in_channel, dense_out_channel, shape, strategy=None): 608 super().__init__() 609 self.flat = nn.Flatten() 610 self.dense = nn.Dense(in_channels=dense_in_channel, 611 out_channels=dense_out_channel, 612 weight_init='ones', 613 bias_init='ones', 614 has_bias=True) 615 self.reshape = P.Reshape() 616 self.shape = shape 617 self.reshape.shard(strategy) 618 619 def construct(self, inputs): 620 x = self.flat(inputs) 621 x = self.dense(x) 622 x = self.reshape(x, self.shape) 623 return x 624 625 626# the shape of input and output of reshape is the same 627# reshape is optimized before step_parallel 628def test_flatten_reshape3(parallel_mode="auto_parallel"): 629 batch_size = 16 630 learning_rate = 0.1 631 momentum = 0.9 632 epoch_size = 2 633 context.reset_auto_parallel_context() 634 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) 635 set_algo_parameters(fully_use_devices=False) 636 net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),)) 637 loss = CrossEntropyLoss() 638 predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) 639 label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32) 640 dataset = Dataset(predict, label, 2, input_num=2) 641 642 opt = Momentum(net.trainable_params(), learning_rate, momentum) 643 model = Model(net, loss_fn=loss, optimizer=opt) 644 model.train(epoch_size, dataset, dataset_sink_mode=False) 645 646 647class CrossEntropyLoss2(nn.Cell): 648 def __init__(self, reduction='mean'): 649 super(CrossEntropyLoss2, self).__init__() 650 self.cross_entropy = SoftmaxCrossEntropyWithLogits(reduction=reduction) 651 652 def construct(self, logits, label): 653 loss = self.cross_entropy(logits, label) 654 return loss 655 656 657def test_flatten_reshape4(parallel_mode="semi_auto_parallel"): 658 batch_size = 16 659 learning_rate = 0.1 660 momentum = 0.9 661 epoch_size = 2 662 context.reset_auto_parallel_context() 663 context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) 664 set_algo_parameters(fully_use_devices=False) 665 net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, 666 strategy=((4, 1, 1, 1),)) 667 loss = CrossEntropyLoss2() 668 predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) 669 label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32) 670 dataset = Dataset(predict, label, 2, input_num=2) 671 672 opt = Momentum(net.trainable_params(), learning_rate, momentum) 673 model = Model(net, loss_fn=loss, optimizer=opt) 674 model.train(epoch_size, dataset, dataset_sink_mode=False) 675