• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ResNet."""
16import numpy as np
17import mindspore.nn as nn
18import mindspore.common.initializer as weight_init
19from mindspore.ops import operations as P
20from mindspore import Tensor
21from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant
22from mindspore.compression.quant import create_quant_config
23
24_ema_decay = 0.999
25_symmetric = True
26_fake = True
27_per_channel = True
28_quant_config = create_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
29
30
31def _weight_variable(shape, factor=0.01):
32    init_value = np.random.randn(*shape).astype(np.float32) * factor
33    return Tensor(init_value)
34
35
36def _conv3x3(in_channel, out_channel, stride=1):
37    weight_shape = (out_channel, in_channel, 3, 3)
38    weight = _weight_variable(weight_shape)
39    return nn.Conv2d(in_channel, out_channel,
40                     kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
41
42
43def _conv1x1(in_channel, out_channel, stride=1):
44    weight_shape = (out_channel, in_channel, 1, 1)
45    weight = _weight_variable(weight_shape)
46    return nn.Conv2d(in_channel, out_channel,
47                     kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
48
49
50def _conv7x7(in_channel, out_channel, stride=1):
51    weight_shape = (out_channel, in_channel, 7, 7)
52    weight = _weight_variable(weight_shape)
53    return nn.Conv2d(in_channel, out_channel,
54                     kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
55
56
57def _bn(channel):
58    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
59                          gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
60
61
62def _bn_last(channel):
63    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
64                          gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
65
66
67def _fc(in_channel, out_channel):
68    weight_shape = (out_channel, in_channel)
69    weight = _weight_variable(weight_shape)
70    return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
71
72
73class ConvBNReLU(nn.Cell):
74    """
75    Convolution/Depthwise fused with Batchnorm and ReLU block definition.
76
77    Args:
78        in_planes (int): Input channel.
79        out_planes (int): Output channel.
80        kernel_size (int): Input kernel size.
81        stride (int): Stride size for the first convolutional layer. Default: 1.
82        groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
83
84    Returns:
85        Tensor, output tensor.
86
87    Examples:
88        >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
89    """
90
91    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
92        super(ConvBNReLU, self).__init__()
93        padding = (kernel_size - 1) // 2
94        conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
95                                 group=groups, fake=_fake, quant_config=_quant_config)
96        layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
97        self.features = nn.SequentialCell(layers)
98
99    def construct(self, x):
100        output = self.features(x)
101        return output
102
103
104class ResidualBlock(nn.Cell):
105    """
106    ResNet V1 residual block definition.
107
108    Args:
109        in_channel (int): Input channel.
110        out_channel (int): Output channel.
111        stride (int): Stride size for the first convolutional layer. Default: 1.
112
113    Returns:
114        Tensor, output tensor.
115
116    Examples:
117        >>> ResidualBlock(3, 256, stride=2)
118    """
119    expansion = 4
120
121    def __init__(self,
122                 in_channel,
123                 out_channel,
124                 stride=1):
125        super(ResidualBlock, self).__init__()
126
127        channel = out_channel // self.expansion
128        self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
129        self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
130        self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
131                                                          quant_config=_quant_config,
132                                                          kernel_size=1, stride=1, pad_mode='same', padding=0),
133                                        FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
134                                        ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
135                                                                           quant_config=_quant_config,
136                                                                           kernel_size=1, stride=1,
137                                                                           pad_mode='same', padding=0)
138
139        self.down_sample = False
140
141        if stride != 1 or in_channel != out_channel:
142            self.down_sample = True
143        self.down_sample_layer = None
144
145        if self.down_sample:
146            self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel,
147                                                                          quant_config=_quant_config,
148                                                                          kernel_size=1, stride=stride,
149                                                                          pad_mode='same', padding=0),
150                                                        FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
151                                                                                    symmetric=False)
152                                                        ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel,
153                                                                                           fake=_fake,
154                                                                                           quant_config=_quant_config,
155                                                                                           kernel_size=1,
156                                                                                           stride=stride,
157                                                                                           pad_mode='same',
158                                                                                           padding=0)
159        self.add = nn.TensorAddQuant()
160        self.relu = P.ReLU()
161
162    def construct(self, x):
163        identity = x
164        out = self.conv1(x)
165        out = self.conv2(out)
166        out = self.conv3(out)
167
168        if self.down_sample:
169            identity = self.down_sample_layer(identity)
170
171        out = self.add(out, identity)
172        out = self.relu(out)
173
174        return out
175
176
177class ResNet(nn.Cell):
178    """
179    ResNet architecture.
180
181    Args:
182        block (Cell): Block for network.
183        layer_nums (list): Numbers of block in different layers.
184        in_channels (list): Input channel in each layer.
185        out_channels (list): Output channel in each layer.
186        strides (list):  Stride size in each layer.
187        num_classes (int): The number of classes that the training images are belonging to.
188    Returns:
189        Tensor, output tensor.
190
191    Examples:
192        >>> ResNet(ResidualBlock,
193        >>>        [3, 4, 6, 3],
194        >>>        [64, 256, 512, 1024],
195        >>>        [256, 512, 1024, 2048],
196        >>>        [1, 2, 2, 2],
197        >>>        10)
198    """
199
200    def __init__(self,
201                 block,
202                 layer_nums,
203                 in_channels,
204                 out_channels,
205                 strides,
206                 num_classes):
207        super(ResNet, self).__init__()
208
209        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
210            raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
211
212        self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
213        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
214
215        self.layer1 = self._make_layer(block,
216                                       layer_nums[0],
217                                       in_channel=in_channels[0],
218                                       out_channel=out_channels[0],
219                                       stride=strides[0])
220        self.layer2 = self._make_layer(block,
221                                       layer_nums[1],
222                                       in_channel=in_channels[1],
223                                       out_channel=out_channels[1],
224                                       stride=strides[1])
225        self.layer3 = self._make_layer(block,
226                                       layer_nums[2],
227                                       in_channel=in_channels[2],
228                                       out_channel=out_channels[2],
229                                       stride=strides[2])
230        self.layer4 = self._make_layer(block,
231                                       layer_nums[3],
232                                       in_channel=in_channels[3],
233                                       out_channel=out_channels[3],
234                                       stride=strides[3])
235
236        self.mean = P.ReduceMean(keep_dims=True)
237        self.flatten = nn.Flatten()
238        self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
239        self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
240
241        # init weights
242        self._initialize_weights()
243
244    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
245        """
246        Make stage network of ResNet.
247
248        Args:
249            block (Cell): Resnet block.
250            layer_num (int): Layer number.
251            in_channel (int): Input channel.
252            out_channel (int): Output channel.
253            stride (int): Stride size for the first convolutional layer.
254
255        Returns:
256            SequentialCell, the output layer.
257
258        Examples:
259            >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
260        """
261        layers = []
262
263        resnet_block = block(in_channel, out_channel, stride=stride)
264        layers.append(resnet_block)
265
266        for _ in range(1, layer_num):
267            resnet_block = block(out_channel, out_channel, stride=1)
268            layers.append(resnet_block)
269
270        return nn.SequentialCell(layers)
271
272    def construct(self, x):
273        x = self.conv1(x)
274        c1 = self.maxpool(x)
275
276        c2 = self.layer1(c1)
277        c3 = self.layer2(c2)
278        c4 = self.layer3(c3)
279        c5 = self.layer4(c4)
280
281        out = self.mean(c5, (2, 3))
282        out = self.flatten(out)
283        out = self.end_point(out)
284        out = self.output_fake(out)
285        return out
286
287    def _initialize_weights(self):
288
289        self.init_parameters_data()
290        for _, m in self.cells_and_names():
291            np.random.seed(1)
292
293            if isinstance(m, nn.Conv2dBnFoldQuant):
294                m.weight.set_data(weight_init.initializer(weight_init.Normal(),
295                                                          m.weight.shape,
296                                                          m.weight.dtype))
297            elif isinstance(m, nn.DenseQuant):
298                m.weight.set_data(weight_init.initializer(weight_init.Normal(),
299                                                          m.weight.shape,
300                                                          m.weight.dtype))
301            elif isinstance(m, nn.Conv2dBnWithoutFoldQuant):
302                m.weight.set_data(weight_init.initializer(weight_init.Normal(),
303                                                          m.weight.shape,
304                                                          m.weight.dtype))
305
306
307def resnet50_quant(class_num=10):
308    """
309    Get ResNet50 neural network.
310
311    Args:
312        class_num (int): Class number.
313
314    Returns:
315        Cell, cell instance of ResNet50 neural network.
316
317    Examples:
318        >>> net = resnet50_quant(10)
319    """
320    return ResNet(ResidualBlock,
321                  [3, 4, 6, 3],
322                  [64, 256, 512, 1024],
323                  [256, 512, 1024, 2048],
324                  [1, 2, 2, 2],
325                  class_num)
326
327
328def resnet101_quant(class_num=1001):
329    """
330    Get ResNet101 neural network.
331
332    Args:
333        class_num (int): Class number.
334
335    Returns:
336        Cell, cell instance of ResNet101 neural network.
337
338    Examples:
339        >>> net = resnet101(1001)
340    """
341    return ResNet(ResidualBlock,
342                  [3, 4, 23, 3],
343                  [64, 256, 512, 1024],
344                  [256, 512, 1024, 2048],
345                  [1, 2, 2, 2],
346                  class_num)
347