• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ResNet."""
16import math
17import numpy as np
18from scipy.stats import truncnorm
19import mindspore.nn as nn
20import mindspore.ops as ops
21import mindspore.common.dtype as mstype
22from mindspore.common.tensor import Tensor
23from tests.dataset_mock import MindData
24
25
26def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
27    fan_in = in_channel * kernel_size * kernel_size
28    scale = 1.0
29    scale /= max(1., fan_in)
30    stddev = (scale ** 0.5) / .87962566103423978
31    mu, sigma = 0, stddev
32    weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
33    weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
34    return Tensor(weight, dtype=mstype.float32)
35
36
37def _weight_variable(shape, factor=0.01):
38    init_value = np.random.randn(*shape).astype(np.float32) * factor
39    return Tensor(init_value)
40
41
42def calculate_gain(nonlinearity, param=None):
43    """calculate_gain"""
44    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
45    res = 0
46    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
47        res = 1
48    elif nonlinearity == 'tanh':
49        res = 5.0 / 3
50    elif nonlinearity == 'relu':
51        res = math.sqrt(2.0)
52    elif nonlinearity == 'leaky_relu':
53        if param is None:
54            neg_slope = 0.01
55        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
56            neg_slope = param
57        else:
58            raise ValueError("neg_slope {} not a valid number".format(param))
59        res = math.sqrt(2.0 / (1 + neg_slope ** 2))
60    else:
61        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
62    return res
63
64
65def _calculate_fan_in_and_fan_out(tensor):
66    """_calculate_fan_in_and_fan_out"""
67    dimensions = len(tensor)
68    if dimensions < 2:
69        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
70    if dimensions == 2:  # Linear
71        fan_in = tensor[1]
72        fan_out = tensor[0]
73    else:
74        num_input_fmaps = tensor[1]
75        num_output_fmaps = tensor[0]
76        receptive_field_size = 1
77        if dimensions > 2:
78            receptive_field_size = tensor[2] * tensor[3]
79        fan_in = num_input_fmaps * receptive_field_size
80        fan_out = num_output_fmaps * receptive_field_size
81    return fan_in, fan_out
82
83
84def _calculate_correct_fan(tensor, mode):
85    mode = mode.lower()
86    valid_modes = ['fan_in', 'fan_out']
87    if mode not in valid_modes:
88        raise ValueError("Unsupported mode {}, please use one of {}".format(mode, valid_modes))
89    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
90    return fan_in if mode == 'fan_in' else fan_out
91
92
93def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
94    fan = _calculate_correct_fan(inputs_shape, mode)
95    gain = calculate_gain(nonlinearity, a)
96    std = gain / math.sqrt(fan)
97    return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
98
99
100def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
101    fan = _calculate_correct_fan(inputs_shape, mode)
102    gain = calculate_gain(nonlinearity, a)
103    std = gain / math.sqrt(fan)
104    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
105    return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
106
107
108def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
109    if use_se:
110        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
111    else:
112        weight_shape = (out_channel, in_channel, 3, 3)
113        weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
114    if res_base:
115        return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
116                         padding=1, pad_mode='pad', weight_init=weight)
117    return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
118                     padding=0, pad_mode='same', weight_init=weight)
119
120
121def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
122    if use_se:
123        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
124    else:
125        weight_shape = (out_channel, in_channel, 1, 1)
126        weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
127    if res_base:
128        return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
129                         padding=0, pad_mode='pad', weight_init=weight)
130    return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
131                     padding=0, pad_mode='same', weight_init=weight)
132
133
134def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
135    if use_se:
136        weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
137    else:
138        weight_shape = (out_channel, in_channel, 7, 7)
139        weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
140    if res_base:
141        return nn.Conv2d(in_channel, out_channel,
142                         kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
143    return nn.Conv2d(in_channel, out_channel,
144                     kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
145
146
147def _bn(channel, res_base=False):
148    if res_base:
149        return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
150                              gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
151    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
152                          gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
153
154
155def _bn_last(channel):
156    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
157                          gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
158
159
160def _fc(in_channel, out_channel, use_se=False):
161    if use_se:
162        weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
163        weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
164    else:
165        weight_shape = (out_channel, in_channel)
166        weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
167    return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
168
169
170class ResidualBlock(nn.Cell):
171    """
172    ResNet V1 residual block definition.
173
174    Args:
175        in_channel (int): Input channel.
176        out_channel (int): Output channel.
177        stride (int): Stride size for the first convolutional layer. Default: 1.
178        use_se (bool): Enable SE-ResNet50 net. Default: False.
179        se_block(bool): Use se block in SE-ResNet50 net. Default: False.
180
181    Returns:
182        Tensor, output tensor.
183
184    Examples:
185        >>> ResidualBlock(3, 256, stride=2)
186    """
187    expansion = 4
188
189    def __init__(self,
190                 in_channel,
191                 out_channel,
192                 stride=1,
193                 use_se=False, se_block=False):
194        super(ResidualBlock, self).__init__()
195        self.stride = stride
196        self.use_se = use_se
197        self.se_block = se_block
198        channel = out_channel // self.expansion
199        self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
200        self.bn1 = _bn(channel)
201        if self.use_se and self.stride != 1:
202            self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
203                                         nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
204        else:
205            self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
206            self.bn2 = _bn(channel)
207
208        self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
209        self.bn3 = _bn(out_channel)
210        if self.se_block:
211            self.se_global_pool = ops.ReduceMean(keep_dims=False)
212            self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
213            self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
214            self.se_sigmoid = nn.Sigmoid()
215            self.se_mul = ops.Mul()
216        self.relu = nn.ReLU()
217
218        self.down_sample = False
219
220        if stride != 1 or in_channel != out_channel:
221            self.down_sample = True
222        self.down_sample_layer = None
223
224        if self.down_sample:
225            if self.use_se:
226                if stride == 1:
227                    self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
228                                                                         stride, use_se=self.use_se), _bn(out_channel)])
229                else:
230                    self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
231                                                                _conv1x1(in_channel, out_channel, 1,
232                                                                         use_se=self.use_se), _bn(out_channel)])
233            else:
234                self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
235                                                                     use_se=self.use_se), _bn(out_channel)])
236
237    def construct(self, x):
238        identity = x  # <IGNORE>
239
240        out = self.conv1(x)
241        out = self.bn1(out)
242        out = self.relu(out)
243        if self.use_se and self.stride != 1:  # <IGNORE>
244            out = self.e2(out)  # <IGNORE>
245        else:  # <IGNORE>
246            out = self.conv2(out)
247            out = self.bn2(out)
248            out = self.relu(out)
249        out = self.conv3(out)
250        out = self.bn3(out)
251        if self.se_block:  # <IGNORE>
252            out_se = out  # <IGNORE>
253            out = self.se_global_pool(out, (2, 3))  # <IGNORE>
254            out = self.se_dense_0(out)  # <IGNORE>
255            out = self.relu(out)  # <IGNORE>
256            out = self.se_dense_1(out)  # <IGNORE>
257            out = self.se_sigmoid(out)  # <IGNORE>
258            out = ops.reshape(out, ops.shape(out) + (1, 1))  # <IGNORE>
259            out = self.se_mul(out, out_se)  # <IGNORE>
260
261        if self.down_sample:  # <IGNORE>
262            identity = self.down_sample_layer(identity)  # <IGNORE>
263
264        out = out + identity
265        out = self.relu(out)
266
267        return out  # <IGNORE>
268
269
270class ResidualBlockBase(nn.Cell):
271    """
272    ResNet V1 residual block definition.
273
274    Args:
275        in_channel (int): Input channel.
276        out_channel (int): Output channel.
277        stride (int): Stride size for the first convolutional layer. Default: 1.
278        use_se (bool): Enable SE-ResNet50 net. Default: False.
279        se_block(bool): Use se block in SE-ResNet50 net. Default: False.
280        res_base (bool): Enable parameter setting of resnet18. Default: True.
281
282    Returns:
283        Tensor, output tensor.
284
285    Examples:
286        >>> ResidualBlockBase(3, 256, stride=2)
287    """
288
289    def __init__(self,
290                 in_channel,
291                 out_channel,
292                 stride=1,
293                 use_se=False,
294                 se_block=False,
295                 res_base=True):
296        super(ResidualBlockBase, self).__init__()
297        self.res_base = res_base
298        self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
299        self.bn1d = _bn(out_channel)
300        self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
301        self.bn2d = _bn(out_channel)
302        self.relu = nn.ReLU()
303
304        self.down_sample = False
305        if stride != 1 or in_channel != out_channel:
306            self.down_sample = True
307
308        self.down_sample_layer = None
309        if self.down_sample:
310            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
311                                                                 use_se=use_se, res_base=self.res_base),
312                                                        _bn(out_channel, res_base)])
313
314    def construct(self, x):
315        identity = x
316
317        out = self.conv1(x)
318        out = self.bn1d(out)
319        out = self.relu(out)
320
321        out = self.conv2(out)
322        out = self.bn2d(out)
323
324        if self.down_sample:
325            identity = self.down_sample_layer(identity)
326
327        out = out + identity
328        out = self.relu(out)
329
330        return out
331
332
333class ResNet(nn.Cell):
334    """
335    ResNet architecture.
336
337    Args:
338        block (Cell): Block for network.
339        layer_nums (list): Numbers of block in different layers.
340        in_channels (list): Input channel in each layer.
341        out_channels (list): Output channel in each layer.
342        strides (list):  Stride size in each layer.
343        num_classes (int): The number of classes that the training images are belonging to.
344        use_se (bool): Enable SE-ResNet50 net. Default: False.
345        se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
346        res_base (bool): Enable parameter setting of resnet18. Default: False.
347
348    Returns:
349        Tensor, output tensor.
350
351    Examples:
352        >>> ResNet(ResidualBlock,
353        >>>        [3, 4, 6, 3],
354        >>>        [64, 256, 512, 1024],
355        >>>        [256, 512, 1024, 2048],
356        >>>        [1, 2, 2, 2],
357        >>>        10)
358    """
359
360    def __init__(self,
361                 block,
362                 layer_nums,
363                 in_channels,
364                 out_channels,
365                 strides,
366                 num_classes,
367                 use_se=False,
368                 res_base=False):
369        super(ResNet, self).__init__()
370
371        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
372            raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
373        self.use_se = use_se
374        self.res_base = res_base
375        self.se_block = False
376        if self.use_se:
377            self.se_block = True
378
379        if self.use_se:
380            self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
381            self.bn1_0 = _bn(32)
382            self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
383            self.bn1_1 = _bn(32)
384            self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
385        else:
386            self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
387        self.bn1 = _bn(64, self.res_base)
388        self.relu = ops.ReLU()
389
390        if self.res_base:
391            self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
392            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
393        else:
394            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
395
396        self.layer1 = self._make_layer(block,
397                                       layer_nums[0],
398                                       in_channel=in_channels[0],
399                                       out_channel=out_channels[0],
400                                       stride=strides[0],
401                                       use_se=self.use_se)
402        self.layer2 = self._make_layer(block,
403                                       layer_nums[1],
404                                       in_channel=in_channels[1],
405                                       out_channel=out_channels[1],
406                                       stride=strides[1],
407                                       use_se=self.use_se)
408        self.layer3 = self._make_layer(block,
409                                       layer_nums[2],
410                                       in_channel=in_channels[2],
411                                       out_channel=out_channels[2],
412                                       stride=strides[2],
413                                       use_se=self.use_se,
414                                       se_block=self.se_block)
415        self.layer4 = self._make_layer(block,
416                                       layer_nums[3],
417                                       in_channel=in_channels[3],
418                                       out_channel=out_channels[3],
419                                       stride=strides[3],
420                                       use_se=self.use_se,
421                                       se_block=self.se_block)
422
423        self.mean = ops.ReduceMean(keep_dims=True)
424        self.flatten = nn.Flatten()
425        self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
426
427    def construct(self, x):
428        if self.use_se:  # <IGNORE>
429            x = self.conv1_0(x)  # <IGNORE>
430            x = self.bn1_0(x)  # <IGNORE>
431            x = self.relu(x)  # <IGNORE>
432            x = self.conv1_1(x)  # <IGNORE>
433            x = self.bn1_1(x)  # <IGNORE>
434            x = self.relu(x)  # <IGNORE>
435            x = self.conv1_2(x)  # <IGNORE>
436        else:  # <IGNORE>
437            x = self.conv1(x)
438        x = self.bn1(x)
439        x = self.relu(x)
440        if self.res_base:  # <IGNORE>
441            x = self.pad(x)  # <IGNORE>
442        c1 = self.maxpool(x)
443
444        c2 = self.layer1(c1)
445        c3 = self.layer2(c2)
446        c4 = self.layer3(c3)
447        c5 = self.layer4(c4)
448
449        out = self.mean(c5, (2, 3))
450        out = self.flatten(out)
451        out = self.end_point(out)
452
453        return out  # <IGNORE>
454
455    def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
456        """
457        Make stage network of ResNet.
458
459        Args:
460            block (Cell): Resnet block.
461            layer_num (int): Layer number.
462            in_channel (int): Input channel.
463            out_channel (int): Output channel.
464            stride (int): Stride size for the first convolutional layer.
465            se_block(bool): Use se block in SE-ResNet50 net. Default: False.
466        Returns:
467            SequentialCell, the output layer.
468
469        Examples:
470            >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
471        """
472        layers = []
473
474        resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
475        layers.append(resnet_block)
476        if se_block:
477            for _ in range(1, layer_num - 1):
478                resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
479                layers.append(resnet_block)
480            resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
481            layers.append(resnet_block)
482        else:
483            for _ in range(1, layer_num):
484                resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
485                layers.append(resnet_block)
486        return nn.SequentialCell(layers)
487
488
489def resnet50(class_num=10):
490    """
491    Get ResNet50 neural network.
492
493    Args:
494        class_num (int): Class number.
495
496    Returns:
497        Cell, cell instance of ResNet50 neural network.
498
499    Examples:
500        >>> net = resnet50(10)
501    """
502    return ResNet(ResidualBlock,
503                  [3, 4, 6, 3],
504                  [64, 256, 512, 1024],
505                  [256, 512, 1024, 2048],
506                  [1, 2, 2, 2],
507                  class_num)
508
509
510class DatasetResNet(MindData):
511    def __init__(self, predict, label, length=3):
512        super(DatasetResNet, self).__init__(size=length)
513        self.predict = predict
514        self.label = label
515        self.index = 0
516        self.length = length
517
518    def __iter__(self):
519        return self
520
521    def __next__(self):
522        if self.index >= self.length:
523            raise StopIteration
524        self.index += 1
525        return self.predict, self.label
526
527    def reset(self):
528        self.index = 0
529