1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore.nn as nn 18from mindspore import Tensor 19from mindspore.ops import operations as P 20 21 22def weight_variable_0(shape): 23 zeros = np.zeros(shape).astype(np.float32) 24 return Tensor(zeros) 25 26 27def weight_variable_1(shape): 28 ones = np.ones(shape).astype(np.float32) 29 return Tensor(ones) 30 31 32def conv3x3(in_channels, out_channels, stride=1, padding=0): 33 """3x3 convolution """ 34 return nn.Conv2d(in_channels, out_channels, 35 kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform', 36 has_bias=False, pad_mode="same") 37 38 39def conv1x1(in_channels, out_channels, stride=1, padding=0): 40 """1x1 convolution""" 41 return nn.Conv2d(in_channels, out_channels, 42 kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform', 43 has_bias=False, pad_mode="same") 44 45 46def conv7x7(in_channels, out_channels, stride=1, padding=0): 47 """1x1 convolution""" 48 return nn.Conv2d(in_channels, out_channels, 49 kernel_size=7, stride=stride, padding=padding, weight_init='Uniform', 50 has_bias=False, pad_mode="same") 51 52 53def bn_with_initialize(out_channels): 54 shape = (out_channels) 55 mean = weight_variable_0(shape) 56 var = weight_variable_1(shape) 57 beta = weight_variable_0(shape) 58 bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', 59 beta_init=beta, moving_mean_init=mean, moving_var_init=var) 60 return bn 61 62 63def bn_with_initialize_last(out_channels): 64 shape = (out_channels) 65 mean = weight_variable_0(shape) 66 var = weight_variable_1(shape) 67 beta = weight_variable_0(shape) 68 bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', 69 beta_init=beta, moving_mean_init=mean, moving_var_init=var) 70 return bn 71 72 73def fc_with_initialize(input_channels, out_channels): 74 return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform') 75 76 77class ResidualBlock(nn.Cell): 78 expansion = 4 79 80 def __init__(self, 81 in_channels, 82 out_channels, 83 stride=1): 84 super(ResidualBlock, self).__init__() 85 86 out_chls = out_channels // self.expansion 87 self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) 88 self.bn1 = bn_with_initialize(out_chls) 89 90 self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) 91 self.bn2 = bn_with_initialize(out_chls) 92 93 self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) 94 self.bn3 = bn_with_initialize_last(out_channels) 95 96 self.relu = P.ReLU() 97 self.add = P.Add() 98 99 def construct(self, x): 100 identity = x 101 102 out = self.conv1(x) 103 out = self.bn1(out) 104 out = self.relu(out) 105 106 out = self.conv2(out) 107 out = self.bn2(out) 108 out = self.relu(out) 109 110 out = self.conv3(out) 111 out = self.bn3(out) 112 113 out = self.add(out, identity) 114 out = self.relu(out) 115 116 return out 117 118 119class ResidualBlockWithDown(nn.Cell): 120 expansion = 4 121 122 def __init__(self, 123 in_channels, 124 out_channels, 125 stride=1, 126 down_sample=False): 127 super(ResidualBlockWithDown, self).__init__() 128 129 out_chls = out_channels // self.expansion 130 self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) 131 self.bn1 = bn_with_initialize(out_chls) 132 133 self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) 134 self.bn2 = bn_with_initialize(out_chls) 135 136 self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) 137 self.bn3 = bn_with_initialize_last(out_channels) 138 139 self.relu = P.ReLU() 140 self.downSample = down_sample 141 142 self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) 143 self.bn_down_sample = bn_with_initialize(out_channels) 144 self.add = P.Add() 145 146 def construct(self, x): 147 identity = x 148 149 out = self.conv1(x) 150 out = self.bn1(out) 151 out = self.relu(out) 152 153 out = self.conv2(out) 154 out = self.bn2(out) 155 out = self.relu(out) 156 157 out = self.conv3(out) 158 out = self.bn3(out) 159 160 identity = self.conv_down_sample(identity) 161 identity = self.bn_down_sample(identity) 162 163 out = self.add(out, identity) 164 out = self.relu(out) 165 166 return out 167 168 169class MakeLayer0(nn.Cell): 170 171 def __init__(self, block, in_channels, out_channels, stride): 172 super(MakeLayer0, self).__init__() 173 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) 174 self.b = block(out_channels, out_channels, stride=stride) 175 self.c = block(out_channels, out_channels, stride=1) 176 177 def construct(self, x): 178 x = self.a(x) 179 x = self.b(x) 180 x = self.c(x) 181 182 return x 183 184 185class MakeLayer1(nn.Cell): 186 187 def __init__(self, block, in_channels, out_channels, stride): 188 super(MakeLayer1, self).__init__() 189 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 190 self.b = block(out_channels, out_channels, stride=1) 191 self.c = block(out_channels, out_channels, stride=1) 192 self.d = block(out_channels, out_channels, stride=1) 193 194 def construct(self, x): 195 x = self.a(x) 196 x = self.b(x) 197 x = self.c(x) 198 x = self.d(x) 199 200 return x 201 202 203class MakeLayer2(nn.Cell): 204 205 def __init__(self, block, in_channels, out_channels, stride): 206 super(MakeLayer2, self).__init__() 207 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 208 self.b = block(out_channels, out_channels, stride=1) 209 self.c = block(out_channels, out_channels, stride=1) 210 self.d = block(out_channels, out_channels, stride=1) 211 self.e = block(out_channels, out_channels, stride=1) 212 self.f = block(out_channels, out_channels, stride=1) 213 214 def construct(self, x): 215 x = self.a(x) 216 x = self.b(x) 217 x = self.c(x) 218 x = self.d(x) 219 x = self.e(x) 220 x = self.f(x) 221 222 return x 223 224 225class MakeLayer3(nn.Cell): 226 227 def __init__(self, block, in_channels, out_channels, stride): 228 super(MakeLayer3, self).__init__() 229 self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) 230 self.b = block(out_channels, out_channels, stride=1) 231 self.c = block(out_channels, out_channels, stride=1) 232 233 def construct(self, x): 234 x = self.a(x) 235 x = self.b(x) 236 x = self.c(x) 237 238 return x 239 240 241class ResNet(nn.Cell): 242 243 def __init__(self, block, num_classes=100, batch_size=32): 244 super(ResNet, self).__init__() 245 self.batch_size = batch_size 246 self.num_classes = num_classes 247 248 self.conv1 = conv7x7(3, 64, stride=2, padding=0) 249 250 self.bn1 = bn_with_initialize(64) 251 self.relu = P.ReLU() 252 self.maxpool = P.MaxPoolWithArgmax(kernel_size=3, strides=2, pad_mode="SAME") 253 254 self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) 255 self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) 256 self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) 257 self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) 258 259 self.pool = P.ReduceMean(keep_dims=True) 260 self.squeeze = P.Squeeze(axis=(2, 3)) 261 self.fc = fc_with_initialize(512 * block.expansion, num_classes) 262 263 def construct(self, x): 264 x = self.conv1(x) 265 x = self.bn1(x) 266 x = self.relu(x) 267 x = self.maxpool(x)[0] 268 269 x = self.layer1(x) 270 x = self.layer2(x) 271 x = self.layer3(x) 272 x = self.layer4(x) 273 274 x = self.pool(x, (2, 3)) 275 x = self.squeeze(x) 276 x = self.fc(x) 277 return x 278 279 280def resnet50(batch_size, num_classes): 281 return ResNet(ResidualBlock, num_classes, batch_size) 282