• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16resnet50 example
17"""
18import numpy as np
19
20import mindspore.nn as nn  # pylint: disable=C0414
21from mindspore import Tensor
22from mindspore.common.api import _cell_graph_executor
23from mindspore.ops.operations import Add
24from ...train_step_wrap import train_step_with_loss_warp
25
26
27def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'):
28    """3x3 convolution """
29    return nn.Conv2d(in_channels, out_channels,
30                     kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode)
31
32
33def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'):
34    """1x1 convolution"""
35    return nn.Conv2d(in_channels, out_channels,
36                     kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode)
37
38
39class ResidualBlock(nn.Cell):
40    """
41    residual Block
42    """
43    expansion = 4
44
45    def __init__(self,
46                 in_channels,
47                 out_channels,
48                 stride=1,
49                 down_sample=False):
50        super(ResidualBlock, self).__init__()
51
52        out_chls = out_channels // self.expansion
53        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
54        self.bn1 = nn.BatchNorm2d(out_chls)
55
56        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
57        self.bn2 = nn.BatchNorm2d(out_chls)
58
59        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
60        self.bn3 = nn.BatchNorm2d(out_channels)
61
62        self.relu = nn.ReLU()
63        self.downsample = down_sample
64
65        self.conv_down_sample = conv1x1(in_channels, out_channels,
66                                        stride=stride, padding=0)
67        self.bn_down_sample = nn.BatchNorm2d(out_channels)
68        self.add = Add()
69
70    def construct(self, x):
71        """
72        :param x:
73        :return:
74        """
75        identity = x
76
77        out = self.conv1(x)
78        out = self.bn1(out)
79        out = self.relu(out)
80
81        out = self.conv2(out)
82        out = self.bn2(out)
83        out = self.relu(out)
84
85        out = self.conv3(out)
86        out = self.bn3(out)
87
88        if self.downsample:
89            identity = self.conv_down_sample(identity)
90            identity = self.bn_down_sample(identity)
91
92        out = self.add(out, identity)
93        out = self.relu(out)
94
95        return out
96
97
98class ResNet18(nn.Cell):
99    """
100    resnet nn.Cell
101    """
102
103    def __init__(self, block, num_classes=100):
104        super(ResNet18, self).__init__()
105
106        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
107        self.bn1 = nn.BatchNorm2d(64)
108        self.relu = nn.ReLU()
109        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
110
111        self.layer1 = self.MakeLayer(
112            block, 2, in_channels=64, out_channels=256, stride=1)
113        self.layer2 = self.MakeLayer(
114            block, 2, in_channels=256, out_channels=512, stride=2)
115        self.layer3 = self.MakeLayer(
116            block, 2, in_channels=512, out_channels=1024, stride=2)
117        self.layer4 = self.MakeLayer(
118            block, 2, in_channels=1024, out_channels=2048, stride=2)
119
120        self.avgpool = nn.AvgPool2d(7, 1)
121        self.flatten = nn.Flatten()
122        self.fc = nn.Dense(512 * block.expansion, num_classes)
123
124    def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
125        """
126        make block layer
127        :param block:
128        :param layer_num:
129        :param in_channels:
130        :param out_channels:
131        :param stride:
132        :return:
133        """
134        layers = []
135        resblk = block(in_channels, out_channels,
136                       stride=stride, down_sample=True)
137        layers.append(resblk)
138
139        for _ in range(1, layer_num):
140            resblk = block(out_channels, out_channels, stride=1)
141            layers.append(resblk)
142
143        return nn.SequentialCell(layers)
144
145    def construct(self, x):
146        """
147        :param x:
148        :return:
149        """
150        x = self.conv1(x)
151        x = self.bn1(x)
152        x = self.relu(x)
153        x = self.maxpool(x)
154
155        x = self.layer1(x)
156        x = self.layer2(x)
157        x = self.layer3(x)
158        x = self.layer4(x)
159
160        x = self.avgpool(x)
161        x = self.flatten(x)
162        x = self.fc(x)
163
164        return x
165
166
167class ResNet9(nn.Cell):
168    """
169    resnet nn.Cell
170    """
171
172    def __init__(self, block, num_classes=100):
173        super(ResNet9, self).__init__()
174
175        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
176        self.bn1 = nn.BatchNorm2d(64)
177        self.relu = nn.ReLU()
178        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
179
180        self.layer1 = self.MakeLayer(
181            block, 1, in_channels=64, out_channels=256, stride=1)
182        self.layer2 = self.MakeLayer(
183            block, 1, in_channels=256, out_channels=512, stride=2)
184        self.layer3 = self.MakeLayer(
185            block, 1, in_channels=512, out_channels=1024, stride=2)
186        self.layer4 = self.MakeLayer(
187            block, 1, in_channels=1024, out_channels=2048, stride=2)
188
189        self.avgpool = nn.AvgPool2d(7, 1)
190        self.flatten = nn.Flatten()
191        self.fc = nn.Dense(512 * block.expansion, num_classes)
192
193    def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
194        """
195        make block layer
196        :param block:
197        :param layer_num:
198        :param in_channels:
199        :param out_channels:
200        :param stride:
201        :return:
202        """
203        layers = []
204        resblk = block(in_channels, out_channels,
205                       stride=stride, down_sample=True)
206        layers.append(resblk)
207
208        for _ in range(1, layer_num):
209            resblk = block(out_channels, out_channels, stride=1)
210            layers.append(resblk)
211
212        return nn.SequentialCell(layers)
213
214    def construct(self, x):
215        """
216        :param x:
217        :return:
218        """
219        x = self.conv1(x)
220        x = self.bn1(x)
221        x = self.relu(x)
222        x = self.maxpool(x)
223
224        x = self.layer1(x)
225        x = self.layer2(x)
226        x = self.layer3(x)
227        x = self.layer4(x)
228
229        x = self.avgpool(x)
230        x = self.flatten(x)
231        x = self.fc(x)
232
233        return x
234
235
236def resnet18():
237    return ResNet18(ResidualBlock, 10)
238
239
240def resnet9():
241    return ResNet9(ResidualBlock, 10)
242
243
244def test_compile():
245    net = resnet18()
246    input_data = Tensor(np.ones([1, 3, 224, 224]))
247    _cell_graph_executor.compile(net, input_data)
248
249
250def test_train_step():
251    net = train_step_with_loss_warp(resnet9())
252    input_data = Tensor(np.ones([1, 3, 224, 224]))
253    label = Tensor(np.zeros([1, 10]))
254    _cell_graph_executor.compile(net, input_data, label)
255
256
257def test_train_step_training():
258    net = train_step_with_loss_warp(resnet9())
259    input_data = Tensor(np.ones([1, 3, 224, 224]))
260    label = Tensor(np.zeros([1, 10]))
261    net.set_train()
262    _cell_graph_executor.compile(net, input_data, label)
263