• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.common.parameter import Parameter
23from mindspore.common.parameter import ParameterTuple
24from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
25from mindspore.nn.optim.momentum import Momentum
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
29from mindspore.parallel import set_algo_parameters
30from mindspore.train import Model
31from mindspore.context import ParallelMode
32from tests.dataset_mock import MindData
33from tests.ut.python.ops.test_math_ops import VirtualLoss
34
35context.set_context(mode=context.GRAPH_MODE)
36context.reset_auto_parallel_context()
37
38
39grad_all = C.GradOperation(get_all=True)
40
41
42class Dataset(MindData):
43    def __init__(self, predict, label, length=3, input_num=2):
44        super(Dataset, self).__init__(size=length)
45        self.predict = predict
46        self.label = label
47        self.index = 0
48        self.length = length
49        self.input_num = input_num
50
51    def __iter__(self):
52        return self
53
54    def __next__(self):
55        if self.index >= self.length:
56            raise StopIteration
57        self.index += 1
58        if self.input_num == 2:
59            return (self.predict, self.label)
60        return (self.predict,)
61
62    def reset(self):
63        self.index = 0
64
65
66class ReshapeNet(nn.Cell):
67    def __init__(self, strategy0, strategy1, strategy2):
68        super(ReshapeNet, self).__init__()
69        self.relu = P.ReLU().shard(strategy0)
70        self.reshape = P.Reshape().shard(strategy1)
71        self.matmul = P.MatMul().shard(strategy2)
72        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
73
74    def construct(self, x):
75        x = self.relu(x)
76        x = self.reshape(x, (256, 25088))
77        x = self.matmul(x, self.matmul_weight)
78        return x
79
80
81def reshape_net(strategy0, strategy1, strategy2):
82    return ReshapeNet(strategy0=strategy0, strategy1=strategy1, strategy2=strategy2)
83
84
85def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss):
86    learning_rate = 0.1
87    momentum = 0.9
88    epoch_size = 2
89
90    context.reset_auto_parallel_context()
91    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
92    predict = Tensor(np.ones([32, 512, 7, 7]), dtype=ms.float32)
93    label = Tensor(np.ones([32]), dtype=ms.int32)
94    dataset = Dataset(predict, label, 2)
95    net = reshape_net(strategy0, strategy1, strategy2)
96
97    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
98    loss.softmax_cross_entropy.shard(strategy_loss)
99    loss.one_hot.shard(((8, 1), (), ()))
100    opt = Momentum(net.trainable_params(), learning_rate, momentum)
101    model = Model(net, loss, opt)
102    model.train(epoch_size, dataset, dataset_sink_mode=False)
103
104
105def test_reshape1():
106    strategy0 = ((8, 1, 1, 1),)
107    strategy1 = None
108    strategy2 = ((8, 1), (1, 1))
109    strategy_loss = ((8, 1), (8, 1))
110    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
111
112
113def test_reshape1_strategy_1():
114    strategy0 = ((8, 1, 1, 1),)
115    strategy1 = ((8, 1, 1, 1),)
116    strategy2 = ((8, 1), (1, 1))
117    strategy_loss = ((8, 1), (8, 1))
118    try:
119        reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
120    except ValueError:
121        pass
122    except TypeError:
123        pass
124    except RuntimeError:
125        pass
126
127
128def test_reshape1_strategy_2():
129    strategy0 = ((8, 1, 1, 1),)
130    strategy1 = ((8, 1, 1, 1),)
131    strategy2 = ((8, 1), (1, 1))
132    strategy_loss = ((8, 1), (8, 1))
133    try:
134        reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
135    except ValueError:
136        pass
137    except TypeError:
138        pass
139    except RuntimeError:
140        pass
141
142
143def test_reshape2():
144    strategy0 = ((8, 1, 1, 1),)
145    strategy1 = None
146    strategy2 = ((8, 1), (1, 1))
147    strategy_loss = ((8, 1), (8, 1))
148    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
149
150
151def test_reshape3():
152    strategy0 = ((2, 1, 1, 1),)
153    strategy1 = None
154    strategy2 = ((8, 1), (1, 1))
155    strategy_loss = ((8, 1), (8, 1))
156    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
157
158
159def test_reshape4():
160    strategy0 = ((1, 1, 1, 1),)
161    strategy1 = None
162    strategy2 = ((8, 1), (1, 1))
163    strategy_loss = ((8, 1), (8, 1))
164    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
165
166
167def test_reshape5():
168    strategy0 = ((2, 1, 1, 1),)
169    strategy1 = None
170    strategy2 = ((1, 8), (8, 1))
171    strategy_loss = ((8, 1), (8, 1))
172    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
173
174
175def test_reshape_auto():
176    strategy0 = None
177    strategy1 = None
178    strategy2 = None
179    strategy_loss = None
180    reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
181
182
183class NetWithLoss(nn.Cell):
184    def __init__(self, network):
185        super(NetWithLoss, self).__init__()
186        self.loss = VirtualLoss()
187        self.network = network
188
189    def construct(self, x):
190        predict = self.network(x)
191        return self.loss(predict)
192
193
194class GradWrap(nn.Cell):
195    def __init__(self, network):
196        super(GradWrap, self).__init__()
197        self.network = network
198
199    def construct(self, x):
200        return grad_all(self.network)(x)
201
202
203class ReshapeNet1(nn.Cell):
204    def __init__(self, strategy0):
205        super(ReshapeNet1, self).__init__()
206        self.reshape = P.Reshape()
207        self.matmul = P.MatMul().shard(strategy0)
208        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
209        self.reshape2 = P.Reshape()
210
211    def construct(self, x):
212        x = self.reshape(x, (256, 25088))
213        x = self.matmul(x, self.matmul_weight)
214        x = self.reshape2(x, (256 * 256,))
215        return x
216
217
218class ReshapeNet2(nn.Cell):
219    def __init__(self, strategy0):
220        super(ReshapeNet2, self).__init__()
221        self.reshape = P.Reshape()
222        self.matmul = P.MatMul().shard(strategy0)
223        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
224        self.reshape2 = P.Reshape()
225        self.reduce_sum = P.ReduceSum(keep_dims=True)
226        self.reshape3 = P.Reshape()
227
228    def construct(self, x):
229        x = self.reshape(x, (256, 25088))
230        x = self.matmul(x, self.matmul_weight)
231        x = self.reshape2(x, (256 * 256,))
232        x = self.reduce_sum(x, -1)
233        x = self.reshape3(x, ())
234        return x
235
236
237class ReshapeNet3(nn.Cell):
238    def __init__(self, strategy0):
239        super(ReshapeNet3, self).__init__()
240        self.reshape = P.Reshape()
241        self.matmul = P.MatMul().shard(strategy0)
242        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
243        self.reshape2 = P.Reshape()
244        self.reduce_sum = P.ReduceSum(keep_dims=False)
245        self.reshape3 = P.Reshape()
246
247    def construct(self, x):
248        x = self.reshape(x, (256, 25088))
249        x = self.matmul(x, self.matmul_weight)
250        x = self.reshape2(x, (256 * 256,))
251        x = self.reduce_sum(x, -1)
252        x = self.reshape3(x, (1, 1))
253        return x
254
255
256class ReshapeNet4(nn.Cell):
257    def __init__(self, strategy0):
258        super(ReshapeNet4, self).__init__()
259        self.reshape = P.Reshape()
260        self.reshape2 = P.Reshape()
261        self.matmul = P.MatMul().shard(strategy0)
262        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
263
264    def construct(self, x):
265        x = self.reshape(x, (256, 25088))
266        w = self.reshape2(self.matmul_weight, (25088, 256))
267        x = self.matmul(x, w)
268        return x
269
270
271class ReshapeNet5(nn.Cell):
272    def __init__(self, strategy0):
273        super(ReshapeNet5, self).__init__()
274        self.reshape = P.Reshape()
275        self.matmul1 = P.MatMul().shard(strategy0)
276        self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
277        self.matmul2 = P.MatMul().shard(strategy0)
278
279    def construct(self, x):
280        x = self.reshape(x, (256, 25088))
281        matmul1_o = self.matmul1(x, self.matmul1_weight)
282        matmul2_o = self.matmul2(matmul1_o, x)
283        return matmul2_o
284
285
286class ReshapeNet6(nn.Cell):
287    def __init__(self, strategy0):
288        super(ReshapeNet6, self).__init__()
289        self.reshape = P.Reshape()
290        self.matmul1_1 = P.MatMul().shard(strategy0)
291        self.matmul1_2 = P.MatMul().shard(strategy0)
292        self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
293        self.matmul2 = P.MatMul().shard(strategy0)
294        self.add = P.Add()
295
296    def construct(self, x):
297        x = self.reshape(x, (256, 25088))
298        matmul1_1_o = self.matmul1_1(x, self.matmul1_weight)
299        matmul1_2_o = self.matmul1_2(x, self.matmul1_weight)
300        matmul1_o = self.add(matmul1_1_o, matmul1_2_o)
301        matmul2_o = self.matmul2(matmul1_o, x)
302        return matmul2_o
303
304
305def compile_net(net, input_):
306    net.set_auto_parallel()
307    net.set_train()
308    _cell_graph_executor.compile(net, input_)
309
310
311def reshape_net2(backbone):
312    batch_size = 16
313    device_num = 16
314    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
315    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
316    input_ = Tensor(np.ones([batch_size * device_num, 512, 7, 7]).astype(np.float32) * 0.01)
317
318    net = GradWrap(NetWithLoss(backbone))
319
320    compile_net(net, input_)
321
322
323def test_reshape_net1_1():
324    reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
325
326
327def test_reshape_net1_2():
328    reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
329
330
331def test_reshape_net2_1():
332    reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
333
334
335def test_reshape_net2_2():
336    reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
337
338
339def test_reshape_net3_1():
340    reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
341
342
343def test_reshape_net3_2():
344    reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
345
346
347def test_reshape_net4_1():
348    try:
349        reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1)))))
350    except ValueError:
351        pass
352    except TypeError:
353        pass
354    except RuntimeError:
355        pass
356
357
358def test_reshape_net4_2():
359    try:
360        reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2)))))
361    except ValueError:
362        pass
363    except TypeError:
364        pass
365    except RuntimeError:
366        pass
367
368
369def test_reshape_net5_1():
370    reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1)))))
371
372
373def test_reshape_net5_2():
374    reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2)))))
375
376
377def test_reshape_net6_1():
378    reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1)))))
379
380
381def test_reshape_net6_2():
382    reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2)))))
383
384
385class TrainOneStepCell(nn.Cell):
386    """
387    Network training package class.
388
389    Append an optimizer to the training network after that the construct function
390    can be called to create the backward graph.
391
392    Args:
393        network (Cell): The training network.
394        optimizer (Cell): Optimizer for updating the weights.
395        sens (Number): The adjust parameter. Default: 1.0.
396
397    Examples:
398        >>> net = Net()
399        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
400        >>> optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
401        >>> loss_net = WithLossCell(net, loss_fn)
402        >>> train_net = TrainOneStepCell(loss_net, optim)
403    """
404
405    def __init__(self, network, optimizer, sens=1.0):
406        super(TrainOneStepCell, self).__init__(auto_prefix=False)
407        self.network = network
408        self.network.add_flags(defer_inline=True)
409        self.weights = ParameterTuple(network.trainable_params())
410        self.optimizer = optimizer
411        self.grad = C.GradOperation(get_by_list=True,
412                                    sens_param=True)
413        self.sens = sens
414
415    def construct(self, data):
416        weights = self.weights
417        loss = self.network(data)
418        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
419        grads = self.grad(self.network, weights)(data, sens)
420
421        self.optimizer(grads)
422        return loss
423
424
425def reshape_common2(parallel_mode, net):
426    batch_size = 16
427    learning_rate = 0.1
428    momentum = 0.9
429    epoch_size = 2
430
431    predict = Tensor(np.ones([batch_size, 512, 7, 7]), dtype=ms.float32)
432    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
433    dataset = Dataset(predict, label, 2, input_num=1)
434    context.reset_auto_parallel_context()
435    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=16)
436
437    opt = Momentum(net.trainable_params(), learning_rate, momentum)
438    train_net = TrainOneStepCell(net, opt).set_train()
439    model = Model(train_net)
440    model.train(epoch_size, dataset, dataset_sink_mode=False)
441
442
443def test_reshape_common2_0():
444    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
445
446
447def test_reshape_common2_1():
448    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
449
450
451def test_reshape_common2_2():
452    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
453
454
455def test_reshape_common2_3():
456    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
457
458
459def test_reshape_common2_4():
460    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
461
462
463def test_reshape_common2_5():
464    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
465
466
467class BatchNormReshapeNet(nn.Cell):
468    def __init__(self):
469        super(BatchNormReshapeNet, self).__init__()
470        self.batch_norm = nn.BatchNorm1d(512, affine=False)
471        self.reshape = P.Reshape()
472        self.prelu = nn.PReLU(channel=256)
473
474    def construct(self, x):
475        x = self.batch_norm(x)
476        x = self.reshape(x, (512, 256))
477        x = self.prelu(x)
478        return x
479
480
481def test_batchnorm_reshape_train():
482    batch_size = 16
483    device_num = 16
484    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
485    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
486    input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)
487
488    net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet())))
489
490    compile_net(net, input_)
491
492
493def bn_with_initialize(out_channels):
494    bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
495    return bn
496
497
498def fc_with_initialize(input_channels, out_channels):
499    return nn.Dense(input_channels, out_channels).add_flags_recursive(fp16=True)
500
501
502class BNReshapeDenseBNNet(nn.Cell):
503    def __init__(self):
504        super(BNReshapeDenseBNNet, self).__init__()
505        self.batch_norm = bn_with_initialize(2)
506        self.reshape = P.Reshape()
507        self.cast = P.Cast()
508        self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
509        self.fc = fc_with_initialize(2 * 32 * 32, 512)
510
511    def construct(self, x):
512        x = self.batch_norm(x)
513        x = self.reshape(x, (16, 2 * 32 * 32))
514        x = self.fc(x)
515        x = self.batch_norm2(x)
516        return x
517
518
519def test_bn_reshape_dense_bn_train():
520    batch_size = 16
521    device_num = 16
522    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
523    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
524    input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
525
526    net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
527
528    compile_net(net, input_)
529
530
531class ParallelReduceMeanNet(nn.Cell):
532    def __init__(self, conv_in_channel, conv_out_channel,
533                 reducemean_keep_dims=False, reducemean_axis=-1, strategy=None):
534        super().__init__()
535        self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel,
536                              kernel_size=1, stride=1, pad_mode='valid', has_bias=True,
537                              weight_init='ones', bias_init='ones')
538        self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
539        self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims)
540        self.flat = nn.Flatten()
541        self.reducemean_axis = reducemean_axis
542        if strategy is not None:
543            self.reduce_mean.shard(strategy)
544
545    def construct(self, inputs):
546        x = self.conv(inputs)
547        x = self.reduce_mean(x, self.reducemean_axis)
548        x = self.flat(x)
549        return x
550
551
552class CrossEntropyLoss(nn.Cell):
553    def __init__(self, reduction='mean'):
554        super(CrossEntropyLoss, self).__init__()
555
556        self.reduce_mean = P.ReduceMean()
557        self.cross_entropy = SoftmaxCrossEntropyWithLogits()
558        self.reduction = reduction
559
560    def construct(self, logits, label):
561        loss = self.cross_entropy(logits, label)
562        if self.reduction == 'mean':
563            loss = self.reduce_mean(loss, (-1,))
564        return loss
565
566
567def test_flatten_reshape(parallel_mode="auto_parallel"):
568    batch_size = 16
569    learning_rate = 0.1
570    momentum = 0.9
571    epoch_size = 2
572    context.reset_auto_parallel_context()
573    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
574    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3),
575                                strategy=((4, 2, 1, 1),))
576    loss = CrossEntropyLoss()
577    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
578    label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
579    dataset = Dataset(predict, label, 2, input_num=2)
580
581    opt = Momentum(net.trainable_params(), learning_rate, momentum)
582    model = Model(net, loss_fn=loss, optimizer=opt)
583    model.train(epoch_size, dataset, dataset_sink_mode=False)
584
585
586def test_flatten_reshape2(parallel_mode="auto_parallel"):
587    batch_size = 16
588    learning_rate = 0.1
589    momentum = 0.9
590    epoch_size = 2
591    context.reset_auto_parallel_context()
592    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
593    set_algo_parameters(fully_use_devices=False)
594    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3),
595                                strategy=((4, 1, 1, 1),))
596    loss = CrossEntropyLoss()
597    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
598    label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
599    dataset = Dataset(predict, label, 2, input_num=2)
600
601    opt = Momentum(net.trainable_params(), learning_rate, momentum)
602    model = Model(net, loss_fn=loss, optimizer=opt)
603    model.train(epoch_size, dataset, dataset_sink_mode=False)
604
605
606class ParallelReshapeNet(nn.Cell):
607    def __init__(self, dense_in_channel, dense_out_channel, shape, strategy=None):
608        super().__init__()
609        self.flat = nn.Flatten()
610        self.dense = nn.Dense(in_channels=dense_in_channel,
611                              out_channels=dense_out_channel,
612                              weight_init='ones',
613                              bias_init='ones',
614                              has_bias=True)
615        self.reshape = P.Reshape()
616        self.shape = shape
617        self.reshape.shard(strategy)
618
619    def construct(self, inputs):
620        x = self.flat(inputs)
621        x = self.dense(x)
622        x = self.reshape(x, self.shape)
623        return x
624
625
626# the shape of input and output of reshape is the same
627# reshape is optimized before step_parallel
628def test_flatten_reshape3(parallel_mode="auto_parallel"):
629    batch_size = 16
630    learning_rate = 0.1
631    momentum = 0.9
632    epoch_size = 2
633    context.reset_auto_parallel_context()
634    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
635    set_algo_parameters(fully_use_devices=False)
636    net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),))
637    loss = CrossEntropyLoss()
638    predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32)
639    label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32)
640    dataset = Dataset(predict, label, 2, input_num=2)
641
642    opt = Momentum(net.trainable_params(), learning_rate, momentum)
643    model = Model(net, loss_fn=loss, optimizer=opt)
644    model.train(epoch_size, dataset, dataset_sink_mode=False)
645
646
647class CrossEntropyLoss2(nn.Cell):
648    def __init__(self, reduction='mean'):
649        super(CrossEntropyLoss2, self).__init__()
650        self.cross_entropy = SoftmaxCrossEntropyWithLogits(reduction=reduction)
651
652    def construct(self, logits, label):
653        loss = self.cross_entropy(logits, label)
654        return loss
655
656
657def test_flatten_reshape4(parallel_mode="semi_auto_parallel"):
658    batch_size = 16
659    learning_rate = 0.1
660    momentum = 0.9
661    epoch_size = 2
662    context.reset_auto_parallel_context()
663    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
664    set_algo_parameters(fully_use_devices=False)
665    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True,
666                                strategy=((4, 1, 1, 1),))
667    loss = CrossEntropyLoss2()
668    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
669    label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32)
670    dataset = Dataset(predict, label, 2, input_num=2)
671
672    opt = Momentum(net.trainable_params(), learning_rate, momentum)
673    model = Model(net, loss_fn=loss, optimizer=opt)
674    model.train(epoch_size, dataset, dataset_sink_mode=False)
675