• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import os
15import numpy as np
16import pytest
17
18import mindspore.nn as nn
19from mindspore import context
20from mindspore.common.tensor import Tensor
21from mindspore.common.initializer import TruncatedNormal
22from mindspore.common.parameter import ParameterTuple
23from mindspore.ops import operations as P
24from mindspore.ops import composite as C
25from mindspore.train.serialization import export, load
26
27
28def weight_variable():
29    return TruncatedNormal(0.02)
30
31
32def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
33    weight = weight_variable()
34    return nn.Conv2d(in_channels, out_channels,
35                     kernel_size=kernel_size, stride=stride, padding=padding,
36                     weight_init=weight, has_bias=False, pad_mode="valid")
37
38
39def fc_with_initialize(input_channels, out_channels):
40    weight = weight_variable()
41    bias = weight_variable()
42    return nn.Dense(input_channels, out_channels, weight, bias)
43
44
45class LeNet5(nn.Cell):
46    def __init__(self):
47        super(LeNet5, self).__init__()
48        self.batch_size = 32
49        self.conv1 = conv(1, 6, 5)
50        self.conv2 = conv(6, 16, 5)
51        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
52        self.fc2 = fc_with_initialize(120, 84)
53        self.fc3 = fc_with_initialize(84, 10)
54        self.relu = nn.ReLU()
55        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
56        self.reshape = P.Reshape()
57
58    def construct(self, x):
59        x = self.conv1(x)
60        x = self.relu(x)
61        x = self.max_pool2d(x)
62        x = self.conv2(x)
63        x = self.relu(x)
64        x = self.max_pool2d(x)
65        x = self.reshape(x, (self.batch_size, -1))
66        x = self.fc1(x)
67        x = self.relu(x)
68        x = self.fc2(x)
69        x = self.relu(x)
70        x = self.fc3(x)
71        return x
72
73
74class WithLossCell(nn.Cell):
75    def __init__(self, network):
76        super(WithLossCell, self).__init__(auto_prefix=False)
77        self.loss = nn.SoftmaxCrossEntropyWithLogits()
78        self.network = network
79
80    def construct(self, x, label):
81        predict = self.network(x)
82        return self.loss(predict, label)
83
84
85class TrainOneStepCell(nn.Cell):
86    def __init__(self, network):
87        super(TrainOneStepCell, self).__init__(auto_prefix=False)
88        self.network = network
89        self.network.set_train()
90        self.weights = ParameterTuple(network.trainable_params())
91        self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
92        self.hyper_map = C.HyperMap()
93        self.grad = C.GradOperation(get_by_list=True)
94
95    def construct(self, x, label):
96        weights = self.weights
97        grads = self.grad(self.network, weights)(x, label)
98        return self.optimizer(grads)
99
100
101class SingleIfNet(nn.Cell):
102    def construct(self, x, y):
103        x += 1
104        if x < y:
105            y += x
106        else:
107            y -= x
108        y += 5
109        return y
110
111
112@pytest.mark.level0
113@pytest.mark.platform_x86_ascend_training
114@pytest.mark.platform_arm_ascend_training
115@pytest.mark.env_onecard
116def test_export_lenet_grad_mindir():
117    context.set_context(mode=context.GRAPH_MODE)
118    network = LeNet5()
119    network.set_train()
120    predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
121    label = Tensor(np.zeros([32, 10]).astype(np.float32))
122    net = TrainOneStepCell(WithLossCell(network))
123    export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
124    verify_name = "lenet_grad.mindir"
125    assert os.path.exists(verify_name)
126
127
128@pytest.mark.level0
129@pytest.mark.platform_x86_ascend_training
130@pytest.mark.platform_arm_ascend_training
131@pytest.mark.env_onecard
132def test_load_mindir_and_run():
133    context.set_context(mode=context.GRAPH_MODE)
134    network = LeNet5()
135    network.set_train()
136
137    inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
138    outputs0 = network(inputs0)
139
140    inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
141    export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
142    mindir_name = "test_lenet_load.mindir"
143    assert os.path.exists(mindir_name)
144
145    graph = load(mindir_name)
146    loaded_net = nn.GraphCell(graph)
147    outputs_after_load = loaded_net(inputs0)
148    assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
149
150
151@pytest.mark.level0
152@pytest.mark.platform_x86_ascend_training
153@pytest.mark.platform_arm_ascend_training
154@pytest.mark.env_onecard
155def test_single_if():
156    context.set_context(mode=context.GRAPH_MODE)
157    network = SingleIfNet()
158
159    x = Tensor(np.array([1]).astype(np.float32))
160    y = Tensor(np.array([2]).astype(np.float32))
161    origin_out = network(x, y)
162
163    file_name = "if_net"
164    export(network, x, y, file_name=file_name, file_format='MINDIR')
165    mindir_name = file_name + ".mindir"
166    assert os.path.exists(mindir_name)
167
168    graph = load(mindir_name)
169    loaded_net = nn.GraphCell(graph)
170    outputs_after_load = loaded_net(x, y)
171    assert origin_out == outputs_after_load
172