1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ResNet.""" 16import numpy as np 17import mindspore.nn as nn 18import mindspore.common.initializer as weight_init 19from mindspore.ops import operations as P 20from mindspore import Tensor 21from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant 22from mindspore.compression.quant import create_quant_config 23 24_ema_decay = 0.999 25_symmetric = True 26_fake = True 27_per_channel = True 28_quant_config = create_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) 29 30 31def _weight_variable(shape, factor=0.01): 32 init_value = np.random.randn(*shape).astype(np.float32) * factor 33 return Tensor(init_value) 34 35 36def _conv3x3(in_channel, out_channel, stride=1): 37 weight_shape = (out_channel, in_channel, 3, 3) 38 weight = _weight_variable(weight_shape) 39 return nn.Conv2d(in_channel, out_channel, 40 kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) 41 42 43def _conv1x1(in_channel, out_channel, stride=1): 44 weight_shape = (out_channel, in_channel, 1, 1) 45 weight = _weight_variable(weight_shape) 46 return nn.Conv2d(in_channel, out_channel, 47 kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) 48 49 50def _conv7x7(in_channel, out_channel, stride=1): 51 weight_shape = (out_channel, in_channel, 7, 7) 52 weight = _weight_variable(weight_shape) 53 return nn.Conv2d(in_channel, out_channel, 54 kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) 55 56 57def _bn(channel): 58 return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, 59 gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) 60 61 62def _bn_last(channel): 63 return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, 64 gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) 65 66 67def _fc(in_channel, out_channel): 68 weight_shape = (out_channel, in_channel) 69 weight = _weight_variable(weight_shape) 70 return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) 71 72 73class ConvBNReLU(nn.Cell): 74 """ 75 Convolution/Depthwise fused with Batchnorm and ReLU block definition. 76 77 Args: 78 in_planes (int): Input channel. 79 out_planes (int): Output channel. 80 kernel_size (int): Input kernel size. 81 stride (int): Stride size for the first convolutional layer. Default: 1. 82 groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. 83 84 Returns: 85 Tensor, output tensor. 86 87 Examples: 88 >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) 89 """ 90 91 def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 92 super(ConvBNReLU, self).__init__() 93 padding = (kernel_size - 1) // 2 94 conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, 95 group=groups, fake=_fake, quant_config=_quant_config) 96 layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] 97 self.features = nn.SequentialCell(layers) 98 99 def construct(self, x): 100 output = self.features(x) 101 return output 102 103 104class ResidualBlock(nn.Cell): 105 """ 106 ResNet V1 residual block definition. 107 108 Args: 109 in_channel (int): Input channel. 110 out_channel (int): Output channel. 111 stride (int): Stride size for the first convolutional layer. Default: 1. 112 113 Returns: 114 Tensor, output tensor. 115 116 Examples: 117 >>> ResidualBlock(3, 256, stride=2) 118 """ 119 expansion = 4 120 121 def __init__(self, 122 in_channel, 123 out_channel, 124 stride=1): 125 super(ResidualBlock, self).__init__() 126 127 channel = out_channel // self.expansion 128 self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) 129 self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) 130 self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake, 131 quant_config=_quant_config, 132 kernel_size=1, stride=1, pad_mode='same', padding=0), 133 FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) 134 ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake, 135 quant_config=_quant_config, 136 kernel_size=1, stride=1, 137 pad_mode='same', padding=0) 138 139 self.down_sample = False 140 141 if stride != 1 or in_channel != out_channel: 142 self.down_sample = True 143 self.down_sample_layer = None 144 145 if self.down_sample: 146 self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel, 147 quant_config=_quant_config, 148 kernel_size=1, stride=stride, 149 pad_mode='same', padding=0), 150 FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, 151 symmetric=False) 152 ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel, 153 fake=_fake, 154 quant_config=_quant_config, 155 kernel_size=1, 156 stride=stride, 157 pad_mode='same', 158 padding=0) 159 self.add = nn.TensorAddQuant() 160 self.relu = P.ReLU() 161 162 def construct(self, x): 163 identity = x 164 out = self.conv1(x) 165 out = self.conv2(out) 166 out = self.conv3(out) 167 168 if self.down_sample: 169 identity = self.down_sample_layer(identity) 170 171 out = self.add(out, identity) 172 out = self.relu(out) 173 174 return out 175 176 177class ResNet(nn.Cell): 178 """ 179 ResNet architecture. 180 181 Args: 182 block (Cell): Block for network. 183 layer_nums (list): Numbers of block in different layers. 184 in_channels (list): Input channel in each layer. 185 out_channels (list): Output channel in each layer. 186 strides (list): Stride size in each layer. 187 num_classes (int): The number of classes that the training images are belonging to. 188 Returns: 189 Tensor, output tensor. 190 191 Examples: 192 >>> ResNet(ResidualBlock, 193 >>> [3, 4, 6, 3], 194 >>> [64, 256, 512, 1024], 195 >>> [256, 512, 1024, 2048], 196 >>> [1, 2, 2, 2], 197 >>> 10) 198 """ 199 200 def __init__(self, 201 block, 202 layer_nums, 203 in_channels, 204 out_channels, 205 strides, 206 num_classes): 207 super(ResNet, self).__init__() 208 209 if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: 210 raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") 211 212 self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) 213 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") 214 215 self.layer1 = self._make_layer(block, 216 layer_nums[0], 217 in_channel=in_channels[0], 218 out_channel=out_channels[0], 219 stride=strides[0]) 220 self.layer2 = self._make_layer(block, 221 layer_nums[1], 222 in_channel=in_channels[1], 223 out_channel=out_channels[1], 224 stride=strides[1]) 225 self.layer3 = self._make_layer(block, 226 layer_nums[2], 227 in_channel=in_channels[2], 228 out_channel=out_channels[2], 229 stride=strides[2]) 230 self.layer4 = self._make_layer(block, 231 layer_nums[3], 232 in_channel=in_channels[3], 233 out_channel=out_channels[3], 234 stride=strides[3]) 235 236 self.mean = P.ReduceMean(keep_dims=True) 237 self.flatten = nn.Flatten() 238 self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config) 239 self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay) 240 241 # init weights 242 self._initialize_weights() 243 244 def _make_layer(self, block, layer_num, in_channel, out_channel, stride): 245 """ 246 Make stage network of ResNet. 247 248 Args: 249 block (Cell): Resnet block. 250 layer_num (int): Layer number. 251 in_channel (int): Input channel. 252 out_channel (int): Output channel. 253 stride (int): Stride size for the first convolutional layer. 254 255 Returns: 256 SequentialCell, the output layer. 257 258 Examples: 259 >>> _make_layer(ResidualBlock, 3, 128, 256, 2) 260 """ 261 layers = [] 262 263 resnet_block = block(in_channel, out_channel, stride=stride) 264 layers.append(resnet_block) 265 266 for _ in range(1, layer_num): 267 resnet_block = block(out_channel, out_channel, stride=1) 268 layers.append(resnet_block) 269 270 return nn.SequentialCell(layers) 271 272 def construct(self, x): 273 x = self.conv1(x) 274 c1 = self.maxpool(x) 275 276 c2 = self.layer1(c1) 277 c3 = self.layer2(c2) 278 c4 = self.layer3(c3) 279 c5 = self.layer4(c4) 280 281 out = self.mean(c5, (2, 3)) 282 out = self.flatten(out) 283 out = self.end_point(out) 284 out = self.output_fake(out) 285 return out 286 287 def _initialize_weights(self): 288 289 self.init_parameters_data() 290 for _, m in self.cells_and_names(): 291 np.random.seed(1) 292 293 if isinstance(m, nn.Conv2dBnFoldQuant): 294 m.weight.set_data(weight_init.initializer(weight_init.Normal(), 295 m.weight.shape, 296 m.weight.dtype)) 297 elif isinstance(m, nn.DenseQuant): 298 m.weight.set_data(weight_init.initializer(weight_init.Normal(), 299 m.weight.shape, 300 m.weight.dtype)) 301 elif isinstance(m, nn.Conv2dBnWithoutFoldQuant): 302 m.weight.set_data(weight_init.initializer(weight_init.Normal(), 303 m.weight.shape, 304 m.weight.dtype)) 305 306 307def resnet50_quant(class_num=10): 308 """ 309 Get ResNet50 neural network. 310 311 Args: 312 class_num (int): Class number. 313 314 Returns: 315 Cell, cell instance of ResNet50 neural network. 316 317 Examples: 318 >>> net = resnet50_quant(10) 319 """ 320 return ResNet(ResidualBlock, 321 [3, 4, 6, 3], 322 [64, 256, 512, 1024], 323 [256, 512, 1024, 2048], 324 [1, 2, 2, 2], 325 class_num) 326 327 328def resnet101_quant(class_num=1001): 329 """ 330 Get ResNet101 neural network. 331 332 Args: 333 class_num (int): Class number. 334 335 Returns: 336 Cell, cell instance of ResNet101 neural network. 337 338 Examples: 339 >>> net = resnet101(1001) 340 """ 341 return ResNet(ResidualBlock, 342 [3, 4, 23, 3], 343 [64, 256, 512, 1024], 344 [256, 512, 1024, 2048], 345 [1, 2, 2, 2], 346 class_num) 347