• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16resnet50 example
17"""
18import numpy as np
19
20import mindspore.context as context
21import mindspore.nn as nn
22from mindspore import Tensor, Model
23from mindspore.context import ParallelMode
24from mindspore.nn.optim import Momentum
25from mindspore.ops.operations import Add
26from ....dataset_mock import MindData
27
28
29def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'):
30    """3x3 convolution """
31    return nn.Conv2d(in_channels, out_channels,
32                     kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode)
33
34
35def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'):
36    """1x1 convolution"""
37    return nn.Conv2d(in_channels, out_channels,
38                     kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode)
39
40
41class ResidualBlock(nn.Cell):
42    """
43    residual Block
44    """
45    expansion = 4
46
47    def __init__(self,
48                 in_channels,
49                 out_channels,
50                 stride=1,
51                 down_sample=False):
52        super(ResidualBlock, self).__init__()
53
54        out_chls = out_channels // self.expansion
55        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
56        self.bn1 = nn.BatchNorm2d(out_chls)
57
58        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
59        self.bn2 = nn.BatchNorm2d(out_chls)
60
61        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
62        self.bn3 = nn.BatchNorm2d(out_channels)
63
64        self.relu = nn.ReLU()
65        self.downsample = down_sample
66
67        self.conv_down_sample = conv1x1(in_channels, out_channels,
68                                        stride=stride, padding=0)
69        self.bn_down_sample = nn.BatchNorm2d(out_channels)
70        self.add = Add()
71
72    def construct(self, x):
73        """
74        :param x:
75        :return:
76        """
77        identity = x
78
79        out = self.conv1(x)
80        out = self.bn1(out)
81        out = self.relu(out)
82
83        out = self.conv2(out)
84        out = self.bn2(out)
85        out = self.relu(out)
86
87        out = self.conv3(out)
88        out = self.bn3(out)
89
90        if self.downsample:
91            identity = self.conv_down_sample(identity)
92            identity = self.bn_down_sample(identity)
93
94        out = self.add(out, identity)
95        out = self.relu(out)
96
97        return out
98
99
100class ResNet18(nn.Cell):
101    """
102    resnet nn.Cell
103    """
104
105    def __init__(self, block, num_classes=100):
106        super(ResNet18, self).__init__()
107
108        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
109        self.bn1 = nn.BatchNorm2d(64)
110        self.relu = nn.ReLU()
111        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
112
113        self.layer1 = self.MakeLayer(
114            block, 2, in_channels=64, out_channels=256, stride=1)
115        self.layer2 = self.MakeLayer(
116            block, 2, in_channels=256, out_channels=512, stride=2)
117        self.layer3 = self.MakeLayer(
118            block, 2, in_channels=512, out_channels=1024, stride=2)
119        self.layer4 = self.MakeLayer(
120            block, 2, in_channels=1024, out_channels=2048, stride=2)
121
122        self.avgpool = nn.AvgPool2d(7, 1)
123        self.flatten = nn.Flatten()
124        self.fc = nn.Dense(512 * block.expansion, num_classes)
125
126    def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
127        """
128        make block layer
129        :param block:
130        :param layer_num:
131        :param in_channels:
132        :param out_channels:
133        :param stride:
134        :return:
135        """
136        layers = []
137        resblk = block(in_channels, out_channels,
138                       stride=stride, down_sample=True)
139        layers.append(resblk)
140
141        for _ in range(1, layer_num):
142            resblk = block(out_channels, out_channels, stride=1)
143            layers.append(resblk)
144
145        return nn.SequentialCell(layers)
146
147    def construct(self, x):
148        """
149        :param x:
150        :return:
151        """
152        x = self.conv1(x)
153        x = self.bn1(x)
154        x = self.relu(x)
155        x = self.maxpool(x)
156
157        x = self.layer1(x)
158        x = self.layer2(x)
159        x = self.layer3(x)
160        x = self.layer4(x)
161
162        x = self.avgpool(x)
163        x = self.flatten(x)
164        x = self.fc(x)
165
166        return x
167
168
169class ResNet9(nn.Cell):
170    """
171    resnet nn.Cell
172    """
173
174    def __init__(self, block, num_classes=100):
175        super(ResNet9, self).__init__()
176
177        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
178        self.bn1 = nn.BatchNorm2d(64)
179        self.relu = nn.ReLU()
180        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
181
182        self.layer1 = self.MakeLayer(
183            block, 1, in_channels=64, out_channels=256, stride=1)
184        self.layer2 = self.MakeLayer(
185            block, 1, in_channels=256, out_channels=512, stride=2)
186        self.layer3 = self.MakeLayer(
187            block, 1, in_channels=512, out_channels=1024, stride=2)
188        self.layer4 = self.MakeLayer(
189            block, 1, in_channels=1024, out_channels=2048, stride=2)
190
191        self.avgpool = nn.AvgPool2d(7, 1)
192        self.flatten = nn.Flatten()
193        self.fc = nn.Dense(512 * block.expansion, num_classes)
194
195    def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
196        """
197        make block layer
198        :param block:
199        :param layer_num:
200        :param in_channels:
201        :param out_channels:
202        :param stride:
203        :return:
204        """
205        layers = []
206        resblk = block(in_channels, out_channels,
207                       stride=stride, down_sample=True)
208        layers.append(resblk)
209
210        for _ in range(1, layer_num):
211            resblk = block(out_channels, out_channels, stride=1)
212            layers.append(resblk)
213
214        return nn.SequentialCell(layers)
215
216    def construct(self, x):
217        """
218        :param x:
219        :return:
220        """
221        x = self.conv1(x)
222        x = self.bn1(x)
223        x = self.relu(x)
224        x = self.maxpool(x)
225
226        x = self.layer1(x)
227        x = self.layer2(x)
228        x = self.layer3(x)
229        x = self.layer4(x)
230
231        x = self.avgpool(x)
232        x = self.flatten(x)
233        x = self.fc(x)
234
235        return x
236
237
238def resnet9(classnum):
239    return ResNet9(ResidualBlock, classnum)
240
241
242class DatasetLenet(MindData):
243    """DatasetLenet definition"""
244
245    def __init__(self, predict, label, length=3, size=None, batch_size=None,
246                 np_types=None, output_shapes=None, input_indexs=()):
247        super(DatasetLenet, self).__init__(size=size, batch_size=batch_size,
248                                           np_types=np_types, output_shapes=output_shapes,
249                                           input_indexs=input_indexs)
250        self.predict = predict
251        self.label = label
252        self.index = 0
253        self.length = length
254
255    def __iter__(self):
256        return self
257
258    def __next__(self):
259        if self.index >= self.length:
260            raise StopIteration
261        self.index += 1
262        return self.predict, self.label
263
264    def reset(self):
265        self.index = 0
266
267
268def test_resnet_train_tensor():
269    """test_resnet_train_tensor"""
270    batch_size = 1
271    size = 2
272    context.set_context(mode=context.GRAPH_MODE)
273    context.reset_auto_parallel_context()
274    context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, device_num=size,
275                                      parameter_broadcast=True)
276    one_hot_len = 10
277    dataset_types = (np.float32, np.float32)
278    dataset_shapes = [[batch_size, 3, 224, 224], [batch_size, one_hot_len]]
279    predict = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01)
280    label = Tensor(np.zeros([batch_size, one_hot_len]).astype(np.float32))
281    dataset = DatasetLenet(predict, label, 2,
282                           size=2, batch_size=2,
283                           np_types=dataset_types,
284                           output_shapes=dataset_shapes,
285                           input_indexs=(0, 1))
286    dataset.reset()
287    network = resnet9(one_hot_len)
288    network.set_train()
289    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
290    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), learning_rate=0.1, momentum=0.9)
291    model = Model(network=network, loss_fn=loss_fn, optimizer=optimizer)
292    model.train(epoch=2, train_dataset=dataset, dataset_sink_mode=False)
293    context.set_context(mode=context.GRAPH_MODE)
294    context.reset_auto_parallel_context()
295
296
297class_num = 10
298
299
300def get_dataset():
301    dataset_types = (np.float32, np.float32)
302    dataset_shapes = ((32, 3, 224, 224), (32, class_num))
303
304    dataset = MindData(size=2, batch_size=1,
305                       np_types=dataset_types,
306                       output_shapes=dataset_shapes,
307                       input_indexs=(0, 1))
308    return dataset
309