• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore.common import dtype as mstype
20from mindspore.common.initializer import initializer
21from mindspore.ops import operations as P
22
23
24def weight_variable(shape):
25    return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
26
27
28def weight_variable_uniform(shape):
29    return initializer('Uniform', shape=shape, dtype=mstype.float32)
30
31
32def weight_variable_0(shape):
33    zeros = np.zeros(shape).astype(np.float32)
34    return Tensor(zeros)
35
36
37def weight_variable_1(shape):
38    ones = np.ones(shape).astype(np.float32)
39    return Tensor(ones)
40
41
42def conv3x3(in_channels, out_channels, stride=1, padding=0):
43    """3x3 convolution """
44    weight_shape = (out_channels, in_channels, 3, 3)
45    weight = weight_variable(weight_shape)
46    return nn.Conv2d(in_channels, out_channels,
47                     kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
48
49
50def conv1x1(in_channels, out_channels, stride=1, padding=0):
51    """1x1 convolution"""
52    weight_shape = (out_channels, in_channels, 1, 1)
53    weight = weight_variable(weight_shape)
54    return nn.Conv2d(in_channels, out_channels,
55                     kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
56
57
58def conv7x7(in_channels, out_channels, stride=1, padding=0):
59    """1x1 convolution"""
60    weight_shape = (out_channels, in_channels, 7, 7)
61    weight = weight_variable(weight_shape)
62    return nn.Conv2d(in_channels, out_channels,
63                     kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
64
65
66def bn_with_initialize(out_channels):
67    shape = (out_channels)
68    mean = weight_variable_0(shape)
69    var = weight_variable_1(shape)
70    beta = weight_variable_0(shape)
71    gamma = weight_variable_uniform(shape)
72    bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
73                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
74    return bn
75
76
77def bn_with_initialize_last(out_channels):
78    shape = (out_channels)
79    mean = weight_variable_0(shape)
80    var = weight_variable_1(shape)
81    beta = weight_variable_0(shape)
82    gamma = weight_variable_uniform(shape)
83    bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
84                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
85    return bn
86
87
88def fc_with_initialize(input_channels, out_channels):
89    weight_shape = (out_channels, input_channels)
90    weight = weight_variable(weight_shape)
91    bias_shape = (out_channels)
92    bias = weight_variable_uniform(bias_shape)
93    return nn.Dense(input_channels, out_channels, weight, bias)
94
95
96class ResidualBlock(nn.Cell):
97    expansion = 4
98
99    def __init__(self,
100                 in_channels,
101                 out_channels,
102                 stride=1):
103        super(ResidualBlock, self).__init__()
104
105        out_chls = out_channels // self.expansion
106        self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
107        self.bn1 = bn_with_initialize(out_chls)
108
109        self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
110        self.bn2 = bn_with_initialize(out_chls)
111
112        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
113        self.bn3 = bn_with_initialize_last(out_channels)
114
115        self.relu = P.ReLU()
116        self.add = P.Add()
117
118    def construct(self, x):
119        identity = x
120
121        out = self.conv1(x)
122        out = self.bn1(out)
123        out = self.relu(out)
124
125        out = self.conv2(out)
126        out = self.bn2(out)
127        out = self.relu(out)
128
129        out = self.conv3(out)
130        out = self.bn3(out)
131
132        out = self.add(out, identity)
133        out = self.relu(out)
134
135        return out
136
137
138class ResidualBlockWithDown(nn.Cell):
139    expansion = 4
140
141    def __init__(self,
142                 in_channels,
143                 out_channels,
144                 stride=1,
145                 down_sample=False):
146        super(ResidualBlockWithDown, self).__init__()
147
148        out_chls = out_channels // self.expansion
149        self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
150        self.bn1 = bn_with_initialize(out_chls)
151
152        self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
153        self.bn2 = bn_with_initialize(out_chls)
154
155        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
156        self.bn3 = bn_with_initialize_last(out_channels)
157
158        self.relu = P.ReLU()
159        self.downSample = down_sample
160
161        self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
162        self.bn_down_sample = bn_with_initialize(out_channels)
163        self.add = P.Add()
164
165    def construct(self, x):
166        identity = x
167
168        out = self.conv1(x)
169        out = self.bn1(out)
170        out = self.relu(out)
171
172        out = self.conv2(out)
173        out = self.bn2(out)
174        out = self.relu(out)
175
176        out = self.conv3(out)
177        out = self.bn3(out)
178
179        identity = self.conv_down_sample(identity)
180        identity = self.bn_down_sample(identity)
181
182        out = self.add(out, identity)
183        out = self.relu(out)
184
185        return out
186
187
188class MakeLayer0(nn.Cell):
189
190    def __init__(self, block, in_channels, out_channels, stride):
191        super(MakeLayer0, self).__init__()
192        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
193        self.b = block(out_channels, out_channels, stride=stride)
194        self.c = block(out_channels, out_channels, stride=1)
195
196    def construct(self, x):
197        x = self.a(x)
198        x = self.b(x)
199        x = self.c(x)
200
201        return x
202
203
204class MakeLayer1(nn.Cell):
205
206    def __init__(self, block, in_channels, out_channels, stride):
207        super(MakeLayer1, self).__init__()
208        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
209        self.b = block(out_channels, out_channels, stride=1)
210        self.c = block(out_channels, out_channels, stride=1)
211        self.d = block(out_channels, out_channels, stride=1)
212
213    def construct(self, x):
214        x = self.a(x)
215        x = self.b(x)
216        x = self.c(x)
217        x = self.d(x)
218
219        return x
220
221
222class MakeLayer2(nn.Cell):
223
224    def __init__(self, block, in_channels, out_channels, stride):
225        super(MakeLayer2, self).__init__()
226        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
227        self.b = block(out_channels, out_channels, stride=1)
228        self.c = block(out_channels, out_channels, stride=1)
229        self.d = block(out_channels, out_channels, stride=1)
230        self.e = block(out_channels, out_channels, stride=1)
231        self.f = block(out_channels, out_channels, stride=1)
232
233    def construct(self, x):
234        x = self.a(x)
235        x = self.b(x)
236        x = self.c(x)
237        x = self.d(x)
238        x = self.e(x)
239        x = self.f(x)
240
241        return x
242
243
244class MakeLayer3(nn.Cell):
245
246    def __init__(self, block, in_channels, out_channels, stride):
247        super(MakeLayer3, self).__init__()
248        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
249        self.b = block(out_channels, out_channels, stride=1)
250        self.c = block(out_channels, out_channels, stride=1)
251
252    def construct(self, x):
253        x = self.a(x)
254        x = self.b(x)
255        x = self.c(x)
256
257        return x
258
259
260class ResNet(nn.Cell):
261
262    def __init__(self, block, num_classes=100, batch_size=32):
263        super(ResNet, self).__init__()
264        self.batch_size = batch_size
265        self.num_classes = num_classes
266
267        self.conv1 = conv7x7(3, 64, stride=2, padding=0)
268
269        self.bn1 = bn_with_initialize(64)
270        self.relu = P.ReLU()
271        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
272
273        self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1)
274        self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2)
275        self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2)
276        self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2)
277
278        self.pool = P.ReduceMean(keep_dims=True)
279        self.squeeze = P.Squeeze(axis=(2, 3))
280        self.fc = fc_with_initialize(512 * block.expansion, num_classes)
281
282    def construct(self, x):
283        x = self.conv1(x)
284        x = self.bn1(x)
285        x = self.relu(x)
286        x = self.maxpool(x)
287
288        x = self.layer1(x)
289        x = self.layer2(x)
290        x = self.layer3(x)
291        x = self.layer4(x)
292
293        x = self.pool(x, (2, 3))
294        x = self.squeeze(x)
295        x = self.fc(x)
296        return x
297
298
299def resnet50(batch_size, num_classes):
300    return ResNet(ResidualBlock, num_classes, batch_size)
301