• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ResNet."""
16import numpy as np
17import mindspore.nn as nn
18from mindspore.ops import operations as P
19from mindspore.common.tensor import Tensor
20
21
22def _weight_variable(shape, factor=0.01):
23    init_value = np.random.randn(*shape).astype(np.float32) * factor
24    return Tensor(init_value)
25
26
27def _conv3x3(in_channel, out_channel, stride=1):
28    weight_shape = (out_channel, in_channel, 3, 3)
29    weight = _weight_variable(weight_shape)
30    return nn.Conv2d(in_channel, out_channel,
31                     kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
32
33
34def _conv1x1(in_channel, out_channel, stride=1):
35    weight_shape = (out_channel, in_channel, 1, 1)
36    weight = _weight_variable(weight_shape)
37    return nn.Conv2d(in_channel, out_channel,
38                     kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
39
40
41def _conv7x7(in_channel, out_channel, stride=1):
42    weight_shape = (out_channel, in_channel, 7, 7)
43    weight = _weight_variable(weight_shape)
44    return nn.Conv2d(in_channel, out_channel,
45                     kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
46
47
48def _bn(channel):
49    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
50                          gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
51
52
53def _bn_last(channel):
54    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
55                          gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
56
57
58def _fc(in_channel, out_channel):
59    weight_shape = (out_channel, in_channel)
60    weight = _weight_variable(weight_shape)
61    return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
62
63
64class ResidualBlock(nn.Cell):
65    """
66    ResNet V1 residual block definition.
67
68    Args:
69        in_channel (int): Input channel.
70        out_channel (int): Output channel.
71        stride (int): Stride size for the first convolutional layer. Default: 1.
72
73    Returns:
74        Tensor, output tensor.
75
76    Examples:
77        >>> ResidualBlock(3, 256, stride=2)
78    """
79    expansion = 4
80
81    def __init__(self,
82                 in_channel,
83                 out_channel,
84                 stride=1):
85        super(ResidualBlock, self).__init__()
86
87        channel = out_channel // self.expansion
88        self.conv1 = _conv1x1(in_channel, channel, stride=1)
89        self.bn1 = _bn(channel)
90
91        self.conv2 = _conv3x3(channel, channel, stride=stride)
92        self.bn2 = _bn(channel)
93
94        self.conv3 = _conv1x1(channel, out_channel, stride=1)
95        self.bn3 = _bn_last(out_channel)
96
97        self.relu = nn.ReLU()
98
99        self.down_sample = False
100
101        if stride != 1 or in_channel != out_channel:
102            self.down_sample = True
103        self.down_sample_layer = None
104
105        if self.down_sample:
106            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
107                                                        _bn(out_channel)])
108        self.add = P.Add()
109
110    def construct(self, x):
111        identity = x
112
113        out = self.conv1(x)
114        out = self.bn1(out)
115        out = self.relu(out)
116
117        out = self.conv2(out)
118        out = self.bn2(out)
119        out = self.relu(out)
120
121        out = self.conv3(out)
122        out = self.bn3(out)
123
124        if self.down_sample:
125            identity = self.down_sample_layer(identity)
126
127        out = self.add(out, identity)
128        out = self.relu(out)
129
130        return out
131
132
133class ResNet(nn.Cell):
134    """
135    ResNet architecture.
136
137    Args:
138        block (Cell): Block for network.
139        layer_nums (list): Numbers of block in different layers.
140        in_channels (list): Input channel in each layer.
141        out_channels (list): Output channel in each layer.
142        strides (list):  Stride size in each layer.
143        num_classes (int): The number of classes that the training images are belonging to.
144    Returns:
145        Tensor, output tensor.
146
147    Examples:
148        >>> ResNet(ResidualBlock,
149        >>>        [3, 4, 6, 3],
150        >>>        [64, 256, 512, 1024],
151        >>>        [256, 512, 1024, 2048],
152        >>>        [1, 2, 2, 2],
153        >>>        10)
154    """
155
156    def __init__(self,
157                 block,
158                 layer_nums,
159                 in_channels,
160                 out_channels,
161                 strides,
162                 num_classes):
163        super(ResNet, self).__init__()
164
165        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
166            raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
167
168        self.conv1 = _conv7x7(3, 64, stride=2)
169        self.bn1 = _bn(64)
170        self.relu = P.ReLU()
171        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
172
173        self.layer1 = self._make_layer(block,
174                                       layer_nums[0],
175                                       in_channel=in_channels[0],
176                                       out_channel=out_channels[0],
177                                       stride=strides[0])
178        self.layer2 = self._make_layer(block,
179                                       layer_nums[1],
180                                       in_channel=in_channels[1],
181                                       out_channel=out_channels[1],
182                                       stride=strides[1])
183        self.layer3 = self._make_layer(block,
184                                       layer_nums[2],
185                                       in_channel=in_channels[2],
186                                       out_channel=out_channels[2],
187                                       stride=strides[2])
188        self.layer4 = self._make_layer(block,
189                                       layer_nums[3],
190                                       in_channel=in_channels[3],
191                                       out_channel=out_channels[3],
192                                       stride=strides[3])
193
194        self.mean = P.ReduceMean(keep_dims=True)
195        self.flatten = nn.Flatten()
196        self.end_point = _fc(out_channels[3], num_classes)
197
198    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
199        """
200        Make stage network of ResNet.
201
202        Args:
203            block (Cell): Resnet block.
204            layer_num (int): Layer number.
205            in_channel (int): Input channel.
206            out_channel (int): Output channel.
207            stride (int): Stride size for the first convolutional layer.
208
209        Returns:
210            SequentialCell, the output layer.
211
212        Examples:
213            >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
214        """
215        layers = []
216
217        resnet_block = block(in_channel, out_channel, stride=stride)
218        layers.append(resnet_block)
219
220        for _ in range(1, layer_num):
221            resnet_block = block(out_channel, out_channel, stride=1)
222            layers.append(resnet_block)
223
224        return nn.SequentialCell(layers)
225
226    def construct(self, x):
227        x = self.conv1(x)
228        x = self.bn1(x)
229        x = self.relu(x)
230        c1 = self.maxpool(x)
231
232        c2 = self.layer1(c1)
233        c3 = self.layer2(c2)
234        c4 = self.layer3(c3)
235        c5 = self.layer4(c4)
236
237        out = self.mean(c5, (2, 3))
238        out = self.flatten(out)
239        out = self.end_point(out)
240
241        return out
242
243
244def resnet50(class_num=10):
245    """
246    Get ResNet50 neural network.
247
248    Args:
249        class_num (int): Class number.
250
251    Returns:
252        Cell, cell instance of ResNet50 neural network.
253
254    Examples:
255        >>> net = resnet50(10)
256    """
257    return ResNet(ResidualBlock,
258                  [3, 4, 6, 3],
259                  [64, 256, 512, 1024],
260                  [256, 512, 1024, 2048],
261                  [1, 2, 2, 2],
262                  class_num)
263
264def resnet101(class_num=1001):
265    """
266    Get ResNet101 neural network.
267
268    Args:
269        class_num (int): Class number.
270
271    Returns:
272        Cell, cell instance of ResNet101 neural network.
273
274    Examples:
275        >>> net = resnet101(1001)
276    """
277    return ResNet(ResidualBlock,
278                  [3, 4, 23, 3],
279                  [64, 256, 512, 1024],
280                  [256, 512, 1024, 2048],
281                  [1, 2, 2, 2],
282                  class_num)
283