• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16resnet50 example
17"""
18import numpy as np
19
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23from ..ut_filter import non_graph_engine
24
25
26def conv3x3(in_channels, out_channels, stride=1, padding=1):
27    """3x3 convolution """
28    weight = Tensor(np.ones([out_channels, in_channels, 3, 3]).astype(np.float32) * 0.01)
29    return nn.Conv2d(in_channels, out_channels,
30                     kernel_size=3, stride=stride, padding=padding, weight_init=weight)
31
32
33def conv1x1(in_channels, out_channels, stride=1, padding=0):
34    """1x1 convolution"""
35    weight = Tensor(np.ones([out_channels, in_channels, 1, 1]).astype(np.float32) * 0.01)
36    return nn.Conv2d(in_channels, out_channels,
37                     kernel_size=1, stride=stride, padding=padding, weight_init=weight)
38
39
40def bn_with_initialize(out_channels):
41    shape = (out_channels)
42    mean = Tensor(np.ones(shape).astype(np.float32) * 0.01)
43    var = Tensor(np.ones(shape).astype(np.float32) * 0.01)
44    beta = Tensor(np.ones(shape).astype(np.float32) * 0.01)
45    gamma = Tensor(np.ones(shape).astype(np.float32) * 0.01)
46    return nn.BatchNorm2d(num_features=out_channels,
47                          beta_init=beta,
48                          gamma_init=gamma,
49                          moving_mean_init=mean,
50                          moving_var_init=var)
51
52
53class ResidualBlock(nn.Cell):
54    """
55    residual Block
56    """
57    expansion = 4
58
59    def __init__(self,
60                 in_channels,
61                 out_channels,
62                 stride=1,
63                 down_sample=False):
64        super(ResidualBlock, self).__init__()
65
66        out_chls = out_channels // self.expansion
67        self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
68        self.bn1 = bn_with_initialize(out_chls)
69
70        self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=1)
71        self.bn2 = bn_with_initialize(out_chls)
72
73        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
74        self.bn3 = bn_with_initialize(out_channels)
75
76        self.relu = nn.ReLU()
77        self.downsample = down_sample
78
79        self.conv_down_sample = conv1x1(in_channels, out_channels,
80                                        stride=stride, padding=0)
81        self.bn_down_sample = bn_with_initialize(out_channels)
82        self.add = P.Add()
83
84    def construct(self, x):
85        """
86        :param x:
87        :return:
88        """
89        identity = x
90
91        out = self.conv1(x)
92        out = self.bn1(out)
93        out = self.relu(out)
94
95        out = self.conv2(out)
96        out = self.bn2(out)
97        out = self.relu(out)
98
99        out = self.conv3(out)
100        out = self.bn3(out)
101
102        if self.downsample:
103            identity = self.conv_down_sample(identity)
104            identity = self.bn_down_sample(identity)
105
106        out = self.add(out, identity)
107        out = self.relu(out)
108
109        return out
110
111
112class MakeLayer3(nn.Cell):
113    """
114    make resnet50 3 layers
115    """
116
117    def __init__(self, block, in_channels, out_channels, stride):
118        super(MakeLayer3, self).__init__()
119        self.block_down_sample = block(in_channels, out_channels,
120                                       stride=stride, down_sample=True)
121        self.block1 = block(out_channels, out_channels, stride=1)
122        self.block2 = block(out_channels, out_channels, stride=1)
123
124    def construct(self, x):
125        x = self.block_down_sample(x)
126        x = self.block1(x)
127        x = self.block2(x)
128
129        return x
130
131
132class MakeLayer4(nn.Cell):
133    """
134    make resnet50 4 layers
135    """
136
137    def __init__(self, block, in_channels, out_channels, stride):
138        super(MakeLayer4, self).__init__()
139        self.block_down_sample = block(in_channels, out_channels,
140                                       stride=stride, down_sample=True)
141        self.block1 = block(out_channels, out_channels, stride=1)
142        self.block2 = block(out_channels, out_channels, stride=1)
143        self.block3 = block(out_channels, out_channels, stride=1)
144
145    def construct(self, x):
146        x = self.block_down_sample(x)
147        x = self.block1(x)
148        x = self.block2(x)
149        x = self.block3(x)
150
151        return x
152
153
154class MakeLayer6(nn.Cell):
155    """
156    make resnet50 6 layers
157
158    """
159
160    def __init__(self, block, in_channels, out_channels, stride):
161        super(MakeLayer6, self).__init__()
162        self.block_down_sample = block(in_channels, out_channels,
163                                       stride=stride, down_sample=True)
164        self.block1 = block(out_channels, out_channels, stride=1)
165        self.block2 = block(out_channels, out_channels, stride=1)
166        self.block3 = block(out_channels, out_channels, stride=1)
167        self.block4 = block(out_channels, out_channels, stride=1)
168        self.block5 = block(out_channels, out_channels, stride=1)
169
170    def construct(self, x):
171        x = self.block_down_sample(x)
172        x = self.block1(x)
173        x = self.block2(x)
174        x = self.block3(x)
175        x = self.block4(x)
176        x = self.block5(x)
177
178        return x
179
180
181class ResNet50(nn.Cell):
182    """
183    resnet nn.Cell
184    """
185
186    def __init__(self, block, num_classes=100):
187        super(ResNet50, self).__init__()
188
189        weight_conv = Tensor(np.ones([64, 3, 7, 7]).astype(np.float32) * 0.01)
190        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, weight_init=weight_conv)
191        self.bn1 = bn_with_initialize(64)
192        self.relu = nn.ReLU()
193        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
194
195        self.layer1 = MakeLayer3(
196            block, in_channels=64, out_channels=256, stride=1)
197        self.layer2 = MakeLayer4(
198            block, in_channels=256, out_channels=512, stride=2)
199        self.layer3 = MakeLayer6(
200            block, in_channels=512, out_channels=1024, stride=2)
201        self.layer4 = MakeLayer3(
202            block, in_channels=1024, out_channels=2048, stride=2)
203
204        self.avgpool = nn.AvgPool2d(7, 1)
205        self.flatten = nn.Flatten()
206
207        weight_fc = Tensor(np.ones([num_classes, 512 * block.expansion]).astype(np.float32) * 0.01)
208        bias_fc = Tensor(np.ones([num_classes]).astype(np.float32) * 0.01)
209        self.fc = nn.Dense(512 * block.expansion, num_classes, weight_init=weight_fc, bias_init=bias_fc)
210
211    def construct(self, x):
212        """
213        :param x:
214        :return:
215        """
216        x = self.conv1(x)
217        x = self.bn1(x)
218        x = self.relu(x)
219        x = self.maxpool(x)
220
221        x = self.layer1(x)
222        x = self.layer2(x)
223        x = self.layer3(x)
224        x = self.layer4(x)
225
226        x = self.avgpool(x)
227        x = self.flatten(x)
228        x = self.fc(x)
229
230        return x
231
232
233def resnet50():
234    return ResNet50(ResidualBlock, 10)
235
236
237@non_graph_engine
238def test_compile():
239    net = resnet50()
240    input_data = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
241
242    output = net(input_data)
243    print(output.asnumpy())
244