• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import numpy as np
18
19import mindspore.common.dtype as mstype
20import mindspore.context as context
21import mindspore.nn as nn
22import mindspore.ops.functional as F
23from mindspore import Tensor
24from mindspore.common.initializer import TruncatedNormal
25from mindspore.communication.management import init
26from mindspore.nn.loss.loss import LossBase
27from mindspore.nn.optim.momentum import Momentum
28from mindspore.ops import operations as P
29from mindspore.parallel import set_algo_parameters
30from mindspore.train.callback import Callback
31from mindspore.train.model import Model
32from mindspore.context import ParallelMode
33
34context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
35context.set_context(device_id=int(os.getenv('DEVICE_ID')))
36init()
37context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
38np.random.seed(10)
39
40
41def weight_variable():
42    return TruncatedNormal(0.01)
43
44
45def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
46    init_value = weight_variable()
47    return nn.Conv2d(in_channels, out_channels,
48                     kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
49
50
51def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
52    init_value = weight_variable()
53    return nn.Conv2d(in_channels, out_channels,
54                     kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
55
56
57def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
58    init_value = weight_variable()
59    return nn.Conv2d(in_channels, out_channels,
60                     kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
61
62
63def _fused_bn(channels, momentum=0.9):
64    return nn.BatchNorm2d(channels, momentum=momentum)
65
66
67class BasicBlock(nn.Cell):
68    expansion = 1
69
70    def __init__(self,
71                 in_channels,
72                 out_channels,
73                 stride=1,
74                 momentum=0.1):
75        super(BasicBlock, self).__init__()
76
77        self.conv1 = _conv3x3(in_channels, out_channels, stride=stride)
78        self.bn1 = _fused_bn(out_channels, momentum=momentum)
79        self.conv2 = _conv3x3(out_channels, out_channels)
80        self.bn2 = _fused_bn(out_channels, momentum=momentum)
81        self.relu = P.ReLU()
82        self.down_sample_layer = None
83        self.downsample = (in_channels != out_channels)
84        if self.downsample:
85            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channels,
86                                                                 out_channels,
87                                                                 stride=stride,
88                                                                 padding=0),
89                                                        _fused_bn(out_channels,
90                                                                  momentum=momentum)])
91        self.add = P.Add()
92
93    def construct(self, x):
94        identity = x
95
96        x = self.conv1(x)
97        x = self.relu(x)
98
99        x = self.conv2(x)
100
101        if self.downsample:
102            identity = self.down_sample_layer(identity)
103
104        out = self.add(x, identity)
105        out = self.relu(out)
106
107        return out
108
109
110class ResidualBlock(nn.Cell):
111    expansion = 4
112
113    def __init__(self,
114                 in_channels,
115                 out_channels,
116                 stride=1):
117        super(ResidualBlock, self).__init__()
118
119        out_chls = out_channels // self.expansion
120        self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
121
122        self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
123
124        self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
125
126        self.relu = P.ReLU()
127        self.downsample = (in_channels != out_channels)
128        self.stride = stride
129        if self.downsample:
130            self.conv_down_sample = _conv1x1(in_channels, out_channels,
131                                             stride=stride)
132        elif self.stride != 1:
133            self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
134
135        self.add = P.Add()
136
137    def construct(self, x):
138        identity = x
139
140        out = self.conv1(x)
141        out = self.relu(out)
142
143        out = self.conv2(out)
144        out = self.relu(out)
145
146        out = self.conv3(out)
147
148        if self.downsample:
149            identity = self.conv_down_sample(identity)
150        elif self.stride != 1:
151            identity = self.maxpool_down(identity)
152
153        out = self.add(out, identity)
154        out = self.relu(out)
155
156        return out
157
158
159class ResNet(nn.Cell):
160    def __init__(self,
161                 block,
162                 layer_nums,
163                 in_channels,
164                 out_channels,
165                 strides=None,
166                 num_classes=100):
167        super(ResNet, self).__init__()
168
169        if strides is None:
170            strides = [1, 2, 2, 2]
171        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
172            raise ValueError("the length of "
173                             "layer_num, inchannel, outchannel list must be 4!")
174
175        self.conv1 = _conv7x7(3, 64, stride=2)
176        self.bn1 = _fused_bn(64)
177        self.relu = P.ReLU()
178        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
179
180        self.layer1 = self._make_layer(block,
181                                       layer_nums[0],
182                                       in_channel=in_channels[0],
183                                       out_channel=out_channels[0],
184                                       stride=strides[0])
185        self.layer2 = self._make_layer(block,
186                                       layer_nums[1],
187                                       in_channel=in_channels[1],
188                                       out_channel=out_channels[1],
189                                       stride=strides[1])
190        self.layer3 = self._make_layer(block,
191                                       layer_nums[2],
192                                       in_channel=in_channels[2],
193                                       out_channel=out_channels[2],
194                                       stride=strides[2])
195        self.layer4 = self._make_layer(block,
196                                       layer_nums[3],
197                                       in_channel=in_channels[3],
198                                       out_channel=out_channels[3],
199                                       stride=strides[3])
200
201        self.mean = P.ReduceMean(keep_dims=True)
202        self.end_point = nn.Dense(2048, num_classes, has_bias=True,
203                                  weight_init=weight_variable(),
204                                  bias_init=weight_variable()).add_flags_recursive(fp16=True)
205        self.squeeze = P.Squeeze()
206        self.cast = P.Cast()
207
208    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
209        layers = []
210        resblk = block(in_channel, out_channel, stride=1)
211        layers.append(resblk)
212
213        for _ in range(1, layer_num - 1):
214            resblk = block(out_channel, out_channel, stride=1)
215            layers.append(resblk)
216
217        resblk = block(out_channel, out_channel, stride=stride)
218        layers.append(resblk)
219
220        return nn.SequentialCell(layers)
221
222    def construct(self, x):
223        x = self.conv1(x)
224        x = self.relu(x)
225        c1 = self.maxpool(x)
226
227        c2 = self.layer1(c1)
228        c3 = self.layer2(c2)
229        c4 = self.layer3(c3)
230        c5 = self.layer4(c4)
231
232        out = self.mean(c5, (2, 3))
233        out = self.squeeze(out)
234        out = self.end_point(out)
235
236        return out
237
238
239def resnet50(class_num=10):
240    return ResNet(ResidualBlock,
241                  [3, 4, 6, 3],
242                  [64, 256, 512, 1024],
243                  [256, 512, 1024, 2048],
244                  [2, 2, 2, 1],
245                  class_num)
246
247
248class SoftmaxCrossEntropyExpand(LossBase):
249    def __init__(self, sparse=False):
250        super(SoftmaxCrossEntropyExpand, self).__init__()
251        self.exp = P.Exp()
252        self.sum = P.ReduceSum(keep_dims=True)
253        self.onehot = P.OneHot()
254        self.on_value = Tensor(1.0, mstype.float32)
255        self.off_value = Tensor(0.0, mstype.float32)
256        self.div = P.Div()
257        self.log = P.Log()
258        self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
259        self.mul = P.Mul()
260        self.mul2 = P.Mul()
261        self.cast = P.Cast()
262        self.mean = P.ReduceMean(keep_dims=False)
263        self.sparse = sparse
264        self.max = P.ReduceMax(keep_dims=True)
265        self.sub = P.Sub()
266        self.eps = Tensor(1e-24, mstype.float32)
267
268    def construct(self, logit, label):
269        logit = self.cast(logit, mstype.float32)
270        logit_max = self.max(logit, -1)
271        exp = self.exp(self.sub(logit, logit_max))
272        exp_sum = self.sum(exp, -1)
273        softmax_result = self.div(exp, exp_sum)
274        if self.sparse:
275            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
276
277        softmax_result_log = self.log(softmax_result + self.eps)
278        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
279        loss = self.mul2(F.scalar_to_array(-1.0), loss)
280        loss = self.mean(loss, -1)
281
282        return loss
283
284
285rank_id = int(os.environ["RANK_ID"])
286device_num = int(os.environ["RANK_SIZE"])
287
288
289class DataGenerator():
290    def get_parallel_blocks(self, input_, strategy):
291        blocks = [input_]
292        i = 0
293        for stra in strategy:
294            temp = []
295            while blocks:
296                block = blocks.pop(0)
297                temp.extend(np.split(block, stra, axis=i))
298            blocks.extend(temp)
299            i += 1
300        return blocks
301
302    def generate_data(self, shape):
303        data = np.arange(np.prod(shape)).reshape(shape)
304        return data
305
306    def input_data(self, shape):
307        data = (self.generate_data(shape)).astype(np.float32)
308        stra = [1] * len(shape)
309        stra[0] = device_num
310        data_parallel = self.get_parallel_blocks(data, stra)
311        return Tensor(data), Tensor(data_parallel[rank_id])
312
313    def label_data(self, shape):
314        data = (self.generate_data(shape) * 1000 / np.prod(shape)).astype(np.int32)
315        stra = [1] * len(shape)
316        stra[0] = device_num
317        data_parallel = self.get_parallel_blocks(data, stra)
318        return Tensor(data), Tensor(data_parallel[rank_id])
319
320
321class Dataset():
322    def __init__(self, predict, label, length=1, input_num=2, repeat_count=1):
323        self.predict = predict
324        self.label = label
325        self.index = 0
326        self.length = length
327        self.input_num = input_num
328        self.repeat_count = repeat_count
329
330    def __iter__(self):
331        return self
332
333    def __next__(self):
334        if self.index >= self.length:
335            raise StopIteration
336        self.index += 1
337        if self.input_num == 2:
338            return (self.predict, self.label)
339        return (self.predict,)
340
341    def reset(self):
342        self.index = 0
343
344    def get_dataset_size(self):
345        return self.length
346
347    def get_repeat_count(self):
348        return self.repeat_count
349
350
351class ModelCallback(Callback):
352    def __init__(self):
353        super(ModelCallback, self).__init__()
354        self.loss_list = []
355
356    def epoch_end(self, run_context):
357        cb_params = run_context.original_args()
358        result = cb_params.net_outputs
359        self.loss_list.append(result.asnumpy().mean())
360
361
362def test_train_feed(num_classes=65536):
363    set_algo_parameters(elementwise_op_strategy_follow=True)
364    parallel_callback = ModelCallback()
365    data_gen = DataGenerator()
366    _, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
367    _, label_part = data_gen.label_data((32 * 8,))
368    dataset = Dataset(input_part, label_part)
369    net = resnet50(num_classes)
370    loss = SoftmaxCrossEntropyExpand(sparse=True)
371    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
372    model = Model(net, loss_fn=loss, optimizer=opt)
373    model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
374    loss_value = np.array(parallel_callback.loss_list)
375    expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148]
376    print(loss_value)
377    assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
378