• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore.ops import operations as P
20
21
22def weight_variable_0(shape):
23    zeros = np.zeros(shape).astype(np.float32)
24    return Tensor(zeros)
25
26
27def weight_variable_1(shape):
28    ones = np.ones(shape).astype(np.float32)
29    return Tensor(ones)
30
31
32def conv3x3(in_channels, out_channels, stride=1, padding=0):
33    """3x3 convolution """
34    return nn.Conv2d(in_channels, out_channels,
35                     kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform',
36                     has_bias=False, pad_mode="same")
37
38
39def conv1x1(in_channels, out_channels, stride=1, padding=0):
40    """1x1 convolution"""
41    return nn.Conv2d(in_channels, out_channels,
42                     kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform',
43                     has_bias=False, pad_mode="same")
44
45
46def conv7x7(in_channels, out_channels, stride=1, padding=0):
47    """1x1 convolution"""
48    return nn.Conv2d(in_channels, out_channels,
49                     kernel_size=7, stride=stride, padding=padding, weight_init='Uniform',
50                     has_bias=False, pad_mode="same")
51
52
53def bn_with_initialize(out_channels):
54    shape = (out_channels)
55    mean = weight_variable_0(shape)
56    var = weight_variable_1(shape)
57    beta = weight_variable_0(shape)
58    bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
59                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
60    return bn
61
62
63def bn_with_initialize_last(out_channels):
64    shape = (out_channels)
65    mean = weight_variable_0(shape)
66    var = weight_variable_1(shape)
67    beta = weight_variable_0(shape)
68    bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
69                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
70    return bn
71
72
73def fc_with_initialize(input_channels, out_channels):
74    return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform')
75
76
77class ResidualBlock(nn.Cell):
78    expansion = 4
79
80    def __init__(self,
81                 in_channels,
82                 out_channels,
83                 stride=1):
84        super(ResidualBlock, self).__init__()
85
86        out_chls = out_channels // self.expansion
87        self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
88        self.bn1 = bn_with_initialize(out_chls)
89
90        self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
91        self.bn2 = bn_with_initialize(out_chls)
92
93        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
94        self.bn3 = bn_with_initialize_last(out_channels)
95
96        self.relu = P.ReLU()
97        self.add = P.Add()
98
99    def construct(self, x):
100        identity = x
101
102        out = self.conv1(x)
103        out = self.bn1(out)
104        out = self.relu(out)
105
106        out = self.conv2(out)
107        out = self.bn2(out)
108        out = self.relu(out)
109
110        out = self.conv3(out)
111        out = self.bn3(out)
112
113        out = self.add(out, identity)
114        out = self.relu(out)
115
116        return out
117
118
119class ResidualBlockWithDown(nn.Cell):
120    expansion = 4
121
122    def __init__(self,
123                 in_channels,
124                 out_channels,
125                 stride=1,
126                 down_sample=False):
127        super(ResidualBlockWithDown, self).__init__()
128
129        out_chls = out_channels // self.expansion
130        self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
131        self.bn1 = bn_with_initialize(out_chls)
132
133        self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
134        self.bn2 = bn_with_initialize(out_chls)
135
136        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
137        self.bn3 = bn_with_initialize_last(out_channels)
138
139        self.relu = P.ReLU()
140        self.downSample = down_sample
141
142        self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
143        self.bn_down_sample = bn_with_initialize(out_channels)
144        self.add = P.Add()
145
146    def construct(self, x):
147        identity = x
148
149        out = self.conv1(x)
150        out = self.bn1(out)
151        out = self.relu(out)
152
153        out = self.conv2(out)
154        out = self.bn2(out)
155        out = self.relu(out)
156
157        out = self.conv3(out)
158        out = self.bn3(out)
159
160        identity = self.conv_down_sample(identity)
161        identity = self.bn_down_sample(identity)
162
163        out = self.add(out, identity)
164        out = self.relu(out)
165
166        return out
167
168
169class MakeLayer0(nn.Cell):
170
171    def __init__(self, block, in_channels, out_channels, stride):
172        super(MakeLayer0, self).__init__()
173        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
174        self.b = block(out_channels, out_channels, stride=stride)
175        self.c = block(out_channels, out_channels, stride=1)
176
177    def construct(self, x):
178        x = self.a(x)
179        x = self.b(x)
180        x = self.c(x)
181
182        return x
183
184
185class MakeLayer1(nn.Cell):
186
187    def __init__(self, block, in_channels, out_channels, stride):
188        super(MakeLayer1, self).__init__()
189        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
190        self.b = block(out_channels, out_channels, stride=1)
191        self.c = block(out_channels, out_channels, stride=1)
192        self.d = block(out_channels, out_channels, stride=1)
193
194    def construct(self, x):
195        x = self.a(x)
196        x = self.b(x)
197        x = self.c(x)
198        x = self.d(x)
199
200        return x
201
202
203class MakeLayer2(nn.Cell):
204
205    def __init__(self, block, in_channels, out_channels, stride):
206        super(MakeLayer2, self).__init__()
207        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
208        self.b = block(out_channels, out_channels, stride=1)
209        self.c = block(out_channels, out_channels, stride=1)
210        self.d = block(out_channels, out_channels, stride=1)
211        self.e = block(out_channels, out_channels, stride=1)
212        self.f = block(out_channels, out_channels, stride=1)
213
214    def construct(self, x):
215        x = self.a(x)
216        x = self.b(x)
217        x = self.c(x)
218        x = self.d(x)
219        x = self.e(x)
220        x = self.f(x)
221
222        return x
223
224
225class MakeLayer3(nn.Cell):
226
227    def __init__(self, block, in_channels, out_channels, stride):
228        super(MakeLayer3, self).__init__()
229        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
230        self.b = block(out_channels, out_channels, stride=1)
231        self.c = block(out_channels, out_channels, stride=1)
232
233    def construct(self, x):
234        x = self.a(x)
235        x = self.b(x)
236        x = self.c(x)
237
238        return x
239
240
241class ResNet(nn.Cell):
242
243    def __init__(self, block, num_classes=100, batch_size=32):
244        super(ResNet, self).__init__()
245        self.batch_size = batch_size
246        self.num_classes = num_classes
247
248        self.conv1 = conv7x7(3, 64, stride=2, padding=0)
249
250        self.bn1 = bn_with_initialize(64)
251        self.relu = P.ReLU()
252        self.maxpool = P.MaxPoolWithArgmax(kernel_size=3, strides=2, pad_mode="SAME")
253
254        self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1)
255        self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2)
256        self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2)
257        self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2)
258
259        self.pool = P.ReduceMean(keep_dims=True)
260        self.squeeze = P.Squeeze(axis=(2, 3))
261        self.fc = fc_with_initialize(512 * block.expansion, num_classes)
262
263    def construct(self, x):
264        x = self.conv1(x)
265        x = self.bn1(x)
266        x = self.relu(x)
267        x = self.maxpool(x)[0]
268
269        x = self.layer1(x)
270        x = self.layer2(x)
271        x = self.layer3(x)
272        x = self.layer4(x)
273
274        x = self.pool(x, (2, 3))
275        x = self.squeeze(x)
276        x = self.fc(x)
277        return x
278
279
280def resnet50(batch_size, num_classes):
281    return ResNet(ResidualBlock, num_classes, batch_size)
282