• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore.ops import operations as P
20
21
22def weight_variable(shape):
23    ones = np.ones(shape).astype(np.float32)
24    return Tensor(ones * 0.01)
25
26
27def weight_variable_0(shape):
28    zeros = np.zeros(shape).astype(np.float32)
29    return Tensor(zeros)
30
31
32def weight_variable_1(shape):
33    ones = np.ones(shape).astype(np.float32)
34    return Tensor(ones)
35
36
37def conv3x3(in_channels, out_channels, stride=1, padding=0):
38    """3x3 convolution """
39    weight_shape = (out_channels, in_channels, 3, 3)
40    weight = weight_variable(weight_shape)
41    return nn.Conv2d(in_channels, out_channels,
42                     kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
43
44
45def conv1x1(in_channels, out_channels, stride=1, padding=0):
46    """1x1 convolution"""
47    weight_shape = (out_channels, in_channels, 1, 1)
48    weight = weight_variable(weight_shape)
49    return nn.Conv2d(in_channels, out_channels,
50                     kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
51
52
53def conv7x7(in_channels, out_channels, stride=1, padding=0):
54    """1x1 convolution"""
55    weight_shape = (out_channels, in_channels, 7, 7)
56    weight = weight_variable(weight_shape)
57    return nn.Conv2d(in_channels, out_channels,
58                     kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
59
60
61def bn_with_initialize(out_channels):
62    shape = (out_channels)
63    mean = weight_variable_0(shape)
64    var = weight_variable_1(shape)
65    beta = weight_variable_0(shape)
66    gamma = weight_variable_1(shape)
67    bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
68                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
69    return bn
70
71
72def bn_with_initialize_last(out_channels):
73    shape = (out_channels)
74    mean = weight_variable_0(shape)
75    var = weight_variable_1(shape)
76    beta = weight_variable_0(shape)
77    gamma = weight_variable_0(shape)
78    bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
79                        beta_init=beta, moving_mean_init=mean, moving_var_init=var)
80    return bn
81
82
83def fc_with_initialize(input_channels, out_channels):
84    weight_shape = (out_channels, input_channels)
85    bias_shape = (out_channels)
86    weight = weight_variable(weight_shape)
87    bias = weight_variable_0(bias_shape)
88
89    return nn.Dense(input_channels, out_channels, weight, bias)
90
91
92class ResidualBlock(nn.Cell):
93    expansion = 4
94
95    def __init__(self,
96                 in_channels,
97                 out_channels,
98                 stride=1):
99        super(ResidualBlock, self).__init__()
100
101        out_chls = out_channels // self.expansion
102        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
103        self.bn1 = bn_with_initialize(out_chls)
104
105        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
106        self.bn2 = bn_with_initialize(out_chls)
107
108        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
109        self.bn3 = bn_with_initialize_last(out_channels)
110
111        self.relu = P.ReLU()
112        self.add = P.Add()
113
114    def construct(self, x):
115        identity = x
116
117        out = self.conv1(x)
118        out = self.bn1(out)
119        out = self.relu(out)
120
121        out = self.conv2(out)
122        out = self.bn2(out)
123        out = self.relu(out)
124
125        out = self.conv3(out)
126        out = self.bn3(out)
127
128        out = self.add(out, identity)
129        out = self.relu(out)
130
131        return out
132
133
134class ResidualBlockWithDown(nn.Cell):
135    expansion = 4
136
137    def __init__(self,
138                 in_channels,
139                 out_channels,
140                 stride=1,
141                 down_sample=False):
142        super(ResidualBlockWithDown, self).__init__()
143
144        out_chls = out_channels // self.expansion
145        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
146        self.bn1 = bn_with_initialize(out_chls)
147
148        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
149        self.bn2 = bn_with_initialize(out_chls)
150
151        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
152        self.bn3 = bn_with_initialize_last(out_channels)
153
154        self.relu = P.ReLU()
155        self.downSample = down_sample
156
157        self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
158        self.bn_down_sample = bn_with_initialize(out_channels)
159        self.add = P.Add()
160
161    def construct(self, x):
162        identity = x
163
164        out = self.conv1(x)
165        out = self.bn1(out)
166        out = self.relu(out)
167
168        out = self.conv2(out)
169        out = self.bn2(out)
170        out = self.relu(out)
171
172        out = self.conv3(out)
173        out = self.bn3(out)
174
175        identity = self.conv_down_sample(identity)
176        identity = self.bn_down_sample(identity)
177
178        out = self.add(out, identity)
179        out = self.relu(out)
180
181        return out
182
183
184class MakeLayer0(nn.Cell):
185
186    def __init__(self, block, in_channels, out_channels, stride):
187        super(MakeLayer0, self).__init__()
188        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
189        self.b = block(out_channels, out_channels, stride=stride)
190        self.c = block(out_channels, out_channels, stride=1)
191
192    def construct(self, x):
193        x = self.a(x)
194        x = self.b(x)
195        x = self.c(x)
196
197        return x
198
199
200class MakeLayer1(nn.Cell):
201
202    def __init__(self, block, in_channels, out_channels, stride):
203        super(MakeLayer1, self).__init__()
204        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
205        self.b = block(out_channels, out_channels, stride=1)
206        self.c = block(out_channels, out_channels, stride=1)
207        self.d = block(out_channels, out_channels, stride=1)
208
209    def construct(self, x):
210        x = self.a(x)
211        x = self.b(x)
212        x = self.c(x)
213        x = self.d(x)
214
215        return x
216
217
218class MakeLayer2(nn.Cell):
219
220    def __init__(self, block, in_channels, out_channels, stride):
221        super(MakeLayer2, self).__init__()
222        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
223        self.b = block(out_channels, out_channels, stride=1)
224        self.c = block(out_channels, out_channels, stride=1)
225        self.d = block(out_channels, out_channels, stride=1)
226        self.e = block(out_channels, out_channels, stride=1)
227        self.f = block(out_channels, out_channels, stride=1)
228
229    def construct(self, x):
230        x = self.a(x)
231        x = self.b(x)
232        x = self.c(x)
233        x = self.d(x)
234        x = self.e(x)
235        x = self.f(x)
236
237        return x
238
239
240class MakeLayer3(nn.Cell):
241
242    def __init__(self, block, in_channels, out_channels, stride):
243        super(MakeLayer3, self).__init__()
244        self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
245        self.b = block(out_channels, out_channels, stride=1)
246        self.c = block(out_channels, out_channels, stride=1)
247
248    def construct(self, x):
249        x = self.a(x)
250        x = self.b(x)
251        x = self.c(x)
252
253        return x
254
255
256class ResNet(nn.Cell):
257
258    def __init__(self, block, num_classes=100, batch_size=32):
259        super(ResNet, self).__init__()
260        self.batch_size = batch_size
261        self.num_classes = num_classes
262
263        self.conv1 = conv7x7(3, 64, stride=2, padding=0)
264
265        self.bn1 = bn_with_initialize(64)
266        self.relu = P.ReLU()
267        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="SAME")
268
269        self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1)
270        self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2)
271        self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2)
272        self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2)
273
274        self.pool = P.ReduceMean(keep_dims=True)
275        self.fc = fc_with_initialize(512 * block.expansion, num_classes)
276        self.flatten = nn.Flatten()
277
278    def construct(self, x):
279        x = self.conv1(x)
280        x = self.bn1(x)
281        x = self.relu(x)
282        x = self.maxpool(x)
283
284        x = self.layer1(x)
285        x = self.layer2(x)
286        x = self.layer3(x)
287        x = self.layer4(x)
288
289        x = self.pool(x, (-2, -1))
290        x = self.flatten(x)
291        x = self.fc(x)
292        return x
293
294
295def resnet50(batch_size, num_classes):
296    return ResNet(ResidualBlock, num_classes, batch_size)
297