1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ResNet.""" 16import math 17import numpy as np 18from scipy.stats import truncnorm 19import mindspore.nn as nn 20import mindspore.common.dtype as mstype 21from mindspore.ops import operations as P 22from mindspore.ops import functional as F 23from mindspore.common.tensor import Tensor 24 25import mindspore.hypercomplex.dual as ops 26from mindspore.hypercomplex.utils import get_x_and_y, to_2channel 27 28 29def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): 30 fan_in = in_channel * kernel_size * kernel_size 31 scale = 1.0 32 scale /= max(1., fan_in) 33 stddev = (scale ** 0.5) / .87962566103423978 34 mu, sigma = 0, stddev 35 weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) 36 weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) 37 return Tensor(weight, dtype=mstype.float32) 38 39 40def calculate_gain(nonlinearity, param=None): 41 """calculate_gain""" 42 linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 43 res = 0 44 if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 45 res = 1 46 elif nonlinearity == 'tanh': 47 res = 5.0 / 3 48 elif nonlinearity == 'relu': 49 res = math.sqrt(2.0) 50 elif nonlinearity == 'leaky_relu': 51 if param is None: 52 negative_slope = 0.01 53 elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 54 # True/False are instances of int, hence check above 55 negative_slope = param 56 else: 57 raise ValueError("negative_slope {} not a valid number".format(param)) 58 res = math.sqrt(2.0 / (1 + negative_slope ** 2)) 59 else: 60 raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 61 return res 62 63 64def _calculate_fan_in_and_fan_out(tensor): 65 """_calculate_fan_in_and_fan_out""" 66 dimensions = len(tensor) 67 if dimensions < 2: 68 raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") 69 if dimensions == 2: # Linear 70 fan_in = tensor[1] 71 fan_out = tensor[0] 72 else: 73 num_input_fmaps = tensor[1] 74 num_output_fmaps = tensor[0] 75 receptive_field_size = 1 76 if dimensions > 2: 77 receptive_field_size = tensor[2] * tensor[3] 78 fan_in = num_input_fmaps * receptive_field_size 79 fan_out = num_output_fmaps * receptive_field_size 80 return fan_in, fan_out 81 82 83def _calculate_correct_fan(tensor, mode): 84 mode = mode.lower() 85 valid_modes = ['fan_in', 'fan_out'] 86 if mode not in valid_modes: 87 raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) 88 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 89 return fan_in if mode == 'fan_in' else fan_out 90 91 92def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): 93 fan = _calculate_correct_fan(inputs_shape, mode) 94 gain = calculate_gain(nonlinearity, a) 95 std = gain / math.sqrt(fan) 96 return np.random.normal(0, std, size=inputs_shape).astype(np.float32) 97 98 99def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'): 100 fan = _calculate_correct_fan(inputs_shape, mode) 101 gain = calculate_gain(nonlinearity, a) 102 std = gain / math.sqrt(fan) 103 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 104 return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) 105 106 107def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False): 108 if use_se: 109 weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) 110 else: 111 weight_shape = (out_channel, in_channel, 3, 3) 112 weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu')) 113 if res_base: 114 return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, 115 padding=1, pad_mode='pad', weight_init=weight) 116 return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, 117 padding=0, pad_mode='same', weight_init=weight) 118 119 120def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False): 121 if use_se: 122 weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) 123 else: 124 weight_shape = (out_channel, in_channel, 1, 1) 125 weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu')) 126 if res_base: 127 return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, 128 padding=0, pad_mode='pad', weight_init=weight) 129 return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, 130 padding=0, pad_mode='same', weight_init=weight) 131 132 133def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False): 134 if use_se: 135 weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) 136 else: 137 weight_shape = (out_channel, in_channel, 7, 7) 138 weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu')) 139 if res_base: 140 return ops.Conv2d(in_channel, out_channel, 141 kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight) 142 return ops.Conv2d(in_channel, out_channel, 143 kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) 144 145 146def _bn(channel, res_base=False): 147 if res_base: 148 return ops.BatchNorm2d(channel, eps=1e-5, momentum=0.1, 149 gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) 150 return ops.BatchNorm2d(channel, eps=1e-4, momentum=0.9, 151 gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) 152 153 154def _fc(in_channel, out_channel, use_se=False): 155 if use_se: 156 weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel) 157 weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) 158 else: 159 weight_shape = (out_channel, in_channel) 160 weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) 161 return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) 162 163 164class ResidualBlock(nn.Cell): 165 """ 166 ResNet V1 residual block definition. 167 168 Args: 169 in_channel (int): Input channel. 170 out_channel (int): Output channel. 171 stride (int): Stride size for the first convolutional layer. Default: 1. 172 use_se (bool): Enable SE-ResNet50 net. Default: False. 173 se_block(bool): Use se block in SE-ResNet50 net. Default: False. 174 175 Returns: 176 Tensor, output tensor. 177 178 Examples: 179 >>> ResidualBlock(3, 256, stride=2) 180 """ 181 expansion = 4 182 183 def __init__(self, 184 in_channel, 185 out_channel, 186 stride=1, 187 use_se=False, se_block=False): 188 super(ResidualBlock, self).__init__() 189 self.stride = stride 190 self.use_se = use_se 191 self.se_block = se_block 192 channel = out_channel // self.expansion 193 self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) 194 self.bn1 = _bn(channel) 195 if self.use_se and self.stride != 1: 196 self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), 197 ops.ReLU(), ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) 198 else: 199 self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) 200 self.bn2 = _bn(channel) 201 202 self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) 203 self.bn3 = _bn(out_channel) 204 if self.se_block: 205 self.se_global_pool = P.ReduceMean(keep_dims=False) 206 self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se) 207 self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se) 208 self.se_sigmoid = nn.Sigmoid() 209 self.se_mul = P.Mul() 210 self.relu = ops.ReLU() 211 212 self.down_sample = False 213 214 if stride != 1 or in_channel != out_channel: 215 self.down_sample = True 216 self.down_sample_layer = None 217 218 if self.down_sample: 219 if self.use_se: 220 if stride == 1: 221 self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, 222 stride, use_se=self.use_se), _bn(out_channel)]) 223 else: 224 self.down_sample_layer = nn.SequentialCell([ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), 225 _conv1x1(in_channel, out_channel, 1, 226 use_se=self.use_se), _bn(out_channel)]) 227 else: 228 self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, 229 use_se=self.use_se), _bn(out_channel)]) 230 231 def construct(self, x): 232 identity = x 233 234 out = self.conv1(x) 235 out = self.bn1(out) 236 out = self.relu(out) 237 if self.use_se and self.stride != 1: 238 out = self.e2(out) 239 else: 240 out = self.conv2(out) 241 out = self.bn2(out) 242 out = self.relu(out) 243 out = self.conv3(out) 244 out = self.bn3(out) 245 if self.se_block: 246 out_se = out 247 out = self.se_global_pool(out, (2, 3)) 248 out = self.se_dense_0(out) 249 out = self.relu(out) 250 out = self.se_dense_1(out) 251 out = self.se_sigmoid(out) 252 out = F.reshape(out, F.shape(out) + (1, 1)) 253 out = self.se_mul(out, out_se) 254 255 if self.down_sample: 256 identity = self.down_sample_layer(identity) 257 258 out = out + identity 259 out = self.relu(out) 260 261 return out 262 263 264class ResidualBlockBase(nn.Cell): 265 """ 266 ResNet V1 residual block definition. 267 268 Args: 269 in_channel (int): Input channel. 270 out_channel (int): Output channel. 271 stride (int): Stride size for the first convolutional layer. Default: 1. 272 use_se (bool): Enable SE-ResNet50 net. Default: False. 273 se_block(bool): Use se block in SE-ResNet50 net. Default: False. 274 res_base (bool): Enable parameter setting of resnet18. Default: True. 275 276 Returns: 277 Tensor, output tensor. 278 279 Examples: 280 >>> ResidualBlockBase(3, 256, stride=2) 281 """ 282 283 def __init__(self, 284 in_channel, 285 out_channel, 286 stride=1, 287 use_se=False, 288 se_block=False, 289 res_base=True): 290 super(ResidualBlockBase, self).__init__() 291 self.res_base = res_base 292 self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base) 293 self.bn1d = _bn(out_channel) 294 self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base) 295 self.bn2d = _bn(out_channel) 296 self.relu = ops.ReLU() 297 298 self.down_sample = False 299 if stride != 1 or in_channel != out_channel: 300 self.down_sample = True 301 302 self.down_sample_layer = None 303 if self.down_sample: 304 self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, 305 use_se=use_se, res_base=self.res_base), 306 _bn(out_channel, res_base)]) 307 308 def construct(self, x): 309 identity = x 310 311 out = self.conv1(x) 312 out = self.bn1d(out) 313 out = self.relu(out) 314 315 out = self.conv2(out) 316 out = self.bn2d(out) 317 318 if self.down_sample: 319 identity = self.down_sample_layer(identity) 320 321 out = out + identity 322 out = self.relu(out) 323 324 return out 325 326 327class ResNet(nn.Cell): 328 """ 329 ResNet architecture. 330 331 Args: 332 block (Cell): Block for network. 333 layer_nums (list): Numbers of block in different layers. 334 in_channels (list): Input channel in each layer. 335 out_channels (list): Output channel in each layer. 336 strides (list): Stride size in each layer. 337 num_classes (int): The number of classes that the training images are belonging to. 338 use_se (bool): Enable SE-ResNet50 net. Default: False. 339 se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. 340 res_base (bool): Enable parameter setting of resnet18. Default: False. 341 342 Returns: 343 Tensor, output tensor. 344 345 Examples: 346 >>> ResNet(ResidualBlock, 347 >>> [3, 4, 6, 3], 348 >>> [64, 256, 512, 1024], 349 >>> [256, 512, 1024, 2048], 350 >>> [1, 2, 2, 2], 351 >>> 10) 352 """ 353 354 def __init__(self, 355 block, 356 layer_nums, 357 in_channels, 358 out_channels, 359 strides, 360 num_classes, 361 use_se=False, 362 res_base=False): 363 super(ResNet, self).__init__() 364 365 if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: 366 raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") 367 self.use_se = use_se 368 self.res_base = res_base 369 self.se_block = False 370 if self.use_se: 371 self.se_block = True 372 373 if self.use_se: 374 self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) 375 self.bn1_0 = _bn(32) 376 self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) 377 self.bn1_1 = _bn(32) 378 self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) 379 else: 380 self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base) 381 self.bn1 = _bn(64, self.res_base) 382 self.relu = ops.ReLU() 383 384 if self.res_base: 385 self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1))) 386 self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid") 387 else: 388 self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") 389 390 self.layer1 = self._make_layer(block, 391 layer_nums[0], 392 in_channel=in_channels[0], 393 out_channel=out_channels[0], 394 stride=strides[0], 395 use_se=self.use_se) 396 self.layer2 = self._make_layer(block, 397 layer_nums[1], 398 in_channel=in_channels[1], 399 out_channel=out_channels[1], 400 stride=strides[1], 401 use_se=self.use_se) 402 self.layer3 = self._make_layer(block, 403 layer_nums[2], 404 in_channel=in_channels[2], 405 out_channel=out_channels[2], 406 stride=strides[2], 407 use_se=self.use_se, 408 se_block=self.se_block) 409 self.layer4 = self._make_layer(block, 410 layer_nums[3], 411 in_channel=in_channels[3], 412 out_channel=out_channels[3], 413 stride=strides[3], 414 use_se=self.use_se, 415 se_block=self.se_block) 416 417 self.avgpool = ops.AvgPool2d(4) 418 self.concat = P.Concat(1) 419 self.flatten = nn.Flatten() 420 self.end_point = _fc(16384, num_classes, use_se=self.use_se) 421 422 def construct(self, x): 423 x = to_2channel(x[:, :3], x[:, 3:]) 424 if self.use_se: 425 x = self.conv1_0(x) 426 x = self.bn1_0(x) 427 x = self.relu(x) 428 x = self.conv1_1(x) 429 x = self.bn1_1(x) 430 x = self.relu(x) 431 x = self.conv1_2(x) 432 else: 433 x = self.conv1(x) 434 x = self.bn1(x) 435 x = self.relu(x) 436 if self.res_base: 437 x_1, x_2 = get_x_and_y(x) 438 x_1 = self.pad(x_1) 439 x_2 = self.pad(x_2) 440 x = to_2channel(x_1, x_2) 441 x = self.maxpool(x) 442 443 x = self.layer1(x) 444 x = self.layer2(x) 445 x = self.layer3(x) 446 x = self.layer4(x) 447 448 out = self.avgpool(x) 449 out_x, out_y = get_x_and_y(out) 450 out = self.concat([out_x, out_y]) 451 out = self.flatten(out) 452 out = self.end_point(out) 453 return out 454 455 def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): 456 """ 457 Make stage network of ResNet. 458 459 Args: 460 block (Cell): Resnet block. 461 layer_num (int): Layer number. 462 in_channel (int): Input channel. 463 out_channel (int): Output channel. 464 stride (int): Stride size for the first convolutional layer. 465 se_block(bool): Use se block in SE-ResNet50 net. Default: False. 466 Returns: 467 SequentialCell, the output layer. 468 469 Examples: 470 >>> _make_layer(ResidualBlock, 3, 128, 256, 2) 471 """ 472 layers = [] 473 474 resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) 475 layers.append(resnet_block) 476 if se_block: 477 for _ in range(1, layer_num - 1): 478 resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) 479 layers.append(resnet_block) 480 resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) 481 layers.append(resnet_block) 482 else: 483 for _ in range(1, layer_num): 484 resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) 485 layers.append(resnet_block) 486 return nn.SequentialCell(layers) 487 488 489def resnet18(class_num=10): 490 """ 491 Get ResNet18 neural network. 492 493 Args: 494 class_num (int): Class number. 495 496 Returns: 497 Cell, cell instance of ResNet18 neural network. 498 499 Examples: 500 >>> net = resnet18(10) 501 """ 502 return ResNet(ResidualBlockBase, 503 [2, 2, 2, 2], 504 [64, 64, 128, 256], 505 [64, 128, 256, 512], 506 [1, 2, 2, 2], 507 class_num, 508 res_base=True) 509 510 511def resnet34(class_num=10): 512 """ 513 Get ResNet34 neural network. 514 515 Args: 516 class_num (int): Class number. 517 518 Returns: 519 Cell, cell instance of ResNet34 neural network. 520 521 Examples: 522 >>> net = resnet18(10) 523 """ 524 return ResNet(ResidualBlockBase, 525 [3, 4, 6, 3], 526 [64, 64, 128, 256], 527 [64, 128, 256, 512], 528 [1, 2, 2, 2], 529 class_num, 530 res_base=True) 531 532 533def resnet50(class_num=10): 534 """ 535 Get ResNet50 neural network. 536 537 Args: 538 class_num (int): Class number. 539 540 Returns: 541 Cell, cell instance of ResNet50 neural network. 542 543 Examples: 544 >>> net = resnet50(10) 545 """ 546 return ResNet(ResidualBlock, 547 [3, 4, 6, 3], 548 [64, 256, 512, 1024], 549 [256, 512, 1024, 2048], 550 [1, 2, 2, 2], 551 class_num) 552 553 554def se_resnet50(class_num=1001): 555 """ 556 Get SE-ResNet50 neural network. 557 558 Args: 559 class_num (int): Class number. 560 561 Returns: 562 Cell, cell instance of SE-ResNet50 neural network. 563 564 Examples: 565 >>> net = se-resnet50(1001) 566 """ 567 return ResNet(ResidualBlock, 568 [3, 4, 6, 3], 569 [64, 256, 512, 1024], 570 [256, 512, 1024, 2048], 571 [1, 2, 2, 2], 572 class_num, 573 use_se=True) 574 575 576def resnet101(class_num=1001): 577 """ 578 Get ResNet101 neural network. 579 580 Args: 581 class_num (int): Class number. 582 583 Returns: 584 Cell, cell instance of ResNet101 neural network. 585 586 Examples: 587 >>> net = resnet101(1001) 588 """ 589 return ResNet(ResidualBlock, 590 [3, 4, 23, 3], 591 [64, 256, 512, 1024], 592 [256, 512, 1024, 2048], 593 [1, 2, 2, 2], 594 class_num) 595