• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re
16import numpy as np
17
18import mindspore.common.dtype as mstype
19import mindspore.nn as nn
20import mindspore.ops.functional as F
21from mindspore import Tensor
22from mindspore import context
23from mindspore.common.api import _cell_graph_executor
24from mindspore.common.initializer import TruncatedNormal
25from mindspore.communication.management import init
26from mindspore.nn.loss.loss import LossBase
27from mindspore.nn.optim.momentum import Momentum
28from mindspore.ops import operations as P
29from mindspore.parallel import _cost_model_context as cost_model_context
30from mindspore.parallel import set_algo_parameters
31from mindspore.parallel._utils import _reset_op_id as resset_op_id
32from mindspore.train.model import Model
33from mindspore.context import ParallelMode
34from mindspore.communication._comm_helper import GlobalComm
35
36context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
37context.set_context(device_id=0)
38GlobalComm.CHECK_ENVS = False
39init()
40GlobalComm.CHECK_ENVS = True
41
42def weight_variable():
43    return TruncatedNormal(0.02)
44
45
46def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
47    """Get a conv2d layer with 3x3 kernel size."""
48    init_value = weight_variable()
49    return nn.Conv2d(in_channels, out_channels,
50                     kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
51
52
53def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
54    """Get a conv2d layer with 1x1 kernel size."""
55    init_value = weight_variable()
56    return nn.Conv2d(in_channels, out_channels,
57                     kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
58
59
60def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
61    """Get a conv2d layer with 7x7 kernel size."""
62    init_value = weight_variable()
63    return nn.Conv2d(in_channels, out_channels,
64                     kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
65
66
67def _fused_bn(channels, momentum=0.9):
68    """Get a fused batchnorm"""
69    return nn.BatchNorm2d(channels, momentum=momentum)
70
71
72class ResidualBlock(nn.Cell):
73    expansion = 4
74
75    def __init__(self,
76                 in_channels,
77                 out_channels,
78                 stride=1,
79                 momentum=0.9):
80        super(ResidualBlock, self).__init__()
81
82        out_chls = out_channels // self.expansion
83        self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
84        self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
85        self.bn1 = _fused_bn(out_chls, momentum=momentum)
86
87        self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
88        self.conv2.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
89        self.bn2 = _fused_bn(out_chls, momentum=momentum)
90
91        self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
92        self.conv3.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
93        self.bn3 = _fused_bn(out_channels, momentum=momentum)
94
95        self.relu = P.ReLU()
96        self.downsample = (in_channels != out_channels)
97        self.stride = stride
98        if self.downsample:
99            self.conv_down_sample = _conv1x1(in_channels, out_channels,
100                                             stride=stride)
101            self.conv_down_sample.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
102            self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
103        elif self.stride != 1:
104            self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
105
106        self.add = P.Add()
107
108    def construct(self, x):
109        identity = x
110
111        out = self.conv1(x)
112        out = self.bn1(out)
113        out = self.relu(out)
114
115        out = self.conv2(out)
116        out = self.bn2(out)
117        out = self.relu(out)
118
119        out = self.conv3(out)
120        out = self.bn3(out)
121
122        if self.downsample:
123            identity = self.conv_down_sample(identity)
124            identity = self.bn_down_sample(identity)
125        elif self.stride != 1:
126            identity = self.maxpool_down(identity)
127
128        out = self.add(out, identity)
129        out = self.relu(out)
130
131        return out
132
133
134class ResNet(nn.Cell):
135    def __init__(self,
136                 block,
137                 layer_nums,
138                 in_channels,
139                 out_channels,
140                 strides=None,
141                 num_classes=100):
142        super(ResNet, self).__init__()
143        if strides is None:
144            strides = [1, 2, 2, 2]
145        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
146            raise ValueError("the length of "
147                             "layer_num, inchannel, outchannel list must be 4!")
148
149        self.conv1 = _conv7x7(3, 64, stride=2)
150        self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
151        self.bn1 = _fused_bn(64)
152        self.relu = P.ReLU()
153        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
154
155        self.layer1 = self._make_layer(block,
156                                       layer_nums[0],
157                                       in_channel=in_channels[0],
158                                       out_channel=out_channels[0],
159                                       stride=strides[0])
160        self.layer2 = self._make_layer(block,
161                                       layer_nums[1],
162                                       in_channel=in_channels[1],
163                                       out_channel=out_channels[1],
164                                       stride=strides[1])
165        self.layer3 = self._make_layer(block,
166                                       layer_nums[2],
167                                       in_channel=in_channels[2],
168                                       out_channel=out_channels[2],
169                                       stride=strides[2])
170        self.layer4 = self._make_layer(block,
171                                       layer_nums[3],
172                                       in_channel=in_channels[3],
173                                       out_channel=out_channels[3],
174                                       stride=strides[3])
175
176        self.mean = P.ReduceMean(keep_dims=True)
177        self.end_point = nn.Dense(2048, num_classes, has_bias=True,
178                                  weight_init=weight_variable(),
179                                  bias_init=weight_variable()).add_flags_recursive(fp16=True)
180        self.squeeze = P.Squeeze()
181        self.cast = P.Cast()
182
183    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
184        layers = []
185
186        resblk = block(in_channel, out_channel, stride=1)
187        layers.append(resblk)
188
189        for _ in range(1, layer_num - 1):
190            resblk = block(out_channel, out_channel, stride=1)
191            layers.append(resblk)
192
193        resblk = block(out_channel, out_channel, stride=stride)
194        layers.append(resblk)
195
196        return nn.SequentialCell(layers)
197
198    def construct(self, x):
199        x = self.conv1(x)
200        x = self.bn1(x)
201        x = self.relu(x)
202        c1 = self.maxpool(x)
203        c2 = self.layer1(c1)
204        c3 = self.layer2(c2)
205        c4 = self.layer3(c3)
206        c5 = self.layer4(c4)
207        out = self.mean(c5, (2, 3))
208        out = self.squeeze(out)
209        out = self.end_point(out)
210
211        return out
212
213
214def resnet50(class_num=10):
215    return ResNet(ResidualBlock,
216                  [3, 4, 6, 3],
217                  [64, 256, 512, 1024],
218                  [256, 512, 1024, 2048],
219                  [2, 2, 2, 1],
220                  class_num)
221
222
223class SoftmaxCrossEntropyExpand(LossBase):
224    def __init__(self, sparse=False):
225        super(SoftmaxCrossEntropyExpand, self).__init__()
226        self.exp = P.Exp()
227        self.sum = P.ReduceSum(keep_dims=True)
228        self.onehot = P.OneHot()
229        self.on_value = Tensor(1.0, mstype.float32)
230        self.off_value = Tensor(0.0, mstype.float32)
231        self.div = P.Div()
232        self.log = P.Log()
233        self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
234        self.mul = P.Mul()
235        self.mul2 = P.Mul()
236        self.cast = P.Cast()
237        self.mean = P.ReduceMean(keep_dims=False).add_prim_attr("cross_batch", True)
238        self.sparse = sparse
239        self.max = P.ReduceMax(keep_dims=True)
240        self.sub = P.Sub()
241        self.cast1 = P.Cast()
242
243    def construct(self, logit, label):
244        logit = self.cast1(logit, mstype.float32)
245        logit_max = self.max(logit)
246        exp = self.exp(self.sub(logit, logit_max))
247        exp_sum = self.sum(exp, -1)
248        softmax_result = self.div(exp, exp_sum)
249        if self.sparse:
250            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
251
252        softmax_result_log = self.log(softmax_result)
253        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
254        loss = self.mul2(F.scalar_to_array(-1.0), loss)
255        loss = self.mean(loss, -1)
256
257        return loss
258
259
260class DatasetLenet():
261    def __init__(self, predict, label, length=3):
262        self.predict = predict
263        self.label = label
264        self.index = 0
265        self.length = length
266
267    def __iter__(self):
268        return self
269
270    def __next__(self):
271        if self.index >= self.length:
272            raise StopIteration
273        self.index += 1
274        return self.predict, self.label
275
276    def reset(self):
277        self.index = 0
278
279    def get_dataset_size(self):
280        return 32
281
282    def get_repeat_count(self):
283        return 1
284
285    def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
286        return self
287
288
289def test_train_32k_8p(batch_size=32, num_classes=32768):
290    dev_num = 8
291    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
292    set_algo_parameters(elementwise_op_strategy_follow=True)
293    resset_op_id()
294    np.random.seed(6)
295    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
296    label_np = np.zeros([batch_size]).astype(np.int32)
297    for i in range(0, batch_size):
298        label_np[i] = i % num_classes
299    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
300    net = resnet50(num_classes)
301    loss = SoftmaxCrossEntropyExpand(sparse=True)
302    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
303    model = Model(net, loss_fn=loss, optimizer=opt)
304    model.train(5, dataset, dataset_sink_mode=False)
305    strategies = _cell_graph_executor._get_shard_strategy(model._train_network)
306    for (k, v) in strategies.items():
307        if re.search('Conv2D-op', k) is not None:
308            assert v[0][0] == dev_num
309        elif re.search('MatMul-op', k) is not None:
310            assert v == [[dev_num, 1], [1, 1]]
311        elif re.search('ReduceSum-op', k) is not None:
312            assert v == [[dev_num, 1]]
313
314    allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network)
315    print(allreduce_fusion_dict)
316    return allreduce_fusion_dict
317
318
319def train_32k_8p_fusion1(batch_size=32, num_classes=32768):  # 1048576 #131072 #32768 #8192
320    cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
321    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
322    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
323    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
324    allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes)
325    expect_dict = {'end_point.bias': 2,
326                   'end_point.weight': 2,
327                   'layer4.2.bn3.beta': 2,
328                   'layer4.2.bn3.gamma': 2,
329                   'layer4.2.conv3.weight': 2,
330                   'layer4.2.bn2.beta': 2,
331                   'layer4.2.bn2.gamma': 2,
332                   'layer4.2.conv2.weight': 2,
333                   'layer4.2.bn1.beta': 2,
334                   'layer4.2.bn1.gamma': 2,
335                   'layer4.2.conv1.weight': 2,
336                   'layer4.1.bn3.beta': 2,
337                   'layer4.1.bn3.gamma': 2,
338                   'layer4.1.conv3.weight': 2,
339                   'layer4.1.bn2.beta': 2,
340                   'layer4.1.bn2.gamma': 2,
341                   'layer4.1.conv2.weight': 2,
342                   'layer4.1.bn1.beta': 2,
343                   'layer4.1.bn1.gamma': 2,
344                   'layer4.1.conv1.weight': 2,
345                   'layer4.0.bn_down_sample.beta': 2,
346                   'layer4.0.bn_down_sample.gamma': 2,
347                   'layer4.0.conv_down_sample.weight': 2,
348                   'layer4.0.bn3.beta': 2,
349                   'layer4.0.bn3.gamma': 2,
350                   'layer4.0.conv3.weight': 2,
351                   'layer4.0.bn2.beta': 2,
352                   'layer4.0.bn2.gamma': 2,
353                   'layer4.0.conv2.weight': 2,
354                   'layer4.0.bn1.beta': 2,
355                   'layer4.0.bn1.gamma': 2,
356                   'layer4.0.conv1.weight': 2,
357                   'layer3.5.bn3.beta': 2,
358                   'layer3.5.bn3.gamma': 2,
359                   'layer3.5.conv3.weight': 2,
360                   'layer3.5.bn2.beta': 2,
361                   'layer3.5.bn2.gamma': 2,
362                   'layer3.5.conv2.weight': 2,
363                   'layer3.5.bn1.beta': 2,
364                   'layer3.5.bn1.gamma': 2,
365                   'layer3.5.conv1.weight': 2,
366                   'layer3.4.bn3.beta': 2,
367                   'layer3.4.bn3.gamma': 2,
368                   'layer3.4.conv3.weight': 2,
369                   'layer3.4.bn2.beta': 2,
370                   'layer3.4.bn2.gamma': 2,
371                   'layer3.4.conv2.weight': 2,
372                   'layer3.4.bn1.beta': 2,
373                   'layer3.4.bn1.gamma': 2,
374                   'layer3.4.conv1.weight': 2,
375                   'layer3.3.bn3.beta': 2,
376                   'layer3.3.bn3.gamma': 2,
377                   'layer3.3.conv3.weight': 2,
378                   'layer3.3.bn2.beta': 2,
379                   'layer3.3.bn2.gamma': 2,
380                   'layer3.3.conv2.weight': 2,
381                   'layer3.3.bn1.beta': 2,
382                   'layer3.3.bn1.gamma': 2,
383                   'layer3.3.conv1.weight': 2,
384                   'layer3.2.bn3.beta': 2,
385                   'layer3.2.bn3.gamma': 2,
386                   'layer3.2.conv3.weight': 2,
387                   'layer3.2.bn2.beta': 2,
388                   'layer3.2.bn2.gamma': 2,
389                   'layer3.2.conv2.weight': 2,
390                   'layer3.2.bn1.beta': 2,
391                   'layer3.2.bn1.gamma': 2,
392                   'layer3.2.conv1.weight': 2,
393                   'layer3.1.bn3.beta': 2,
394                   'layer3.1.bn3.gamma': 2,
395                   'layer3.1.conv3.weight': 2,
396                   'layer3.1.bn2.beta': 2,
397                   'layer3.1.bn2.gamma': 2,
398                   'layer3.1.conv2.weight': 2,
399                   'layer3.1.bn1.beta': 2,
400                   'layer3.1.bn1.gamma': 2,
401                   'layer3.1.conv1.weight': 2,
402                   'layer3.0.bn_down_sample.beta': 2,
403                   'layer3.0.bn_down_sample.gamma': 2,
404                   'layer3.0.conv_down_sample.weight': 2,
405                   'layer3.0.bn3.beta': 2,
406                   'layer3.0.bn3.gamma': 2,
407                   'layer3.0.conv3.weight': 2,
408                   'layer3.0.bn2.beta': 2,
409                   'layer3.0.bn2.gamma': 2,
410                   'layer3.0.conv2.weight': 2,
411                   'layer3.0.bn1.beta': 2,
412                   'layer3.0.bn1.gamma': 2,
413                   'layer3.0.conv1.weight': 2,
414                   'layer2.3.bn3.beta': 2,
415                   'layer2.3.bn3.gamma': 2,
416                   'layer2.3.conv3.weight': 2,
417                   'layer2.3.bn2.beta': 2,
418                   'layer2.3.bn2.gamma': 2,
419                   'layer2.3.conv2.weight': 2,
420                   'layer2.3.bn1.beta': 2,
421                   'layer2.3.bn1.gamma': 2,
422                   'layer2.3.conv1.weight': 2,
423                   'layer2.2.bn3.beta': 2,
424                   'layer2.2.bn3.gamma': 2,
425                   'layer2.2.conv3.weight': 2,
426                   'layer2.2.bn2.beta': 2,
427                   'layer2.2.bn2.gamma': 2,
428                   'layer2.2.conv2.weight': 2,
429                   'layer2.2.bn1.beta': 2,
430                   'layer2.2.bn1.gamma': 2,
431                   'layer2.2.conv1.weight': 2,
432                   'layer2.1.bn3.beta': 2,
433                   'layer2.1.bn3.gamma': 2,
434                   'layer2.1.conv3.weight': 2,
435                   'layer2.1.bn2.beta': 2,
436                   'layer2.1.bn2.gamma': 2,
437                   'layer2.1.conv2.weight': 2,
438                   'layer2.1.bn1.beta': 2,
439                   'layer2.1.bn1.gamma': 2,
440                   'layer2.1.conv1.weight': 2,
441                   'layer2.0.bn_down_sample.beta': 2,
442                   'layer2.0.bn_down_sample.gamma': 2,
443                   'layer2.0.conv_down_sample.weight': 2,
444                   'layer2.0.bn3.beta': 2,
445                   'layer2.0.bn3.gamma': 2,
446                   'layer2.0.conv3.weight': 2,
447                   'layer2.0.bn2.beta': 2,
448                   'layer2.0.bn2.gamma': 2,
449                   'layer2.0.conv2.weight': 2,
450                   'layer2.0.bn1.beta': 2,
451                   'layer2.0.bn1.gamma': 2,
452                   'layer2.0.conv1.weight': 2,
453                   'layer1.2.bn3.beta': 2,
454                   'layer1.2.bn3.gamma': 2,
455                   'layer1.2.conv3.weight': 2,
456                   'layer1.2.bn2.beta': 2,
457                   'layer1.2.bn2.gamma': 2,
458                   'layer1.2.conv2.weight': 2,
459                   'layer1.2.bn1.beta': 2,
460                   'layer1.2.bn1.gamma': 2,
461                   'layer1.2.conv1.weight': 2,
462                   'layer1.1.bn3.beta': 2,
463                   'layer1.1.bn3.gamma': 2,
464                   'layer1.1.conv3.weight': 2,
465                   'layer1.1.bn2.beta': 2,
466                   'layer1.1.bn2.gamma': 2,
467                   'layer1.1.conv2.weight': 2,
468                   'layer1.1.bn1.beta': 2,
469                   'layer1.1.bn1.gamma': 2,
470                   'layer1.1.conv1.weight': 2,
471                   'layer1.0.bn_down_sample.beta': 2,
472                   'layer1.0.bn_down_sample.gamma': 2,
473                   'layer1.0.conv_down_sample.weight': 2,
474                   'layer1.0.bn3.beta': 2,
475                   'layer1.0.bn3.gamma': 2,
476                   'layer1.0.conv3.weight': 2,
477                   'layer1.0.bn2.beta': 2,
478                   'layer1.0.bn2.gamma': 2,
479                   'layer1.0.conv2.weight': 2,
480                   'layer1.0.bn1.beta': 2,
481                   'layer1.0.bn1.gamma': 2,
482                   'layer1.0.conv1.weight': 2,
483                   'bn1.beta': 1,
484                   'bn1.gamma': 1,
485                   'conv1.weight': 1}
486
487    assert allreduce_fusion_dict == expect_dict
488    cost_model_context.reset_cost_model_context()
489
490
491def train_32k_8p_fusion2(batch_size=32, num_classes=32768):  # 1048576 #131072 #32768 #8192
492    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
493    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
494    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
495    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
496    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
497    allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes)
498    expect_dict = {'end_point.bias': 2,
499                   'end_point.weight': 2,
500                   'layer4.2.bn3.beta': 2,
501                   'layer4.2.bn3.gamma': 2,
502                   'layer4.2.conv3.weight': 2,
503                   'layer4.2.bn2.beta': 2,
504                   'layer4.2.bn2.gamma': 2,
505                   'layer4.2.conv2.weight': 2,
506                   'layer4.2.bn1.beta': 2,
507                   'layer4.2.bn1.gamma': 2,
508                   'layer4.2.conv1.weight': 2,
509                   'layer4.1.bn3.beta': 2,
510                   'layer4.1.bn3.gamma': 2,
511                   'layer4.1.conv3.weight': 2,
512                   'layer4.1.bn2.beta': 2,
513                   'layer4.1.bn2.gamma': 2,
514                   'layer4.1.conv2.weight': 2,
515                   'layer4.1.bn1.beta': 2,
516                   'layer4.1.bn1.gamma': 2,
517                   'layer4.1.conv1.weight': 2,
518                   'layer4.0.bn_down_sample.beta': 2,
519                   'layer4.0.bn_down_sample.gamma': 2,
520                   'layer4.0.conv_down_sample.weight': 2,
521                   'layer4.0.bn3.beta': 2,
522                   'layer4.0.bn3.gamma': 2,
523                   'layer4.0.conv3.weight': 2,
524                   'layer4.0.bn2.beta': 2,
525                   'layer4.0.bn2.gamma': 2,
526                   'layer4.0.conv2.weight': 2,
527                   'layer4.0.bn1.beta': 2,
528                   'layer4.0.bn1.gamma': 2,
529                   'layer4.0.conv1.weight': 2,
530                   'layer3.5.bn3.beta': 2,
531                   'layer3.5.bn3.gamma': 2,
532                   'layer3.5.conv3.weight': 2,
533                   'layer3.5.bn2.beta': 2,
534                   'layer3.5.bn2.gamma': 2,
535                   'layer3.5.conv2.weight': 2,
536                   'layer3.5.bn1.beta': 2,
537                   'layer3.5.bn1.gamma': 2,
538                   'layer3.5.conv1.weight': 2,
539                   'layer3.4.bn3.beta': 2,
540                   'layer3.4.bn3.gamma': 2,
541                   'layer3.4.conv3.weight': 2,
542                   'layer3.4.bn2.beta': 2,
543                   'layer3.4.bn2.gamma': 2,
544                   'layer3.4.conv2.weight': 2,
545                   'layer3.4.bn1.beta': 2,
546                   'layer3.4.bn1.gamma': 2,
547                   'layer3.4.conv1.weight': 2,
548                   'layer3.3.bn3.beta': 2,
549                   'layer3.3.bn3.gamma': 2,
550                   'layer3.3.conv3.weight': 2,
551                   'layer3.3.bn2.beta': 2,
552                   'layer3.3.bn2.gamma': 2,
553                   'layer3.3.conv2.weight': 2,
554                   'layer3.3.bn1.beta': 2,
555                   'layer3.3.bn1.gamma': 2,
556                   'layer3.3.conv1.weight': 2,
557                   'layer3.2.bn3.beta': 2,
558                   'layer3.2.bn3.gamma': 2,
559                   'layer3.2.conv3.weight': 2,
560                   'layer3.2.bn2.beta': 2,
561                   'layer3.2.bn2.gamma': 2,
562                   'layer3.2.conv2.weight': 2,
563                   'layer3.2.bn1.beta': 2,
564                   'layer3.2.bn1.gamma': 2,
565                   'layer3.2.conv1.weight': 2,
566                   'layer3.1.bn3.beta': 2,
567                   'layer3.1.bn3.gamma': 2,
568                   'layer3.1.conv3.weight': 2,
569                   'layer3.1.bn2.beta': 2,
570                   'layer3.1.bn2.gamma': 2,
571                   'layer3.1.conv2.weight': 2,
572                   'layer3.1.bn1.beta': 2,
573                   'layer3.1.bn1.gamma': 2,
574                   'layer3.1.conv1.weight': 2,
575                   'layer3.0.bn_down_sample.beta': 2,
576                   'layer3.0.bn_down_sample.gamma': 2,
577                   'layer3.0.conv_down_sample.weight': 2,
578                   'layer3.0.bn3.beta': 2,
579                   'layer3.0.bn3.gamma': 2,
580                   'layer3.0.conv3.weight': 2,
581                   'layer3.0.bn2.beta': 2,
582                   'layer3.0.bn2.gamma': 2,
583                   'layer3.0.conv2.weight': 2,
584                   'layer3.0.bn1.beta': 2,
585                   'layer3.0.bn1.gamma': 2,
586                   'layer3.0.conv1.weight': 2,
587                   'layer2.3.bn3.beta': 2,
588                   'layer2.3.bn3.gamma': 2,
589                   'layer2.3.conv3.weight': 2,
590                   'layer2.3.bn2.beta': 2,
591                   'layer2.3.bn2.gamma': 2,
592                   'layer2.3.conv2.weight': 2,
593                   'layer2.3.bn1.beta': 2,
594                   'layer2.3.bn1.gamma': 2,
595                   'layer2.3.conv1.weight': 2,
596                   'layer2.2.bn3.beta': 2,
597                   'layer2.2.bn3.gamma': 2,
598                   'layer2.2.conv3.weight': 2,
599                   'layer2.2.bn2.beta': 2,
600                   'layer2.2.bn2.gamma': 2,
601                   'layer2.2.conv2.weight': 2,
602                   'layer2.2.bn1.beta': 2,
603                   'layer2.2.bn1.gamma': 2,
604                   'layer2.2.conv1.weight': 2,
605                   'layer2.1.bn3.beta': 2,
606                   'layer2.1.bn3.gamma': 2,
607                   'layer2.1.conv3.weight': 2,
608                   'layer2.1.bn2.beta': 2,
609                   'layer2.1.bn2.gamma': 2,
610                   'layer2.1.conv2.weight': 2,
611                   'layer2.1.bn1.beta': 2,
612                   'layer2.1.bn1.gamma': 2,
613                   'layer2.1.conv1.weight': 2,
614                   'layer2.0.bn_down_sample.beta': 2,
615                   'layer2.0.bn_down_sample.gamma': 2,
616                   'layer2.0.conv_down_sample.weight': 2,
617                   'layer2.0.bn3.beta': 2,
618                   'layer2.0.bn3.gamma': 2,
619                   'layer2.0.conv3.weight': 2,
620                   'layer2.0.bn2.beta': 2,
621                   'layer2.0.bn2.gamma': 2,
622                   'layer2.0.conv2.weight': 2,
623                   'layer2.0.bn1.beta': 2,
624                   'layer2.0.bn1.gamma': 2,
625                   'layer2.0.conv1.weight': 2,
626                   'layer1.2.bn3.beta': 2,
627                   'layer1.2.bn3.gamma': 2,
628                   'layer1.2.conv3.weight': 2,
629                   'layer1.2.bn2.beta': 2,
630                   'layer1.2.bn2.gamma': 2,
631                   'layer1.2.conv2.weight': 2,
632                   'layer1.2.bn1.beta': 2,
633                   'layer1.2.bn1.gamma': 2,
634                   'layer1.2.conv1.weight': 2,
635                   'layer1.1.bn3.beta': 2,
636                   'layer1.1.bn3.gamma': 2,
637                   'layer1.1.conv3.weight': 2,
638                   'layer1.1.bn2.beta': 2,
639                   'layer1.1.bn2.gamma': 2,
640                   'layer1.1.conv2.weight': 2,
641                   'layer1.1.bn1.beta': 2,
642                   'layer1.1.bn1.gamma': 2,
643                   'layer1.1.conv1.weight': 2,
644                   'layer1.0.bn_down_sample.beta': 2,
645                   'layer1.0.bn_down_sample.gamma': 2,
646                   'layer1.0.conv_down_sample.weight': 2,
647                   'layer1.0.bn3.beta': 2,
648                   'layer1.0.bn3.gamma': 2,
649                   'layer1.0.conv3.weight': 2,
650                   'layer1.0.bn2.beta': 2,
651                   'layer1.0.bn2.gamma': 2,
652                   'layer1.0.conv2.weight': 1,
653                   'layer1.0.bn1.beta': 1,
654                   'layer1.0.bn1.gamma': 1,
655                   'layer1.0.conv1.weight': 1,
656                   'bn1.beta': 1,
657                   'bn1.gamma': 1,
658                   'conv1.weight': 1}
659
660    assert allreduce_fusion_dict == expect_dict
661    cost_model_context.reset_cost_model_context()
662
663
664def test_train_64k_8p(batch_size=32, num_classes=65536):  # 1048576 #131072 #32768 #8192
665    dev_num = 8
666    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
667    cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
668    set_algo_parameters(elementwise_op_strategy_follow=True)
669    resset_op_id()
670    np.random.seed(6)
671    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
672    label_np = np.zeros([batch_size]).astype(np.int32)
673    for i in range(0, batch_size):
674        label_np[i] = i % num_classes
675    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
676    net = resnet50(num_classes)
677    loss = SoftmaxCrossEntropyExpand(sparse=True)
678    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
679    model = Model(net, loss_fn=loss, optimizer=opt)
680    model.train(5, dataset, dataset_sink_mode=False)
681    strategies = _cell_graph_executor._get_shard_strategy(model._train_network)
682    for (k, v) in strategies.items():
683        if re.search('Conv2D-op', k) is not None:
684            assert v[0][0] == dev_num
685        elif re.search('MatMul-op', k) is not None:
686            assert v == [[1, 1], [dev_num, 1]]
687        elif re.search('ReduceSum-op', k) is not None:
688            assert v == [[1, dev_num]]
689
690
691def test_train_8k_8p_gpu(batch_size=32, num_classes=8192):
692    dev_num = 8
693    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
694    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
695    set_algo_parameters(elementwise_op_strategy_follow=True)
696    #set_algo_parameters(enable_algo_approxi=True)
697    resset_op_id()
698    np.random.seed(6)
699    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
700    label_np = np.zeros([batch_size]).astype(np.int32)
701    for i in range(0, batch_size):
702        label_np[i] = i % num_classes
703    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
704    net = resnet50(num_classes)
705    loss = SoftmaxCrossEntropyExpand(sparse=True)
706    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
707    model = Model(net, loss_fn=loss, optimizer=opt)
708    model.train(5, dataset, dataset_sink_mode=False)
709    strategies = _cell_graph_executor._get_shard_strategy(model._train_network)
710    for (k, v) in strategies.items():
711        if re.search('Conv2D-op', k) is not None:
712            assert v[0][0] == dev_num
713        elif re.search('MatMul-op', k) is not None:
714            assert v == [[1, 1], [dev_num, 1]]
715        elif re.search('ReduceSum-op', k) is not None:
716            assert v == [[1, dev_num]]
717
718def test_train_8k_8p_gpu_approxi(batch_size=32, num_classes=8192):
719    dev_num = 8
720    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
721    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
722    set_algo_parameters(enable_algo_approxi=True)
723    resset_op_id()
724    np.random.seed(6)
725    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
726    label_np = np.zeros([batch_size]).astype(np.int32)
727    for i in range(0, batch_size):
728        label_np[i] = i % num_classes
729    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
730    net = resnet50(num_classes)
731    loss = SoftmaxCrossEntropyExpand(sparse=True)
732    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
733    model = Model(net, loss_fn=loss, optimizer=opt)
734    model.train(5, dataset, dataset_sink_mode=False)
735    strategies = _cell_graph_executor._get_shard_strategy(model._train_network)
736    for (k, v) in strategies.items():
737        if re.search('Conv2D-op', k) is not None:
738            assert v[0][0] == dev_num
739        elif re.search('MatMul-op', k) is not None:
740            assert v == [[1, 1], [dev_num, 1]]
741        elif re.search('ReduceSum-op', k) is not None:
742            assert v == [[1, dev_num]]
743
744def test_train_4k_8p_gpu(batch_size=32, num_classes=4096):
745    dev_num = 8
746    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
747    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
748    set_algo_parameters(elementwise_op_strategy_follow=True)
749    resset_op_id()
750    np.random.seed(6)
751    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
752    label_np = np.zeros([batch_size]).astype(np.int32)
753    for i in range(0, batch_size):
754        label_np[i] = i % num_classes
755    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
756    net = resnet50(num_classes)
757    loss = SoftmaxCrossEntropyExpand(sparse=True)
758    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
759    model = Model(net, loss_fn=loss, optimizer=opt)
760    model.train(5, dataset, dataset_sink_mode=False)
761    strategies = _cell_graph_executor._get_shard_strategy(model._train_network)
762    for (k, v) in strategies.items():
763        if re.search('Conv2D-op', k) is not None:
764            assert v[0][0] == dev_num
765        elif re.search('MatMul-op', k) is not None:
766            assert v == [[dev_num, 1], [1, 1]]
767        elif re.search('ReduceSum-op', k) is not None:
768            assert v == [[dev_num, 1]]
769