• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16import time
17import numpy as np
18import pytest
19
20import mindspore.nn as nn
21from mindspore import context, Tensor, ParameterTuple
22from mindspore.common import dtype as mstype
23from mindspore.common.initializer import TruncatedNormal
24from mindspore.nn.optim import Momentum
25from mindspore.nn.wrap.cell_wrapper import WithLossCell
26from mindspore.ops import composite as C
27from mindspore.ops import functional as F
28from mindspore.ops import operations as P
29
30np.random.seed(1)
31grad_by_list = C.GradOperation(get_by_list=True)
32
33
34def weight_variable():
35    """weight initial"""
36    return TruncatedNormal(0.02)
37
38
39def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
40    """weight initial for conv layer"""
41    weight = weight_variable()
42    return nn.Conv2d(in_channels, out_channels,
43                     kernel_size=kernel_size, stride=stride, padding=padding,
44                     weight_init=weight, has_bias=False, pad_mode="valid")
45
46
47def fc_with_initialize(input_channels, out_channels):
48    """weight initial for fc layer"""
49    weight = weight_variable()
50    bias = weight_variable()
51    return nn.Dense(input_channels, out_channels, weight, bias)
52
53
54class LeNet(nn.Cell):
55    """
56    Lenet network
57    Args:
58        num_class (int): Num classes, Default: 10.
59    Returns:
60        Tensor, output tensor
61    Examples:
62        >>> LeNet(num_class=10)
63    """
64
65    def __init__(self, num_class=10):
66        super(LeNet, self).__init__()
67        self.num_class = num_class
68        self.batch_size = 32
69        self.conv1 = conv(1, 6, 5)
70        self.conv2 = conv(6, 16, 5)
71        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
72        self.fc2 = fc_with_initialize(120, 84)
73        self.fc3 = fc_with_initialize(84, self.num_class)
74        self.relu = nn.ReLU()
75        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
76        self.reshape = P.Reshape()
77
78    def construct(self, x):
79        x = self.conv1(x)
80        x = self.relu(x)
81        x = self.max_pool2d(x)
82        x = self.conv2(x)
83        x = self.relu(x)
84        x = self.max_pool2d(x)
85        x = self.reshape(x, (self.batch_size, -1))
86        x = self.fc1(x)
87        x = self.relu(x)
88        x = self.fc2(x)
89        x = self.relu(x)
90        x = self.fc3(x)
91        return x
92
93
94class CrossEntropyLoss(nn.Cell):
95    """
96    Define loss for network
97    """
98
99    def __init__(self):
100        super(CrossEntropyLoss, self).__init__()
101        self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
102        self.mean = P.ReduceMean()
103        self.one_hot = P.OneHot()
104        self.on_value = Tensor(1.0, mstype.float32)
105        self.off_value = Tensor(0.0, mstype.float32)
106        self.num = Tensor(32.0, mstype.float32)
107
108    def construct(self, logits, label):
109        label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
110        loss = self.cross_entropy(logits, label)[0]
111        loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num)
112        return loss
113
114
115class GradWrap(nn.Cell):
116    """
117    GradWrap definition
118    """
119
120    def __init__(self, network):
121        super(GradWrap, self).__init__()
122        self.network = network
123        self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
124
125    def construct(self, x, label):
126        weights = self.weights
127        return grad_by_list(self.network, weights)(x, label)
128
129
130def test_ascend_lenet():
131    epoch_size = 20
132    batch_size = 32
133    inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
134    labels = Tensor(np.ones([batch_size]).astype(np.int32))
135
136    net = LeNet()
137    criterion = CrossEntropyLoss()
138    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
139
140    net_with_criterion = WithLossCell(net, criterion)
141    train_network = GradWrap(net_with_criterion)
142    train_network.set_train()
143    total_time = 0
144
145    for epoch in range(0, epoch_size):
146        start_time = time.time()
147        fw_output = net(inputs)
148        loss_output = criterion(fw_output, labels)
149        grads = train_network(inputs, labels)
150        optimizer(grads)
151        end_time = time.time()
152        cost_time = end_time - start_time
153        total_time = total_time + cost_time
154
155        print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
156    return loss_output
157
158
159@pytest.mark.level0
160@pytest.mark.platform_arm_ascend_training
161@pytest.mark.platform_x86_ascend_training
162@pytest.mark.env_onecard
163def test_ascend_lenet1():
164    os.environ['GRAPH_OP_RUN'] = str(1)
165    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
166    loss_output = test_ascend_lenet()
167    assert loss_output.asnumpy() < 0.004
168    assert loss_output.asnumpy() > 0.003
169
170@pytest.mark.level0
171@pytest.mark.platform_arm_ascend_training
172@pytest.mark.platform_x86_ascend_training
173@pytest.mark.env_onecard
174def test_ascend_lenet2():
175    os.environ['GRAPH_OP_RUN'] = str(1)
176    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
177    loss_output = test_ascend_lenet()
178    assert loss_output.asnumpy() < 0.004
179    assert loss_output.asnumpy() > 0.003
180