1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import numpy as np 18 19import mindspore.common.dtype as mstype 20import mindspore.context as context 21import mindspore.nn as nn 22import mindspore.ops.functional as F 23from mindspore import Tensor 24from mindspore.common.initializer import TruncatedNormal 25from mindspore.communication.management import init 26from mindspore.nn.loss.loss import LossBase 27from mindspore.nn.optim.momentum import Momentum 28from mindspore.ops import operations as P 29from mindspore.parallel import set_algo_parameters 30from mindspore.train.callback import Callback 31from mindspore.train.model import Model 32from mindspore.context import ParallelMode 33 34context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 35context.set_context(device_id=int(os.getenv('DEVICE_ID'))) 36init() 37context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) 38np.random.seed(10) 39 40 41def weight_variable(): 42 return TruncatedNormal(0.01) 43 44 45def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 46 init_value = weight_variable() 47 return nn.Conv2d(in_channels, out_channels, 48 kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 49 50 51def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 52 init_value = weight_variable() 53 return nn.Conv2d(in_channels, out_channels, 54 kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 55 56 57def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): 58 init_value = weight_variable() 59 return nn.Conv2d(in_channels, out_channels, 60 kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) 61 62 63def _fused_bn(channels, momentum=0.9): 64 return nn.BatchNorm2d(channels, momentum=momentum) 65 66 67class BasicBlock(nn.Cell): 68 expansion = 1 69 70 def __init__(self, 71 in_channels, 72 out_channels, 73 stride=1, 74 momentum=0.1): 75 super(BasicBlock, self).__init__() 76 77 self.conv1 = _conv3x3(in_channels, out_channels, stride=stride) 78 self.bn1 = _fused_bn(out_channels, momentum=momentum) 79 self.conv2 = _conv3x3(out_channels, out_channels) 80 self.bn2 = _fused_bn(out_channels, momentum=momentum) 81 self.relu = P.ReLU() 82 self.down_sample_layer = None 83 self.downsample = (in_channels != out_channels) 84 if self.downsample: 85 self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channels, 86 out_channels, 87 stride=stride, 88 padding=0), 89 _fused_bn(out_channels, 90 momentum=momentum)]) 91 self.add = P.Add() 92 93 def construct(self, x): 94 identity = x 95 96 x = self.conv1(x) 97 x = self.relu(x) 98 99 x = self.conv2(x) 100 101 if self.downsample: 102 identity = self.down_sample_layer(identity) 103 104 out = self.add(x, identity) 105 out = self.relu(out) 106 107 return out 108 109 110class ResidualBlock(nn.Cell): 111 expansion = 4 112 113 def __init__(self, 114 in_channels, 115 out_channels, 116 stride=1): 117 super(ResidualBlock, self).__init__() 118 119 out_chls = out_channels // self.expansion 120 self.conv1 = _conv1x1(in_channels, out_chls, stride=1) 121 122 self.conv2 = _conv3x3(out_chls, out_chls, stride=stride) 123 124 self.conv3 = _conv1x1(out_chls, out_channels, stride=1) 125 126 self.relu = P.ReLU() 127 self.downsample = (in_channels != out_channels) 128 self.stride = stride 129 if self.downsample: 130 self.conv_down_sample = _conv1x1(in_channels, out_channels, 131 stride=stride) 132 elif self.stride != 1: 133 self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same') 134 135 self.add = P.Add() 136 137 def construct(self, x): 138 identity = x 139 140 out = self.conv1(x) 141 out = self.relu(out) 142 143 out = self.conv2(out) 144 out = self.relu(out) 145 146 out = self.conv3(out) 147 148 if self.downsample: 149 identity = self.conv_down_sample(identity) 150 elif self.stride != 1: 151 identity = self.maxpool_down(identity) 152 153 out = self.add(out, identity) 154 out = self.relu(out) 155 156 return out 157 158 159class ResNet(nn.Cell): 160 def __init__(self, 161 block, 162 layer_nums, 163 in_channels, 164 out_channels, 165 strides=None, 166 num_classes=100): 167 super(ResNet, self).__init__() 168 169 if strides is None: 170 strides = [1, 2, 2, 2] 171 if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: 172 raise ValueError("the length of " 173 "layer_num, inchannel, outchannel list must be 4!") 174 175 self.conv1 = _conv7x7(3, 64, stride=2) 176 self.bn1 = _fused_bn(64) 177 self.relu = P.ReLU() 178 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') 179 180 self.layer1 = self._make_layer(block, 181 layer_nums[0], 182 in_channel=in_channels[0], 183 out_channel=out_channels[0], 184 stride=strides[0]) 185 self.layer2 = self._make_layer(block, 186 layer_nums[1], 187 in_channel=in_channels[1], 188 out_channel=out_channels[1], 189 stride=strides[1]) 190 self.layer3 = self._make_layer(block, 191 layer_nums[2], 192 in_channel=in_channels[2], 193 out_channel=out_channels[2], 194 stride=strides[2]) 195 self.layer4 = self._make_layer(block, 196 layer_nums[3], 197 in_channel=in_channels[3], 198 out_channel=out_channels[3], 199 stride=strides[3]) 200 201 self.mean = P.ReduceMean(keep_dims=True) 202 self.end_point = nn.Dense(2048, num_classes, has_bias=True, 203 weight_init=weight_variable(), 204 bias_init=weight_variable()).add_flags_recursive(fp16=True) 205 self.squeeze = P.Squeeze() 206 self.cast = P.Cast() 207 208 def _make_layer(self, block, layer_num, in_channel, out_channel, stride): 209 layers = [] 210 resblk = block(in_channel, out_channel, stride=1) 211 layers.append(resblk) 212 213 for _ in range(1, layer_num - 1): 214 resblk = block(out_channel, out_channel, stride=1) 215 layers.append(resblk) 216 217 resblk = block(out_channel, out_channel, stride=stride) 218 layers.append(resblk) 219 220 return nn.SequentialCell(layers) 221 222 def construct(self, x): 223 x = self.conv1(x) 224 x = self.relu(x) 225 c1 = self.maxpool(x) 226 227 c2 = self.layer1(c1) 228 c3 = self.layer2(c2) 229 c4 = self.layer3(c3) 230 c5 = self.layer4(c4) 231 232 out = self.mean(c5, (2, 3)) 233 out = self.squeeze(out) 234 out = self.end_point(out) 235 236 return out 237 238 239def resnet50(class_num=10): 240 return ResNet(ResidualBlock, 241 [3, 4, 6, 3], 242 [64, 256, 512, 1024], 243 [256, 512, 1024, 2048], 244 [2, 2, 2, 1], 245 class_num) 246 247 248class SoftmaxCrossEntropyExpand(LossBase): 249 def __init__(self, sparse=False): 250 super(SoftmaxCrossEntropyExpand, self).__init__() 251 self.exp = P.Exp() 252 self.sum = P.ReduceSum(keep_dims=True) 253 self.onehot = P.OneHot() 254 self.on_value = Tensor(1.0, mstype.float32) 255 self.off_value = Tensor(0.0, mstype.float32) 256 self.div = P.Div() 257 self.log = P.Log() 258 self.sum_cross_entropy = P.ReduceSum(keep_dims=False) 259 self.mul = P.Mul() 260 self.mul2 = P.Mul() 261 self.cast = P.Cast() 262 self.mean = P.ReduceMean(keep_dims=False) 263 self.sparse = sparse 264 self.max = P.ReduceMax(keep_dims=True) 265 self.sub = P.Sub() 266 self.eps = Tensor(1e-24, mstype.float32) 267 268 def construct(self, logit, label): 269 logit = self.cast(logit, mstype.float32) 270 logit_max = self.max(logit, -1) 271 exp = self.exp(self.sub(logit, logit_max)) 272 exp_sum = self.sum(exp, -1) 273 softmax_result = self.div(exp, exp_sum) 274 if self.sparse: 275 label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) 276 277 softmax_result_log = self.log(softmax_result + self.eps) 278 loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1) 279 loss = self.mul2(F.scalar_to_array(-1.0), loss) 280 loss = self.mean(loss, -1) 281 282 return loss 283 284 285rank_id = int(os.environ["RANK_ID"]) 286device_num = int(os.environ["RANK_SIZE"]) 287 288 289class DataGenerator(): 290 def get_parallel_blocks(self, input_, strategy): 291 blocks = [input_] 292 i = 0 293 for stra in strategy: 294 temp = [] 295 while blocks: 296 block = blocks.pop(0) 297 temp.extend(np.split(block, stra, axis=i)) 298 blocks.extend(temp) 299 i += 1 300 return blocks 301 302 def generate_data(self, shape): 303 data = np.arange(np.prod(shape)).reshape(shape) 304 return data 305 306 def input_data(self, shape): 307 data = (self.generate_data(shape)).astype(np.float32) 308 stra = [1] * len(shape) 309 stra[0] = device_num 310 data_parallel = self.get_parallel_blocks(data, stra) 311 return Tensor(data), Tensor(data_parallel[rank_id]) 312 313 def label_data(self, shape): 314 data = (self.generate_data(shape) * 1000 / np.prod(shape)).astype(np.int32) 315 stra = [1] * len(shape) 316 stra[0] = device_num 317 data_parallel = self.get_parallel_blocks(data, stra) 318 return Tensor(data), Tensor(data_parallel[rank_id]) 319 320 321class Dataset(): 322 def __init__(self, predict, label, length=1, input_num=2, repeat_count=1): 323 self.predict = predict 324 self.label = label 325 self.index = 0 326 self.length = length 327 self.input_num = input_num 328 self.repeat_count = repeat_count 329 330 def __iter__(self): 331 return self 332 333 def __next__(self): 334 if self.index >= self.length: 335 raise StopIteration 336 self.index += 1 337 if self.input_num == 2: 338 return (self.predict, self.label) 339 return (self.predict,) 340 341 def reset(self): 342 self.index = 0 343 344 def get_dataset_size(self): 345 return self.length 346 347 def get_repeat_count(self): 348 return self.repeat_count 349 350 351class ModelCallback(Callback): 352 def __init__(self): 353 super(ModelCallback, self).__init__() 354 self.loss_list = [] 355 356 def epoch_end(self, run_context): 357 cb_params = run_context.original_args() 358 result = cb_params.net_outputs 359 self.loss_list.append(result.asnumpy().mean()) 360 361 362def test_train_feed(num_classes=65536): 363 set_algo_parameters(elementwise_op_strategy_follow=True) 364 parallel_callback = ModelCallback() 365 data_gen = DataGenerator() 366 _, input_part = data_gen.input_data((32 * 8, 3, 224, 224)) 367 _, label_part = data_gen.label_data((32 * 8,)) 368 dataset = Dataset(input_part, label_part) 369 net = resnet50(num_classes) 370 loss = SoftmaxCrossEntropyExpand(sparse=True) 371 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 372 model = Model(net, loss_fn=loss, optimizer=opt) 373 model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) 374 loss_value = np.array(parallel_callback.loss_list) 375 expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148] 376 print(loss_value) 377 assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) 378