• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ResNet."""
16import math
17import numpy as np
18from scipy.stats import truncnorm
19import mindspore.nn as nn
20import mindspore.common.dtype as mstype
21from mindspore.ops import operations as P
22from mindspore.ops import functional as F
23from mindspore.common.tensor import Tensor
24
25import mindspore.hypercomplex.dual as ops
26from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
27
28
29def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
30    fan_in = in_channel * kernel_size * kernel_size
31    scale = 1.0
32    scale /= max(1., fan_in)
33    stddev = (scale ** 0.5) / .87962566103423978
34    mu, sigma = 0, stddev
35    weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
36    weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
37    return Tensor(weight, dtype=mstype.float32)
38
39
40def calculate_gain(nonlinearity, param=None):
41    """calculate_gain"""
42    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
43    res = 0
44    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
45        res = 1
46    elif nonlinearity == 'tanh':
47        res = 5.0 / 3
48    elif nonlinearity == 'relu':
49        res = math.sqrt(2.0)
50    elif nonlinearity == 'leaky_relu':
51        if param is None:
52            negative_slope = 0.01
53        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
54            # True/False are instances of int, hence check above
55            negative_slope = param
56        else:
57            raise ValueError("negative_slope {} not a valid number".format(param))
58        res = math.sqrt(2.0 / (1 + negative_slope ** 2))
59    else:
60        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
61    return res
62
63
64def _calculate_fan_in_and_fan_out(tensor):
65    """_calculate_fan_in_and_fan_out"""
66    dimensions = len(tensor)
67    if dimensions < 2:
68        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
69    if dimensions == 2:  # Linear
70        fan_in = tensor[1]
71        fan_out = tensor[0]
72    else:
73        num_input_fmaps = tensor[1]
74        num_output_fmaps = tensor[0]
75        receptive_field_size = 1
76        if dimensions > 2:
77            receptive_field_size = tensor[2] * tensor[3]
78        fan_in = num_input_fmaps * receptive_field_size
79        fan_out = num_output_fmaps * receptive_field_size
80    return fan_in, fan_out
81
82
83def _calculate_correct_fan(tensor, mode):
84    mode = mode.lower()
85    valid_modes = ['fan_in', 'fan_out']
86    if mode not in valid_modes:
87        raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
88    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
89    return fan_in if mode == 'fan_in' else fan_out
90
91
92def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
93    fan = _calculate_correct_fan(inputs_shape, mode)
94    gain = calculate_gain(nonlinearity, a)
95    std = gain / math.sqrt(fan)
96    return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
97
98
99def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
100    fan = _calculate_correct_fan(inputs_shape, mode)
101    gain = calculate_gain(nonlinearity, a)
102    std = gain / math.sqrt(fan)
103    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
104    return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
105
106
107def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
108    if use_se:
109        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
110    else:
111        weight_shape = (out_channel, in_channel, 3, 3)
112        weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
113    if res_base:
114        return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
115                          padding=1, pad_mode='pad', weight_init=weight)
116    return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
117                      padding=0, pad_mode='same', weight_init=weight)
118
119
120def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
121    if use_se:
122        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
123    else:
124        weight_shape = (out_channel, in_channel, 1, 1)
125        weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
126    if res_base:
127        return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
128                          padding=0, pad_mode='pad', weight_init=weight)
129    return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
130                      padding=0, pad_mode='same', weight_init=weight)
131
132
133def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
134    if use_se:
135        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
136    else:
137        weight_shape = (out_channel, in_channel, 7, 7)
138        weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
139    if res_base:
140        return ops.Conv2d(in_channel, out_channel,
141                          kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
142    return ops.Conv2d(in_channel, out_channel,
143                      kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
144
145
146def _bn(channel, res_base=False):
147    if res_base:
148        return ops.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
149                               gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
150    return ops.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
151                           gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
152
153
154def _fc(in_channel, out_channel, use_se=False):
155    if use_se:
156        weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
157        weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
158    else:
159        weight_shape = (out_channel, in_channel)
160        weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
161    return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
162
163
164class ResidualBlock(nn.Cell):
165    """
166    ResNet V1 residual block definition.
167
168    Args:
169        in_channel (int): Input channel.
170        out_channel (int): Output channel.
171        stride (int): Stride size for the first convolutional layer. Default: 1.
172        use_se (bool): Enable SE-ResNet50 net. Default: False.
173        se_block(bool): Use se block in SE-ResNet50 net. Default: False.
174
175    Returns:
176        Tensor, output tensor.
177
178    Examples:
179        >>> ResidualBlock(3, 256, stride=2)
180    """
181    expansion = 4
182
183    def __init__(self,
184                 in_channel,
185                 out_channel,
186                 stride=1,
187                 use_se=False, se_block=False):
188        super(ResidualBlock, self).__init__()
189        self.stride = stride
190        self.use_se = use_se
191        self.se_block = se_block
192        channel = out_channel // self.expansion
193        self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
194        self.bn1 = _bn(channel)
195        if self.use_se and self.stride != 1:
196            self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
197                                         ops.ReLU(), ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
198        else:
199            self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
200            self.bn2 = _bn(channel)
201
202        self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
203        self.bn3 = _bn(out_channel)
204        if self.se_block:
205            self.se_global_pool = P.ReduceMean(keep_dims=False)
206            self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
207            self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
208            self.se_sigmoid = nn.Sigmoid()
209            self.se_mul = P.Mul()
210        self.relu = ops.ReLU()
211
212        self.down_sample = False
213
214        if stride != 1 or in_channel != out_channel:
215            self.down_sample = True
216        self.down_sample_layer = None
217
218        if self.down_sample:
219            if self.use_se:
220                if stride == 1:
221                    self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
222                                                                         stride, use_se=self.use_se), _bn(out_channel)])
223                else:
224                    self.down_sample_layer = nn.SequentialCell([ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
225                                                                _conv1x1(in_channel, out_channel, 1,
226                                                                         use_se=self.use_se), _bn(out_channel)])
227            else:
228                self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
229                                                                     use_se=self.use_se), _bn(out_channel)])
230
231    def construct(self, x):
232        identity = x
233
234        out = self.conv1(x)
235        out = self.bn1(out)
236        out = self.relu(out)
237        if self.use_se and self.stride != 1:
238            out = self.e2(out)
239        else:
240            out = self.conv2(out)
241            out = self.bn2(out)
242            out = self.relu(out)
243        out = self.conv3(out)
244        out = self.bn3(out)
245        if self.se_block:
246            out_se = out
247            out = self.se_global_pool(out, (2, 3))
248            out = self.se_dense_0(out)
249            out = self.relu(out)
250            out = self.se_dense_1(out)
251            out = self.se_sigmoid(out)
252            out = F.reshape(out, F.shape(out) + (1, 1))
253            out = self.se_mul(out, out_se)
254
255        if self.down_sample:
256            identity = self.down_sample_layer(identity)
257
258        out = out + identity
259        out = self.relu(out)
260
261        return out
262
263
264class ResidualBlockBase(nn.Cell):
265    """
266    ResNet V1 residual block definition.
267
268    Args:
269        in_channel (int): Input channel.
270        out_channel (int): Output channel.
271        stride (int): Stride size for the first convolutional layer. Default: 1.
272        use_se (bool): Enable SE-ResNet50 net. Default: False.
273        se_block(bool): Use se block in SE-ResNet50 net. Default: False.
274        res_base (bool): Enable parameter setting of resnet18. Default: True.
275
276    Returns:
277        Tensor, output tensor.
278
279    Examples:
280        >>> ResidualBlockBase(3, 256, stride=2)
281    """
282
283    def __init__(self,
284                 in_channel,
285                 out_channel,
286                 stride=1,
287                 use_se=False,
288                 se_block=False,
289                 res_base=True):
290        super(ResidualBlockBase, self).__init__()
291        self.res_base = res_base
292        self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
293        self.bn1d = _bn(out_channel)
294        self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
295        self.bn2d = _bn(out_channel)
296        self.relu = ops.ReLU()
297
298        self.down_sample = False
299        if stride != 1 or in_channel != out_channel:
300            self.down_sample = True
301
302        self.down_sample_layer = None
303        if self.down_sample:
304            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
305                                                                 use_se=use_se, res_base=self.res_base),
306                                                        _bn(out_channel, res_base)])
307
308    def construct(self, x):
309        identity = x
310
311        out = self.conv1(x)
312        out = self.bn1d(out)
313        out = self.relu(out)
314
315        out = self.conv2(out)
316        out = self.bn2d(out)
317
318        if self.down_sample:
319            identity = self.down_sample_layer(identity)
320
321        out = out + identity
322        out = self.relu(out)
323
324        return out
325
326
327class ResNet(nn.Cell):
328    """
329    ResNet architecture.
330
331    Args:
332        block (Cell): Block for network.
333        layer_nums (list): Numbers of block in different layers.
334        in_channels (list): Input channel in each layer.
335        out_channels (list): Output channel in each layer.
336        strides (list):  Stride size in each layer.
337        num_classes (int): The number of classes that the training images are belonging to.
338        use_se (bool): Enable SE-ResNet50 net. Default: False.
339        se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
340        res_base (bool): Enable parameter setting of resnet18. Default: False.
341
342    Returns:
343        Tensor, output tensor.
344
345    Examples:
346        >>> ResNet(ResidualBlock,
347        >>>        [3, 4, 6, 3],
348        >>>        [64, 256, 512, 1024],
349        >>>        [256, 512, 1024, 2048],
350        >>>        [1, 2, 2, 2],
351        >>>        10)
352    """
353
354    def __init__(self,
355                 block,
356                 layer_nums,
357                 in_channels,
358                 out_channels,
359                 strides,
360                 num_classes,
361                 use_se=False,
362                 res_base=False):
363        super(ResNet, self).__init__()
364
365        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
366            raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
367        self.use_se = use_se
368        self.res_base = res_base
369        self.se_block = False
370        if self.use_se:
371            self.se_block = True
372
373        if self.use_se:
374            self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
375            self.bn1_0 = _bn(32)
376            self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
377            self.bn1_1 = _bn(32)
378            self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
379        else:
380            self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
381        self.bn1 = _bn(64, self.res_base)
382        self.relu = ops.ReLU()
383
384        if self.res_base:
385            self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
386            self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
387        else:
388            self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
389
390        self.layer1 = self._make_layer(block,
391                                       layer_nums[0],
392                                       in_channel=in_channels[0],
393                                       out_channel=out_channels[0],
394                                       stride=strides[0],
395                                       use_se=self.use_se)
396        self.layer2 = self._make_layer(block,
397                                       layer_nums[1],
398                                       in_channel=in_channels[1],
399                                       out_channel=out_channels[1],
400                                       stride=strides[1],
401                                       use_se=self.use_se)
402        self.layer3 = self._make_layer(block,
403                                       layer_nums[2],
404                                       in_channel=in_channels[2],
405                                       out_channel=out_channels[2],
406                                       stride=strides[2],
407                                       use_se=self.use_se,
408                                       se_block=self.se_block)
409        self.layer4 = self._make_layer(block,
410                                       layer_nums[3],
411                                       in_channel=in_channels[3],
412                                       out_channel=out_channels[3],
413                                       stride=strides[3],
414                                       use_se=self.use_se,
415                                       se_block=self.se_block)
416
417        self.avgpool = ops.AvgPool2d(4)
418        self.concat = P.Concat(1)
419        self.flatten = nn.Flatten()
420        self.end_point = _fc(16384, num_classes, use_se=self.use_se)
421
422    def construct(self, x):
423        x = to_2channel(x[:, :3], x[:, 3:])
424        if self.use_se:
425            x = self.conv1_0(x)
426            x = self.bn1_0(x)
427            x = self.relu(x)
428            x = self.conv1_1(x)
429            x = self.bn1_1(x)
430            x = self.relu(x)
431            x = self.conv1_2(x)
432        else:
433            x = self.conv1(x)
434        x = self.bn1(x)
435        x = self.relu(x)
436        if self.res_base:
437            x_1, x_2 = get_x_and_y(x)
438            x_1 = self.pad(x_1)
439            x_2 = self.pad(x_2)
440            x = to_2channel(x_1, x_2)
441        x = self.maxpool(x)
442
443        x = self.layer1(x)
444        x = self.layer2(x)
445        x = self.layer3(x)
446        x = self.layer4(x)
447
448        out = self.avgpool(x)
449        out_x, out_y = get_x_and_y(out)
450        out = self.concat([out_x, out_y])
451        out = self.flatten(out)
452        out = self.end_point(out)
453        return out
454
455    def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
456        """
457        Make stage network of ResNet.
458
459        Args:
460            block (Cell): Resnet block.
461            layer_num (int): Layer number.
462            in_channel (int): Input channel.
463            out_channel (int): Output channel.
464            stride (int): Stride size for the first convolutional layer.
465            se_block(bool): Use se block in SE-ResNet50 net. Default: False.
466        Returns:
467            SequentialCell, the output layer.
468
469        Examples:
470            >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
471        """
472        layers = []
473
474        resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
475        layers.append(resnet_block)
476        if se_block:
477            for _ in range(1, layer_num - 1):
478                resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
479                layers.append(resnet_block)
480            resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
481            layers.append(resnet_block)
482        else:
483            for _ in range(1, layer_num):
484                resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
485                layers.append(resnet_block)
486        return nn.SequentialCell(layers)
487
488
489def resnet18(class_num=10):
490    """
491    Get ResNet18 neural network.
492
493    Args:
494        class_num (int): Class number.
495
496    Returns:
497        Cell, cell instance of ResNet18 neural network.
498
499    Examples:
500        >>> net = resnet18(10)
501    """
502    return ResNet(ResidualBlockBase,
503                  [2, 2, 2, 2],
504                  [64, 64, 128, 256],
505                  [64, 128, 256, 512],
506                  [1, 2, 2, 2],
507                  class_num,
508                  res_base=True)
509
510
511def resnet34(class_num=10):
512    """
513    Get ResNet34 neural network.
514
515    Args:
516        class_num (int): Class number.
517
518    Returns:
519        Cell, cell instance of ResNet34 neural network.
520
521    Examples:
522        >>> net = resnet18(10)
523    """
524    return ResNet(ResidualBlockBase,
525                  [3, 4, 6, 3],
526                  [64, 64, 128, 256],
527                  [64, 128, 256, 512],
528                  [1, 2, 2, 2],
529                  class_num,
530                  res_base=True)
531
532
533def resnet50(class_num=10):
534    """
535    Get ResNet50 neural network.
536
537    Args:
538        class_num (int): Class number.
539
540    Returns:
541        Cell, cell instance of ResNet50 neural network.
542
543    Examples:
544        >>> net = resnet50(10)
545    """
546    return ResNet(ResidualBlock,
547                  [3, 4, 6, 3],
548                  [64, 256, 512, 1024],
549                  [256, 512, 1024, 2048],
550                  [1, 2, 2, 2],
551                  class_num)
552
553
554def se_resnet50(class_num=1001):
555    """
556    Get SE-ResNet50 neural network.
557
558    Args:
559        class_num (int): Class number.
560
561    Returns:
562        Cell, cell instance of SE-ResNet50 neural network.
563
564    Examples:
565        >>> net = se-resnet50(1001)
566    """
567    return ResNet(ResidualBlock,
568                  [3, 4, 6, 3],
569                  [64, 256, 512, 1024],
570                  [256, 512, 1024, 2048],
571                  [1, 2, 2, 2],
572                  class_num,
573                  use_se=True)
574
575
576def resnet101(class_num=1001):
577    """
578    Get ResNet101 neural network.
579
580    Args:
581        class_num (int): Class number.
582
583    Returns:
584        Cell, cell instance of ResNet101 neural network.
585
586    Examples:
587        >>> net = resnet101(1001)
588    """
589    return ResNet(ResidualBlock,
590                  [3, 4, 23, 3],
591                  [64, 256, 512, 1024],
592                  [256, 512, 1024, 2048],
593                  [1, 2, 2, 2],
594                  class_num)
595