1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import time 18import random 19from multiprocessing import Process, Queue 20import numpy as np 21import pytest 22 23import mindspore.common.dtype as mstype 24import mindspore.dataset as ds 25import mindspore.dataset.transforms.c_transforms as C 26import mindspore.dataset.vision.c_transforms as vision 27import mindspore.nn as nn 28import mindspore.ops.functional as F 29 30from mindspore import Tensor 31from mindspore import context 32from mindspore import ParameterTuple 33from mindspore.nn import Cell 34from mindspore.ops import operations as P 35from mindspore.ops import composite as CP 36from mindspore.nn.optim.momentum import Momentum 37from mindspore.train.callback import Callback 38from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 39from mindspore.train.loss_scale_manager import FixedLossScaleManager 40from mindspore.train.model import Model 41from mindspore.context import ParallelMode 42import mindspore.communication.management as D 43MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json" 44 45np.random.seed(1) 46os.environ['GLOG_v'] = str(2) 47os.environ['ASCEND_GLOBAL_LOG_LEVEL'] = str(3) 48os.environ['ASCEND_GLOBAL_EVENT_ENABLE'] = str(0) 49 50class MyTimeMonitor(Callback): 51 def __init__(self, data_size): 52 super(MyTimeMonitor, self).__init__() 53 self.data_size = data_size 54 self.total = 0 55 56 def epoch_begin(self, run_context): 57 self.epoch_time = time.time() 58 59 def epoch_end(self, run_context): 60 epoch_msseconds = (time.time()-self.epoch_time) * 1000 61 per_step_mssconds = epoch_msseconds / self.data_size 62 print("epoch time:{0}, per step time:{1}".format(epoch_msseconds, per_step_mssconds), flush=True) 63 64 def step_begin(self, run_context): 65 self.step_time = time.time() 66 67 def step_end(self, run_context): 68 step_msseconds = (time.time() - self.step_time) * 1000 69 if step_msseconds < 400: 70 self.total = self.total + 1 71 print(f"step time:{step_msseconds}", flush=True) 72 73 def good_step(self): 74 return self.total 75 76random.seed(1) 77np.random.seed(1) 78ds.config.set_seed(1) 79 80grad_by_list = CP.GradOperation(get_by_list=True) 81 82 83def weight_variable_0(shape): 84 zeros = np.zeros(shape).astype(np.float32) 85 return Tensor(zeros) 86 87 88def weight_variable_1(shape): 89 ones = np.ones(shape).astype(np.float32) 90 return Tensor(ones) 91 92 93def conv3x3(in_channels, out_channels, stride=1, padding=0): 94 """3x3 convolution """ 95 return nn.Conv2d(in_channels, out_channels, 96 kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform', 97 has_bias=False, pad_mode="same") 98 99 100def conv1x1(in_channels, out_channels, stride=1, padding=0): 101 """1x1 convolution""" 102 return nn.Conv2d(in_channels, out_channels, 103 kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform', 104 has_bias=False, pad_mode="same") 105 106 107def conv7x7(in_channels, out_channels, stride=1, padding=0): 108 """1x1 convolution""" 109 return nn.Conv2d(in_channels, out_channels, 110 kernel_size=7, stride=stride, padding=padding, weight_init='XavierUniform', 111 has_bias=False, pad_mode="same") 112 113 114def bn_with_initialize(out_channels): 115 shape = (out_channels) 116 mean = weight_variable_0(shape) 117 var = weight_variable_1(shape) 118 beta = weight_variable_0(shape) 119 bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', 120 beta_init=beta, moving_mean_init=mean, moving_var_init=var) 121 return bn 122 123 124def bn_with_initialize_last(out_channels): 125 shape = (out_channels) 126 mean = weight_variable_0(shape) 127 var = weight_variable_1(shape) 128 beta = weight_variable_0(shape) 129 bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', 130 beta_init=beta, moving_mean_init=mean, moving_var_init=var) 131 return bn 132 133 134def fc_with_initialize(input_channels, out_channels): 135 return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform') 136 137 138class ResidualBlock(nn.Cell): 139 expansion = 4 140 141 def __init__(self, 142 in_channels, 143 out_channels, 144 stride=1): 145 super(ResidualBlock, self).__init__() 146 147 out_chls = out_channels // self.expansion 148 self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) 149 self.bn1 = bn_with_initialize(out_chls) 150 151 self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) 152 self.bn2 = bn_with_initialize(out_chls) 153 154 self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) 155 self.bn3 = bn_with_initialize_last(out_channels) 156 157 self.relu = P.ReLU() 158 self.add = P.Add() 159 160 def construct(self, x): 161 identity = x 162 163 out = self.conv1(x) 164 out = self.bn1(out) 165 out = self.relu(out) 166 167 out = self.conv2(out) 168 out = self.bn2(out) 169 out = self.relu(out) 170 171 out = self.conv3(out) 172 out = self.bn3(out) 173 174 out = self.add(out, identity) 175 out = self.relu(out) 176 177 return out 178 179 180class ResidualBlockWithDown(nn.Cell): 181 expansion = 4 182 183 def __init__(self, 184 in_channels, 185 out_channels, 186 stride=1, 187 down_sample=False): 188 super(ResidualBlockWithDown, self).__init__() 189 190 out_chls = out_channels // self.expansion 191 self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) 192 self.bn1 = bn_with_initialize(out_chls) 193 194 self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) 195 self.bn2 = bn_with_initialize(out_chls) 196 197 self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) 198 self.bn3 = bn_with_initialize_last(out_channels) 199 200 self.relu = P.ReLU() 201 self.downSample = down_sample 202 203 self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) 204 self.bn_down_sample = bn_with_initialize(out_channels) 205 self.add = P.Add() 206 207 def construct(self, x): 208 identity = x 209 210 out = self.conv1(x) 211 out = self.bn1(out) 212 out = self.relu(out) 213 214 out = self.conv2(out) 215 out = self.bn2(out) 216 out = self.relu(out) 217 218 out = self.conv3(out) 219 out = self.bn3(out) 220 221 identity = self.conv_down_sample(identity) 222 identity = self.bn_down_sample(identity) 223 224 out = self.add(out, identity) 225 out = self.relu(out) 226 227 return out 228 229 230class MakeLayer0(nn.Cell): 231 232 def __init__(self, block, in_channels, out_channels, stride): 233 super(MakeLayer0, self).__init__() 234 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) 235 self.b = block(out_channels, out_channels, stride=stride) 236 self.c = block(out_channels, out_channels, stride=1) 237 238 def construct(self, x): 239 x = self.a(x) 240 x = self.b(x) 241 x = self.c(x) 242 243 return x 244 245 246class MakeLayer1(nn.Cell): 247 248 def __init__(self, block, in_channels, out_channels, stride): 249 super(MakeLayer1, self).__init__() 250 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 251 self.b = block(out_channels, out_channels, stride=1) 252 self.c = block(out_channels, out_channels, stride=1) 253 self.d = block(out_channels, out_channels, stride=1) 254 255 def construct(self, x): 256 x = self.a(x) 257 x = self.b(x) 258 x = self.c(x) 259 x = self.d(x) 260 261 return x 262 263 264class MakeLayer2(nn.Cell): 265 266 def __init__(self, block, in_channels, out_channels, stride): 267 super(MakeLayer2, self).__init__() 268 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 269 self.b = block(out_channels, out_channels, stride=1) 270 self.c = block(out_channels, out_channels, stride=1) 271 self.d = block(out_channels, out_channels, stride=1) 272 self.e = block(out_channels, out_channels, stride=1) 273 self.f = block(out_channels, out_channels, stride=1) 274 275 def construct(self, x): 276 x = self.a(x) 277 x = self.b(x) 278 x = self.c(x) 279 x = self.d(x) 280 x = self.e(x) 281 x = self.f(x) 282 283 return x 284 285 286class MakeLayer3(nn.Cell): 287 288 def __init__(self, block, in_channels, out_channels, stride): 289 super(MakeLayer3, self).__init__() 290 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 291 self.b = block(out_channels, out_channels, stride=1) 292 self.c = block(out_channels, out_channels, stride=1) 293 294 def construct(self, x): 295 x = self.a(x) 296 x = self.b(x) 297 x = self.c(x) 298 299 return x 300 301 302class ResNet(nn.Cell): 303 304 def __init__(self, block, num_classes=100, batch_size=32): 305 super(ResNet, self).__init__() 306 self.batch_size = batch_size 307 self.num_classes = num_classes 308 309 self.conv1 = conv7x7(3, 64, stride=2, padding=0) 310 311 self.bn1 = bn_with_initialize(64) 312 self.relu = P.ReLU() 313 self.maxpool = P.MaxPoolWithArgmax(kernel_size=3, strides=2, pad_mode="SAME") 314 315 self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) 316 self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) 317 self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) 318 self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) 319 320 self.pool = P.ReduceMean(keep_dims=True) 321 self.squeeze = P.Squeeze(axis=(2, 3)) 322 self.fc = fc_with_initialize(512 * block.expansion, num_classes) 323 324 def construct(self, x): 325 x = self.conv1(x) 326 x = self.bn1(x) 327 x = self.relu(x) 328 x = self.maxpool(x)[0] 329 330 x = self.layer1(x) 331 x = self.layer2(x) 332 x = self.layer3(x) 333 x = self.layer4(x) 334 335 x = self.pool(x, (2, 3)) 336 x = self.squeeze(x) 337 x = self.fc(x) 338 return x 339 340 341def resnet50(batch_size, num_classes): 342 return ResNet(ResidualBlock, num_classes, batch_size) 343 344 345def create_dataset(repeat_num=1, training=True, batch_size=32, num_samples=1600): 346 data_home = "/home/workspace/mindspore_dataset" 347 data_dir = data_home + "/cifar-10-batches-bin" 348 if not training: 349 data_dir = data_home + "/cifar-10-verify-bin" 350 data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples) 351 352 resize_height = 224 353 resize_width = 224 354 rescale = 1.0 / 255.0 355 shift = 0.0 356 357 # define map operations 358 random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT 359 random_horizontal_op = vision.RandomHorizontalFlip() 360 # interpolation default BILINEAR 361 resize_op = vision.Resize((resize_height, resize_width)) 362 rescale_op = vision.Rescale(rescale, shift) 363 normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) 364 changeswap_op = vision.HWC2CHW() 365 type_cast_op = C.TypeCast(mstype.int32) 366 367 c_trans = [] 368 if training: 369 c_trans = [random_crop_op, random_horizontal_op] 370 c_trans += [resize_op, rescale_op, normalize_op, 371 changeswap_op] 372 373 # apply map operations on images 374 data_set = data_set.map(operations=type_cast_op, input_columns="label") 375 data_set = data_set.map(operations=c_trans, input_columns="image") 376 377 # apply shuffle operations 378 data_set = data_set.shuffle(buffer_size=1000) 379 380 # apply batch operations 381 data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) 382 383 # apply repeat operations 384 data_set = data_set.repeat(repeat_num) 385 386 return data_set 387 388 389class CrossEntropyLoss(nn.Cell): 390 def __init__(self): 391 super(CrossEntropyLoss, self).__init__() 392 self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() 393 self.mean = P.ReduceMean() 394 self.one_hot = P.OneHot() 395 self.one = Tensor(1.0, mstype.float32) 396 self.zero = Tensor(0.0, mstype.float32) 397 398 def construct(self, logits, label): 399 label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) 400 loss = self.cross_entropy(logits, label)[0] 401 loss = self.mean(loss, (-1,)) 402 return loss 403 404 405class GradWrap(Cell): 406 """ GradWrap definition """ 407 408 def __init__(self, network): 409 super(GradWrap, self).__init__() 410 self.network = network 411 self.weights = ParameterTuple(network.trainable_params()) 412 413 def construct(self, x, label): 414 weights = self.weights 415 return grad_by_list(self.network, weights)(x, label) 416 417 418def test_pynative_resnet50(): 419 batch_size = 32 420 num_classes = 10 421 loss_scale = 128 422 total_step = 50 423 net = resnet50(batch_size, num_classes) 424 optimizer = Momentum(learning_rate=0.01, momentum=0.9, 425 params=filter(lambda x: x.requires_grad, net.get_parameters())) 426 data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size, num_samples=total_step * batch_size) 427 428 # define callbacks 429 time_cb = MyTimeMonitor(data_size=data_set.get_dataset_size()) 430 cb = [time_cb] 431 432 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 433 loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) 434 model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'}, 435 amp_level="O2", keep_batchnorm_fp32=False) 436 437 # train model 438 model.train(1, data_set, callbacks=cb, 439 sink_size=data_set.get_dataset_size(), dataset_sink_mode=True) 440 441 return time_cb.good_step() 442 443 444def test_pynative_resnet50_with_env(queue, device_id, device_num): 445 os.system("mkdir " + str(device_id)) 446 os.chdir(str(device_id)) 447 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=device_id) 448 os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH 449 os.environ['RANK_ID'] = str(device_id) 450 os.environ['RANK_SIZE'] = str(device_num) 451 D.init() 452 context.reset_auto_parallel_context() 453 context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, 454 device_num=device_num) 455 456 good_steps = test_pynative_resnet50() 457 queue.put(good_steps) 458 459 460@pytest.mark.level1 461@pytest.mark.platform_arm_ascend_training 462@pytest.mark.env_single 463def test_pynative_resnet50_8p(): 464 device_num = 8 465 process = [] 466 q = Queue() 467 for i in range(device_num): 468 device_id = i 469 process.append(Process(target=test_pynative_resnet50_with_env, args=(q, device_id, device_num))) 470 471 for i in range(device_num): 472 process[i].start() 473 474 for i in range(device_num): 475 process[i].join() 476 477 # check result 478 for i in range(device_num): 479 assert not q.empty() 480 good_steps = q.get() 481 assert good_steps > 10 482