1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import re 16import numpy as np 17 18import mindspore.common.dtype as mstype 19import mindspore.nn as nn 20import mindspore.ops.functional as F 21from mindspore import Tensor 22from mindspore import context 23from mindspore.common.api import _cell_graph_executor 24from mindspore.common.initializer import TruncatedNormal 25from mindspore.communication.management import init 26from mindspore.nn.loss.loss import LossBase 27from mindspore.nn.optim.momentum import Momentum 28from mindspore.ops import operations as P 29from mindspore.parallel import _cost_model_context as cost_model_context 30from mindspore.parallel import set_algo_parameters 31from mindspore.parallel._utils import _reset_op_id as resset_op_id 32from mindspore.train.model import Model 33from mindspore.context import ParallelMode 34from mindspore.communication._comm_helper import GlobalComm 35 36context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 37context.set_context(device_id=0) 38GlobalComm.CHECK_ENVS = False 39init() 40GlobalComm.CHECK_ENVS = True 41 42def weight_variable(): 43 return TruncatedNormal(0.02) 44 45 46def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 47 """Get a conv2d layer with 3x3 kernel size.""" 48 init_value = weight_variable() 49 return nn.Conv2d(in_channels, out_channels, 50 kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 51 52 53def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 54 """Get a conv2d layer with 1x1 kernel size.""" 55 init_value = weight_variable() 56 return nn.Conv2d(in_channels, out_channels, 57 kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 58 59 60def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 61 """Get a conv2d layer with 7x7 kernel size.""" 62 init_value = weight_variable() 63 return nn.Conv2d(in_channels, out_channels, 64 kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 65 66 67def _fused_bn(channels, momentum=0.9): 68 """Get a fused batchnorm""" 69 return nn.BatchNorm2d(channels, momentum=momentum) 70 71 72class ResidualBlock(nn.Cell): 73 expansion = 4 74 75 def __init__(self, 76 in_channels, 77 out_channels, 78 stride=1, 79 momentum=0.9): 80 super(ResidualBlock, self).__init__() 81 82 out_chls = out_channels // self.expansion 83 self.conv1 = _conv1x1(in_channels, out_chls, stride=1) 84 self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 85 self.bn1 = _fused_bn(out_chls, momentum=momentum) 86 87 self.conv2 = _conv3x3(out_chls, out_chls, stride=stride) 88 self.conv2.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 89 self.bn2 = _fused_bn(out_chls, momentum=momentum) 90 91 self.conv3 = _conv1x1(out_chls, out_channels, stride=1) 92 self.conv3.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 93 self.bn3 = _fused_bn(out_channels, momentum=momentum) 94 95 self.relu = P.ReLU() 96 self.downsample = (in_channels != out_channels) 97 self.stride = stride 98 if self.downsample: 99 self.conv_down_sample = _conv1x1(in_channels, out_channels, 100 stride=stride) 101 self.conv_down_sample.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 102 self.bn_down_sample = _fused_bn(out_channels, momentum=momentum) 103 elif self.stride != 1: 104 self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same') 105 106 self.add = P.Add() 107 108 def construct(self, x): 109 identity = x 110 111 out = self.conv1(x) 112 out = self.bn1(out) 113 out = self.relu(out) 114 115 out = self.conv2(out) 116 out = self.bn2(out) 117 out = self.relu(out) 118 119 out = self.conv3(out) 120 out = self.bn3(out) 121 122 if self.downsample: 123 identity = self.conv_down_sample(identity) 124 identity = self.bn_down_sample(identity) 125 elif self.stride != 1: 126 identity = self.maxpool_down(identity) 127 128 out = self.add(out, identity) 129 out = self.relu(out) 130 131 return out 132 133 134class ResNet(nn.Cell): 135 def __init__(self, 136 block, 137 layer_nums, 138 in_channels, 139 out_channels, 140 strides=None, 141 num_classes=100): 142 super(ResNet, self).__init__() 143 if strides is None: 144 strides = [1, 2, 2, 2] 145 if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: 146 raise ValueError("the length of " 147 "layer_num, inchannel, outchannel list must be 4!") 148 149 self.conv1 = _conv7x7(3, 64, stride=2) 150 self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 151 self.bn1 = _fused_bn(64) 152 self.relu = P.ReLU() 153 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') 154 155 self.layer1 = self._make_layer(block, 156 layer_nums[0], 157 in_channel=in_channels[0], 158 out_channel=out_channels[0], 159 stride=strides[0]) 160 self.layer2 = self._make_layer(block, 161 layer_nums[1], 162 in_channel=in_channels[1], 163 out_channel=out_channels[1], 164 stride=strides[1]) 165 self.layer3 = self._make_layer(block, 166 layer_nums[2], 167 in_channel=in_channels[2], 168 out_channel=out_channels[2], 169 stride=strides[2]) 170 self.layer4 = self._make_layer(block, 171 layer_nums[3], 172 in_channel=in_channels[3], 173 out_channel=out_channels[3], 174 stride=strides[3]) 175 176 self.mean = P.ReduceMean(keep_dims=True) 177 self.end_point = nn.Dense(2048, num_classes, has_bias=True, 178 weight_init=weight_variable(), 179 bias_init=weight_variable()).add_flags_recursive(fp16=True) 180 self.squeeze = P.Squeeze() 181 self.cast = P.Cast() 182 183 def _make_layer(self, block, layer_num, in_channel, out_channel, stride): 184 layers = [] 185 186 resblk = block(in_channel, out_channel, stride=1) 187 layers.append(resblk) 188 189 for _ in range(1, layer_num - 1): 190 resblk = block(out_channel, out_channel, stride=1) 191 layers.append(resblk) 192 193 resblk = block(out_channel, out_channel, stride=stride) 194 layers.append(resblk) 195 196 return nn.SequentialCell(layers) 197 198 def construct(self, x): 199 x = self.conv1(x) 200 x = self.bn1(x) 201 x = self.relu(x) 202 c1 = self.maxpool(x) 203 c2 = self.layer1(c1) 204 c3 = self.layer2(c2) 205 c4 = self.layer3(c3) 206 c5 = self.layer4(c4) 207 out = self.mean(c5, (2, 3)) 208 out = self.squeeze(out) 209 out = self.end_point(out) 210 211 return out 212 213 214def resnet50(class_num=10): 215 return ResNet(ResidualBlock, 216 [3, 4, 6, 3], 217 [64, 256, 512, 1024], 218 [256, 512, 1024, 2048], 219 [2, 2, 2, 1], 220 class_num) 221 222 223class SoftmaxCrossEntropyExpand(LossBase): 224 def __init__(self, sparse=False): 225 super(SoftmaxCrossEntropyExpand, self).__init__() 226 self.exp = P.Exp() 227 self.sum = P.ReduceSum(keep_dims=True) 228 self.onehot = P.OneHot() 229 self.on_value = Tensor(1.0, mstype.float32) 230 self.off_value = Tensor(0.0, mstype.float32) 231 self.div = P.Div() 232 self.log = P.Log() 233 self.sum_cross_entropy = P.ReduceSum(keep_dims=False) 234 self.mul = P.Mul() 235 self.mul2 = P.Mul() 236 self.cast = P.Cast() 237 self.mean = P.ReduceMean(keep_dims=False).add_prim_attr("cross_batch", True) 238 self.sparse = sparse 239 self.max = P.ReduceMax(keep_dims=True) 240 self.sub = P.Sub() 241 self.cast1 = P.Cast() 242 243 def construct(self, logit, label): 244 logit = self.cast1(logit, mstype.float32) 245 logit_max = self.max(logit) 246 exp = self.exp(self.sub(logit, logit_max)) 247 exp_sum = self.sum(exp, -1) 248 softmax_result = self.div(exp, exp_sum) 249 if self.sparse: 250 label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) 251 252 softmax_result_log = self.log(softmax_result) 253 loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1) 254 loss = self.mul2(F.scalar_to_array(-1.0), loss) 255 loss = self.mean(loss, -1) 256 257 return loss 258 259 260class DatasetLenet(): 261 def __init__(self, predict, label, length=3): 262 self.predict = predict 263 self.label = label 264 self.index = 0 265 self.length = length 266 267 def __iter__(self): 268 return self 269 270 def __next__(self): 271 if self.index >= self.length: 272 raise StopIteration 273 self.index += 1 274 return self.predict, self.label 275 276 def reset(self): 277 self.index = 0 278 279 def get_dataset_size(self): 280 return 32 281 282 def get_repeat_count(self): 283 return 1 284 285 def create_tuple_iterator(self, num_epochs=-1, do_copy=True): 286 return self 287 288 289def test_train_32k_8p(batch_size=32, num_classes=32768): 290 dev_num = 8 291 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) 292 set_algo_parameters(elementwise_op_strategy_follow=True) 293 resset_op_id() 294 np.random.seed(6) 295 input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) 296 label_np = np.zeros([batch_size]).astype(np.int32) 297 for i in range(0, batch_size): 298 label_np[i] = i % num_classes 299 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 300 net = resnet50(num_classes) 301 loss = SoftmaxCrossEntropyExpand(sparse=True) 302 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 303 model = Model(net, loss_fn=loss, optimizer=opt) 304 model.train(5, dataset, dataset_sink_mode=False) 305 strategies = _cell_graph_executor._get_shard_strategy(model._train_network) 306 for (k, v) in strategies.items(): 307 if re.search('Conv2D-op', k) is not None: 308 assert v[0][0] == dev_num 309 elif re.search('MatMul-op', k) is not None: 310 assert v == [[dev_num, 1], [1, 1]] 311 elif re.search('ReduceSum-op', k) is not None: 312 assert v == [[dev_num, 1]] 313 314 allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network) 315 print(allreduce_fusion_dict) 316 return allreduce_fusion_dict 317 318 319def train_32k_8p_fusion1(batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192 320 cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) 321 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) 322 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) 323 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) 324 allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes) 325 expect_dict = {'end_point.bias': 2, 326 'end_point.weight': 2, 327 'layer4.2.bn3.beta': 2, 328 'layer4.2.bn3.gamma': 2, 329 'layer4.2.conv3.weight': 2, 330 'layer4.2.bn2.beta': 2, 331 'layer4.2.bn2.gamma': 2, 332 'layer4.2.conv2.weight': 2, 333 'layer4.2.bn1.beta': 2, 334 'layer4.2.bn1.gamma': 2, 335 'layer4.2.conv1.weight': 2, 336 'layer4.1.bn3.beta': 2, 337 'layer4.1.bn3.gamma': 2, 338 'layer4.1.conv3.weight': 2, 339 'layer4.1.bn2.beta': 2, 340 'layer4.1.bn2.gamma': 2, 341 'layer4.1.conv2.weight': 2, 342 'layer4.1.bn1.beta': 2, 343 'layer4.1.bn1.gamma': 2, 344 'layer4.1.conv1.weight': 2, 345 'layer4.0.bn_down_sample.beta': 2, 346 'layer4.0.bn_down_sample.gamma': 2, 347 'layer4.0.conv_down_sample.weight': 2, 348 'layer4.0.bn3.beta': 2, 349 'layer4.0.bn3.gamma': 2, 350 'layer4.0.conv3.weight': 2, 351 'layer4.0.bn2.beta': 2, 352 'layer4.0.bn2.gamma': 2, 353 'layer4.0.conv2.weight': 2, 354 'layer4.0.bn1.beta': 2, 355 'layer4.0.bn1.gamma': 2, 356 'layer4.0.conv1.weight': 2, 357 'layer3.5.bn3.beta': 2, 358 'layer3.5.bn3.gamma': 2, 359 'layer3.5.conv3.weight': 2, 360 'layer3.5.bn2.beta': 2, 361 'layer3.5.bn2.gamma': 2, 362 'layer3.5.conv2.weight': 2, 363 'layer3.5.bn1.beta': 2, 364 'layer3.5.bn1.gamma': 2, 365 'layer3.5.conv1.weight': 2, 366 'layer3.4.bn3.beta': 2, 367 'layer3.4.bn3.gamma': 2, 368 'layer3.4.conv3.weight': 2, 369 'layer3.4.bn2.beta': 2, 370 'layer3.4.bn2.gamma': 2, 371 'layer3.4.conv2.weight': 2, 372 'layer3.4.bn1.beta': 2, 373 'layer3.4.bn1.gamma': 2, 374 'layer3.4.conv1.weight': 2, 375 'layer3.3.bn3.beta': 2, 376 'layer3.3.bn3.gamma': 2, 377 'layer3.3.conv3.weight': 2, 378 'layer3.3.bn2.beta': 2, 379 'layer3.3.bn2.gamma': 2, 380 'layer3.3.conv2.weight': 2, 381 'layer3.3.bn1.beta': 2, 382 'layer3.3.bn1.gamma': 2, 383 'layer3.3.conv1.weight': 2, 384 'layer3.2.bn3.beta': 2, 385 'layer3.2.bn3.gamma': 2, 386 'layer3.2.conv3.weight': 2, 387 'layer3.2.bn2.beta': 2, 388 'layer3.2.bn2.gamma': 2, 389 'layer3.2.conv2.weight': 2, 390 'layer3.2.bn1.beta': 2, 391 'layer3.2.bn1.gamma': 2, 392 'layer3.2.conv1.weight': 2, 393 'layer3.1.bn3.beta': 2, 394 'layer3.1.bn3.gamma': 2, 395 'layer3.1.conv3.weight': 2, 396 'layer3.1.bn2.beta': 2, 397 'layer3.1.bn2.gamma': 2, 398 'layer3.1.conv2.weight': 2, 399 'layer3.1.bn1.beta': 2, 400 'layer3.1.bn1.gamma': 2, 401 'layer3.1.conv1.weight': 2, 402 'layer3.0.bn_down_sample.beta': 2, 403 'layer3.0.bn_down_sample.gamma': 2, 404 'layer3.0.conv_down_sample.weight': 2, 405 'layer3.0.bn3.beta': 2, 406 'layer3.0.bn3.gamma': 2, 407 'layer3.0.conv3.weight': 2, 408 'layer3.0.bn2.beta': 2, 409 'layer3.0.bn2.gamma': 2, 410 'layer3.0.conv2.weight': 2, 411 'layer3.0.bn1.beta': 2, 412 'layer3.0.bn1.gamma': 2, 413 'layer3.0.conv1.weight': 2, 414 'layer2.3.bn3.beta': 2, 415 'layer2.3.bn3.gamma': 2, 416 'layer2.3.conv3.weight': 2, 417 'layer2.3.bn2.beta': 2, 418 'layer2.3.bn2.gamma': 2, 419 'layer2.3.conv2.weight': 2, 420 'layer2.3.bn1.beta': 2, 421 'layer2.3.bn1.gamma': 2, 422 'layer2.3.conv1.weight': 2, 423 'layer2.2.bn3.beta': 2, 424 'layer2.2.bn3.gamma': 2, 425 'layer2.2.conv3.weight': 2, 426 'layer2.2.bn2.beta': 2, 427 'layer2.2.bn2.gamma': 2, 428 'layer2.2.conv2.weight': 2, 429 'layer2.2.bn1.beta': 2, 430 'layer2.2.bn1.gamma': 2, 431 'layer2.2.conv1.weight': 2, 432 'layer2.1.bn3.beta': 2, 433 'layer2.1.bn3.gamma': 2, 434 'layer2.1.conv3.weight': 2, 435 'layer2.1.bn2.beta': 2, 436 'layer2.1.bn2.gamma': 2, 437 'layer2.1.conv2.weight': 2, 438 'layer2.1.bn1.beta': 2, 439 'layer2.1.bn1.gamma': 2, 440 'layer2.1.conv1.weight': 2, 441 'layer2.0.bn_down_sample.beta': 2, 442 'layer2.0.bn_down_sample.gamma': 2, 443 'layer2.0.conv_down_sample.weight': 2, 444 'layer2.0.bn3.beta': 2, 445 'layer2.0.bn3.gamma': 2, 446 'layer2.0.conv3.weight': 2, 447 'layer2.0.bn2.beta': 2, 448 'layer2.0.bn2.gamma': 2, 449 'layer2.0.conv2.weight': 2, 450 'layer2.0.bn1.beta': 2, 451 'layer2.0.bn1.gamma': 2, 452 'layer2.0.conv1.weight': 2, 453 'layer1.2.bn3.beta': 2, 454 'layer1.2.bn3.gamma': 2, 455 'layer1.2.conv3.weight': 2, 456 'layer1.2.bn2.beta': 2, 457 'layer1.2.bn2.gamma': 2, 458 'layer1.2.conv2.weight': 2, 459 'layer1.2.bn1.beta': 2, 460 'layer1.2.bn1.gamma': 2, 461 'layer1.2.conv1.weight': 2, 462 'layer1.1.bn3.beta': 2, 463 'layer1.1.bn3.gamma': 2, 464 'layer1.1.conv3.weight': 2, 465 'layer1.1.bn2.beta': 2, 466 'layer1.1.bn2.gamma': 2, 467 'layer1.1.conv2.weight': 2, 468 'layer1.1.bn1.beta': 2, 469 'layer1.1.bn1.gamma': 2, 470 'layer1.1.conv1.weight': 2, 471 'layer1.0.bn_down_sample.beta': 2, 472 'layer1.0.bn_down_sample.gamma': 2, 473 'layer1.0.conv_down_sample.weight': 2, 474 'layer1.0.bn3.beta': 2, 475 'layer1.0.bn3.gamma': 2, 476 'layer1.0.conv3.weight': 2, 477 'layer1.0.bn2.beta': 2, 478 'layer1.0.bn2.gamma': 2, 479 'layer1.0.conv2.weight': 2, 480 'layer1.0.bn1.beta': 2, 481 'layer1.0.bn1.gamma': 2, 482 'layer1.0.conv1.weight': 2, 483 'bn1.beta': 1, 484 'bn1.gamma': 1, 485 'conv1.weight': 1} 486 487 assert allreduce_fusion_dict == expect_dict 488 cost_model_context.reset_cost_model_context() 489 490 491def train_32k_8p_fusion2(batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192 492 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) 493 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) 494 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05) 495 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001) 496 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015) 497 allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes) 498 expect_dict = {'end_point.bias': 2, 499 'end_point.weight': 2, 500 'layer4.2.bn3.beta': 2, 501 'layer4.2.bn3.gamma': 2, 502 'layer4.2.conv3.weight': 2, 503 'layer4.2.bn2.beta': 2, 504 'layer4.2.bn2.gamma': 2, 505 'layer4.2.conv2.weight': 2, 506 'layer4.2.bn1.beta': 2, 507 'layer4.2.bn1.gamma': 2, 508 'layer4.2.conv1.weight': 2, 509 'layer4.1.bn3.beta': 2, 510 'layer4.1.bn3.gamma': 2, 511 'layer4.1.conv3.weight': 2, 512 'layer4.1.bn2.beta': 2, 513 'layer4.1.bn2.gamma': 2, 514 'layer4.1.conv2.weight': 2, 515 'layer4.1.bn1.beta': 2, 516 'layer4.1.bn1.gamma': 2, 517 'layer4.1.conv1.weight': 2, 518 'layer4.0.bn_down_sample.beta': 2, 519 'layer4.0.bn_down_sample.gamma': 2, 520 'layer4.0.conv_down_sample.weight': 2, 521 'layer4.0.bn3.beta': 2, 522 'layer4.0.bn3.gamma': 2, 523 'layer4.0.conv3.weight': 2, 524 'layer4.0.bn2.beta': 2, 525 'layer4.0.bn2.gamma': 2, 526 'layer4.0.conv2.weight': 2, 527 'layer4.0.bn1.beta': 2, 528 'layer4.0.bn1.gamma': 2, 529 'layer4.0.conv1.weight': 2, 530 'layer3.5.bn3.beta': 2, 531 'layer3.5.bn3.gamma': 2, 532 'layer3.5.conv3.weight': 2, 533 'layer3.5.bn2.beta': 2, 534 'layer3.5.bn2.gamma': 2, 535 'layer3.5.conv2.weight': 2, 536 'layer3.5.bn1.beta': 2, 537 'layer3.5.bn1.gamma': 2, 538 'layer3.5.conv1.weight': 2, 539 'layer3.4.bn3.beta': 2, 540 'layer3.4.bn3.gamma': 2, 541 'layer3.4.conv3.weight': 2, 542 'layer3.4.bn2.beta': 2, 543 'layer3.4.bn2.gamma': 2, 544 'layer3.4.conv2.weight': 2, 545 'layer3.4.bn1.beta': 2, 546 'layer3.4.bn1.gamma': 2, 547 'layer3.4.conv1.weight': 2, 548 'layer3.3.bn3.beta': 2, 549 'layer3.3.bn3.gamma': 2, 550 'layer3.3.conv3.weight': 2, 551 'layer3.3.bn2.beta': 2, 552 'layer3.3.bn2.gamma': 2, 553 'layer3.3.conv2.weight': 2, 554 'layer3.3.bn1.beta': 2, 555 'layer3.3.bn1.gamma': 2, 556 'layer3.3.conv1.weight': 2, 557 'layer3.2.bn3.beta': 2, 558 'layer3.2.bn3.gamma': 2, 559 'layer3.2.conv3.weight': 2, 560 'layer3.2.bn2.beta': 2, 561 'layer3.2.bn2.gamma': 2, 562 'layer3.2.conv2.weight': 2, 563 'layer3.2.bn1.beta': 2, 564 'layer3.2.bn1.gamma': 2, 565 'layer3.2.conv1.weight': 2, 566 'layer3.1.bn3.beta': 2, 567 'layer3.1.bn3.gamma': 2, 568 'layer3.1.conv3.weight': 2, 569 'layer3.1.bn2.beta': 2, 570 'layer3.1.bn2.gamma': 2, 571 'layer3.1.conv2.weight': 2, 572 'layer3.1.bn1.beta': 2, 573 'layer3.1.bn1.gamma': 2, 574 'layer3.1.conv1.weight': 2, 575 'layer3.0.bn_down_sample.beta': 2, 576 'layer3.0.bn_down_sample.gamma': 2, 577 'layer3.0.conv_down_sample.weight': 2, 578 'layer3.0.bn3.beta': 2, 579 'layer3.0.bn3.gamma': 2, 580 'layer3.0.conv3.weight': 2, 581 'layer3.0.bn2.beta': 2, 582 'layer3.0.bn2.gamma': 2, 583 'layer3.0.conv2.weight': 2, 584 'layer3.0.bn1.beta': 2, 585 'layer3.0.bn1.gamma': 2, 586 'layer3.0.conv1.weight': 2, 587 'layer2.3.bn3.beta': 2, 588 'layer2.3.bn3.gamma': 2, 589 'layer2.3.conv3.weight': 2, 590 'layer2.3.bn2.beta': 2, 591 'layer2.3.bn2.gamma': 2, 592 'layer2.3.conv2.weight': 2, 593 'layer2.3.bn1.beta': 2, 594 'layer2.3.bn1.gamma': 2, 595 'layer2.3.conv1.weight': 2, 596 'layer2.2.bn3.beta': 2, 597 'layer2.2.bn3.gamma': 2, 598 'layer2.2.conv3.weight': 2, 599 'layer2.2.bn2.beta': 2, 600 'layer2.2.bn2.gamma': 2, 601 'layer2.2.conv2.weight': 2, 602 'layer2.2.bn1.beta': 2, 603 'layer2.2.bn1.gamma': 2, 604 'layer2.2.conv1.weight': 2, 605 'layer2.1.bn3.beta': 2, 606 'layer2.1.bn3.gamma': 2, 607 'layer2.1.conv3.weight': 2, 608 'layer2.1.bn2.beta': 2, 609 'layer2.1.bn2.gamma': 2, 610 'layer2.1.conv2.weight': 2, 611 'layer2.1.bn1.beta': 2, 612 'layer2.1.bn1.gamma': 2, 613 'layer2.1.conv1.weight': 2, 614 'layer2.0.bn_down_sample.beta': 2, 615 'layer2.0.bn_down_sample.gamma': 2, 616 'layer2.0.conv_down_sample.weight': 2, 617 'layer2.0.bn3.beta': 2, 618 'layer2.0.bn3.gamma': 2, 619 'layer2.0.conv3.weight': 2, 620 'layer2.0.bn2.beta': 2, 621 'layer2.0.bn2.gamma': 2, 622 'layer2.0.conv2.weight': 2, 623 'layer2.0.bn1.beta': 2, 624 'layer2.0.bn1.gamma': 2, 625 'layer2.0.conv1.weight': 2, 626 'layer1.2.bn3.beta': 2, 627 'layer1.2.bn3.gamma': 2, 628 'layer1.2.conv3.weight': 2, 629 'layer1.2.bn2.beta': 2, 630 'layer1.2.bn2.gamma': 2, 631 'layer1.2.conv2.weight': 2, 632 'layer1.2.bn1.beta': 2, 633 'layer1.2.bn1.gamma': 2, 634 'layer1.2.conv1.weight': 2, 635 'layer1.1.bn3.beta': 2, 636 'layer1.1.bn3.gamma': 2, 637 'layer1.1.conv3.weight': 2, 638 'layer1.1.bn2.beta': 2, 639 'layer1.1.bn2.gamma': 2, 640 'layer1.1.conv2.weight': 2, 641 'layer1.1.bn1.beta': 2, 642 'layer1.1.bn1.gamma': 2, 643 'layer1.1.conv1.weight': 2, 644 'layer1.0.bn_down_sample.beta': 2, 645 'layer1.0.bn_down_sample.gamma': 2, 646 'layer1.0.conv_down_sample.weight': 2, 647 'layer1.0.bn3.beta': 2, 648 'layer1.0.bn3.gamma': 2, 649 'layer1.0.conv3.weight': 2, 650 'layer1.0.bn2.beta': 2, 651 'layer1.0.bn2.gamma': 2, 652 'layer1.0.conv2.weight': 1, 653 'layer1.0.bn1.beta': 1, 654 'layer1.0.bn1.gamma': 1, 655 'layer1.0.conv1.weight': 1, 656 'bn1.beta': 1, 657 'bn1.gamma': 1, 658 'conv1.weight': 1} 659 660 assert allreduce_fusion_dict == expect_dict 661 cost_model_context.reset_cost_model_context() 662 663 664def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192 665 dev_num = 8 666 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) 667 cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) 668 set_algo_parameters(elementwise_op_strategy_follow=True) 669 resset_op_id() 670 np.random.seed(6) 671 input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) 672 label_np = np.zeros([batch_size]).astype(np.int32) 673 for i in range(0, batch_size): 674 label_np[i] = i % num_classes 675 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 676 net = resnet50(num_classes) 677 loss = SoftmaxCrossEntropyExpand(sparse=True) 678 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 679 model = Model(net, loss_fn=loss, optimizer=opt) 680 model.train(5, dataset, dataset_sink_mode=False) 681 strategies = _cell_graph_executor._get_shard_strategy(model._train_network) 682 for (k, v) in strategies.items(): 683 if re.search('Conv2D-op', k) is not None: 684 assert v[0][0] == dev_num 685 elif re.search('MatMul-op', k) is not None: 686 assert v == [[1, 1], [dev_num, 1]] 687 elif re.search('ReduceSum-op', k) is not None: 688 assert v == [[1, dev_num]] 689 690 691def test_train_8k_8p_gpu(batch_size=32, num_classes=8192): 692 dev_num = 8 693 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 694 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) 695 set_algo_parameters(elementwise_op_strategy_follow=True) 696 #set_algo_parameters(enable_algo_approxi=True) 697 resset_op_id() 698 np.random.seed(6) 699 input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) 700 label_np = np.zeros([batch_size]).astype(np.int32) 701 for i in range(0, batch_size): 702 label_np[i] = i % num_classes 703 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 704 net = resnet50(num_classes) 705 loss = SoftmaxCrossEntropyExpand(sparse=True) 706 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 707 model = Model(net, loss_fn=loss, optimizer=opt) 708 model.train(5, dataset, dataset_sink_mode=False) 709 strategies = _cell_graph_executor._get_shard_strategy(model._train_network) 710 for (k, v) in strategies.items(): 711 if re.search('Conv2D-op', k) is not None: 712 assert v[0][0] == dev_num 713 elif re.search('MatMul-op', k) is not None: 714 assert v == [[1, 1], [dev_num, 1]] 715 elif re.search('ReduceSum-op', k) is not None: 716 assert v == [[1, dev_num]] 717 718def test_train_8k_8p_gpu_approxi(batch_size=32, num_classes=8192): 719 dev_num = 8 720 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 721 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) 722 set_algo_parameters(enable_algo_approxi=True) 723 resset_op_id() 724 np.random.seed(6) 725 input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) 726 label_np = np.zeros([batch_size]).astype(np.int32) 727 for i in range(0, batch_size): 728 label_np[i] = i % num_classes 729 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 730 net = resnet50(num_classes) 731 loss = SoftmaxCrossEntropyExpand(sparse=True) 732 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 733 model = Model(net, loss_fn=loss, optimizer=opt) 734 model.train(5, dataset, dataset_sink_mode=False) 735 strategies = _cell_graph_executor._get_shard_strategy(model._train_network) 736 for (k, v) in strategies.items(): 737 if re.search('Conv2D-op', k) is not None: 738 assert v[0][0] == dev_num 739 elif re.search('MatMul-op', k) is not None: 740 assert v == [[1, 1], [dev_num, 1]] 741 elif re.search('ReduceSum-op', k) is not None: 742 assert v == [[1, dev_num]] 743 744def test_train_4k_8p_gpu(batch_size=32, num_classes=4096): 745 dev_num = 8 746 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 747 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) 748 set_algo_parameters(elementwise_op_strategy_follow=True) 749 resset_op_id() 750 np.random.seed(6) 751 input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) 752 label_np = np.zeros([batch_size]).astype(np.int32) 753 for i in range(0, batch_size): 754 label_np[i] = i % num_classes 755 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 756 net = resnet50(num_classes) 757 loss = SoftmaxCrossEntropyExpand(sparse=True) 758 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 759 model = Model(net, loss_fn=loss, optimizer=opt) 760 model.train(5, dataset, dataset_sink_mode=False) 761 strategies = _cell_graph_executor._get_shard_strategy(model._train_network) 762 for (k, v) in strategies.items(): 763 if re.search('Conv2D-op', k) is not None: 764 assert v[0][0] == dev_num 765 elif re.search('MatMul-op', k) is not None: 766 assert v == [[dev_num, 1], [1, 1]] 767 elif re.search('ReduceSum-op', k) is not None: 768 assert v == [[dev_num, 1]] 769