• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import pytest
22
23import mindspore.context as context
24import mindspore.nn as nn
25from mindspore import Tensor
26from mindspore import amp
27from mindspore.nn import Dense
28from mindspore.nn import TrainOneStepCell, WithLossCell
29from mindspore.nn.cell import Cell
30from mindspore.nn.layer.basic import Flatten
31from mindspore.nn.layer.conv import Conv2d
32from mindspore.nn.layer.normalization import BatchNorm2d
33from mindspore.nn.layer.pooling import MaxPool2d
34from mindspore.nn.optim import Momentum
35from mindspore.ops import operations as P
36from mindspore.ops.operations import Add
37
38context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
39
40
41def random_normal_init(shape, mean=0.0, stddev=0.01, seed=None):
42    init_value = np.ones(shape).astype(np.float32) * 0.01
43    return Tensor(init_value)
44
45
46def variance_scaling_raw(shape):
47    variance_scaling_value = np.ones(shape).astype(np.float32) * 0.01
48    return Tensor(variance_scaling_value)
49
50
51def weight_variable_0(shape):
52    zeros = np.zeros(shape).astype(np.float32)
53    return Tensor(zeros)
54
55
56def weight_variable_1(shape):
57    ones = np.ones(shape).astype(np.float32)
58    return Tensor(ones)
59
60
61def conv3x3(in_channels, out_channels, stride=1, padding=1):
62    """3x3 convolution """
63    weight_shape = (out_channels, in_channels, 3, 3)
64    weight = variance_scaling_raw(weight_shape)
65    return Conv2d(in_channels, out_channels,
66                  kernel_size=3, stride=stride, weight_init=weight, has_bias=False, pad_mode="same")
67
68
69def conv1x1(in_channels, out_channels, stride=1, padding=0):
70    """1x1 convolution"""
71    weight_shape = (out_channels, in_channels, 1, 1)
72    weight = variance_scaling_raw(weight_shape)
73    return Conv2d(in_channels, out_channels,
74                  kernel_size=1, stride=stride, weight_init=weight, has_bias=False, pad_mode="same")
75
76
77def conv7x7(in_channels, out_channels, stride=1, padding=0):
78    """1x1 convolution"""
79    weight_shape = (out_channels, in_channels, 7, 7)
80    weight = variance_scaling_raw(weight_shape)
81    return Conv2d(in_channels, out_channels,
82                  kernel_size=7, stride=stride, weight_init=weight, has_bias=False, pad_mode="same")
83
84
85def bn_with_initialize(out_channels):
86    shape = (out_channels)
87    mean = weight_variable_0(shape)
88    var = weight_variable_1(shape)
89    beta = weight_variable_0(shape)
90    gamma = weight_variable_1(shape)
91    bn = BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
92                     beta_init=beta, moving_mean_init=mean, moving_var_init=var)
93    return bn
94
95
96def bn_with_initialize_last(out_channels):
97    shape = (out_channels)
98    mean = weight_variable_0(shape)
99    var = weight_variable_1(shape)
100    beta = weight_variable_0(shape)
101    gamma = weight_variable_0(shape)
102    bn = BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
103                     beta_init=beta, moving_mean_init=mean, moving_var_init=var)
104    return bn
105
106
107def fc_with_initialize(input_channels, out_channels):
108    weight_shape = (out_channels, input_channels)
109    bias_shape = (out_channels)
110    weight = random_normal_init(weight_shape)
111    bias = weight_variable_0(bias_shape)
112
113    return Dense(input_channels, out_channels, weight, bias)
114
115
116class ResidualBlock(Cell):
117    expansion = 4
118
119    def __init__(self,
120                 in_channels,
121                 out_channels,
122                 stride=1,
123                 down_sample=False):
124        super(ResidualBlock, self).__init__()
125
126        out_chls = out_channels // self.expansion
127        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
128        self.bn1 = bn_with_initialize(out_chls)
129
130        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
131        self.bn2 = bn_with_initialize(out_chls)
132
133        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
134        self.bn3 = bn_with_initialize_last(out_channels)
135
136        self.relu = P.ReLU()
137        self.add = Add()
138
139    def construct(self, x):
140        identity = x
141
142        out = self.conv1(x)
143        out = self.bn1(out)
144        out = self.relu(out)
145
146        out = self.conv2(out)
147        out = self.bn2(out)
148        out = self.relu(out)
149
150        out = self.conv3(out)
151        out = self.bn3(out)
152
153        out = self.add(out, identity)
154        out = self.relu(out)
155
156        return out
157
158
159class ResidualBlockWithDown(Cell):
160    expansion = 4
161
162    def __init__(self,
163                 in_channels,
164                 out_channels,
165                 stride=1,
166                 down_sample=False):
167        super(ResidualBlockWithDown, self).__init__()
168
169        out_chls = out_channels // self.expansion
170        self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
171        self.bn1 = bn_with_initialize(out_chls)
172
173        self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
174        self.bn2 = bn_with_initialize(out_chls)
175
176        self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
177        self.bn3 = bn_with_initialize_last(out_channels)
178
179        self.relu = P.ReLU()
180        self.downSample = down_sample
181
182        self.conv_down_sample = conv1x1(
183            in_channels, out_channels, stride=stride, padding=0)
184        self.bn_down_sample = bn_with_initialize(out_channels)
185        self.add = Add()
186
187    def construct(self, x):
188        identity = x
189
190        out = self.conv1(x)
191        out = self.bn1(out)
192        out = self.relu(out)
193
194        out = self.conv2(out)
195        out = self.bn2(out)
196        out = self.relu(out)
197
198        out = self.conv3(out)
199        out = self.bn3(out)
200
201        identity = self.conv_down_sample(identity)
202        identity = self.bn_down_sample(identity)
203
204        out = self.add(out, identity)
205        out = self.relu(out)
206
207        return out
208
209
210class MakeLayer0(Cell):
211
212    def __init__(self, block, layer_num, in_channels, out_channels, stride):
213        super(MakeLayer0, self).__init__()
214        self.a = ResidualBlockWithDown(
215            in_channels, out_channels, stride=1, down_sample=True)
216        self.b = block(out_channels, out_channels, stride=stride)
217        self.c = block(out_channels, out_channels, stride=1)
218
219    def construct(self, x):
220        x = self.a(x)
221        x = self.b(x)
222        x = self.c(x)
223
224        return x
225
226
227class MakeLayer1(Cell):
228
229    def __init__(self, block, layer_num, in_channels, out_channels, stride):
230        super(MakeLayer1, self).__init__()
231        self.a = ResidualBlockWithDown(
232            in_channels, out_channels, stride=stride, down_sample=True)
233        self.b = block(out_channels, out_channels, stride=1)
234        self.c = block(out_channels, out_channels, stride=1)
235        self.d = block(out_channels, out_channels, stride=1)
236
237    def construct(self, x):
238        x = self.a(x)
239        x = self.b(x)
240        x = self.c(x)
241        x = self.d(x)
242
243        return x
244
245
246class MakeLayer2(Cell):
247
248    def __init__(self, block, layer_num, in_channels, out_channels, stride):
249        super(MakeLayer2, self).__init__()
250        self.a = ResidualBlockWithDown(
251            in_channels, out_channels, stride=stride, down_sample=True)
252        self.b = block(out_channels, out_channels, stride=1)
253        self.c = block(out_channels, out_channels, stride=1)
254        self.d = block(out_channels, out_channels, stride=1)
255        self.e = block(out_channels, out_channels, stride=1)
256        self.f = block(out_channels, out_channels, stride=1)
257
258    def construct(self, x):
259        x = self.a(x)
260        x = self.b(x)
261        x = self.c(x)
262        x = self.d(x)
263        x = self.e(x)
264        x = self.f(x)
265
266        return x
267
268
269class MakeLayer3(Cell):
270
271    def __init__(self, block, layer_num, in_channels, out_channels, stride):
272        super(MakeLayer3, self).__init__()
273        self.a = ResidualBlockWithDown(
274            in_channels, out_channels, stride=stride, down_sample=True)
275        self.b = block(out_channels, out_channels, stride=1)
276        self.c = block(out_channels, out_channels, stride=1)
277
278    def construct(self, x):
279        x = self.a(x)
280        x = self.b(x)
281        x = self.c(x)
282
283        return x
284
285
286class ResNet(Cell):
287
288    def __init__(self, block, layer_num, num_classes=100):
289        super(ResNet, self).__init__()
290
291        self.conv1 = conv7x7(3, 64, stride=2, padding=3)
292
293        self.bn1 = bn_with_initialize(64)
294        self.relu = P.ReLU()
295        self.maxpool = MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
296
297        self.layer1 = MakeLayer0(
298            block, layer_num[0], in_channels=64, out_channels=256, stride=1)
299        self.layer2 = MakeLayer1(
300            block, layer_num[1], in_channels=256, out_channels=512, stride=2)
301        self.layer3 = MakeLayer2(
302            block, layer_num[2], in_channels=512, out_channels=1024, stride=2)
303        self.layer4 = MakeLayer3(
304            block, layer_num[3], in_channels=1024, out_channels=2048, stride=2)
305
306        self.pool = nn.AvgPool2d(7, 1)
307        self.fc = fc_with_initialize(512 * block.expansion, num_classes)
308        self.flatten = Flatten()
309
310    def construct(self, x):
311        x = self.conv1(x)
312        x = self.bn1(x)
313        x = self.relu(x)
314        x = self.maxpool(x)
315
316        x = self.layer1(x)
317        x = self.layer2(x)
318        x = self.layer3(x)
319        x = self.layer4(x)
320
321        x = self.pool(x)
322        x = self.flatten(x)
323        x = self.fc(x)
324        return x
325
326
327def resnet50(num_classes):
328    return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes)
329
330
331@pytest.mark.level0
332@pytest.mark.platform_x86_gpu_training
333@pytest.mark.env_onecard
334def test_trainTensor(num_classes=10, epoch=8, batch_size=1):
335    net = resnet50(num_classes)
336    lr = 0.1
337    momentum = 0.9
338    optimizer = Momentum(filter(lambda x: x.requires_grad,
339                                net.get_parameters()), lr, momentum)
340    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
341    net_with_criterion = WithLossCell(net, criterion)
342    train_network = TrainOneStepCell(
343        net_with_criterion, optimizer)  # optimizer
344    train_network.set_train()
345    losses = []
346    for i in range(0, epoch):
347        data = Tensor(np.ones([batch_size, 3, 224, 224]
348                              ).astype(np.float32) * 0.01)
349        label = Tensor(np.ones([batch_size]).astype(np.int32))
350        loss = train_network(data, label)
351        losses.append(loss)
352    assert (losses[-1].asnumpy() < 1)
353
354
355@pytest.mark.level2
356@pytest.mark.platform_x86_gpu_training
357@pytest.mark.env_onecard
358def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=338):
359    net = resnet50(num_classes)
360    lr = 0.1
361    momentum = 0.9
362    optimizer = Momentum(filter(lambda x: x.requires_grad,
363                                net.get_parameters()), lr, momentum)
364    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
365    net_with_criterion = WithLossCell(net, criterion)
366    train_network = TrainOneStepCell(
367        net_with_criterion, optimizer)  # optimizer
368    train_network.set_train()
369    losses = []
370    for i in range(0, epoch):
371        data = Tensor(np.ones([batch_size, 3, 224, 224]
372                              ).astype(np.float32) * 0.01)
373        label = Tensor(np.ones([batch_size]).astype(np.int32))
374        loss = train_network(data, label)
375        losses.append(loss)
376    assert (losses[-1].asnumpy() < 1)
377
378
379@pytest.mark.level0
380@pytest.mark.platform_x86_gpu_training
381@pytest.mark.env_onecard
382def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
383    net = resnet50(num_classes)
384    lr = 0.1
385    momentum = 0.9
386    optimizer = Momentum(filter(lambda x: x.requires_grad,
387                                net.get_parameters()), lr, momentum)
388    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
389    train_network = amp.build_train_network(
390        net, optimizer, criterion, level="O2")
391    train_network.set_train()
392    losses = []
393    for i in range(0, epoch):
394        data = Tensor(np.ones([batch_size, 3, 224, 224]
395                              ).astype(np.float32) * 0.01)
396        label = Tensor(np.ones([batch_size]).astype(np.int32))
397        loss = train_network(data, label)
398        losses.append(loss)
399    assert (losses[-1][0].asnumpy() < 1)
400    assert not losses[-1][1].asnumpy()
401    assert (losses[-1][2].asnumpy() > 1)
402