• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16import numpy as np
17import pytest
18
19import mindspore.context as context
20from mindspore import Tensor, nn
21from mindspore.common import dtype as mstype
22from mindspore.train.serialization import export, load
23
24
25class CaseNet(nn.Cell):
26    def __init__(self):
27        super(CaseNet, self).__init__()
28        self.conv = nn.Conv2d(1, 1, 3)
29        self.relu = nn.ReLU()
30        self.relu1 = nn.ReLU()
31        self.softmax = nn.Softmax()
32        self.layers1 = (self.relu, self.softmax)
33        self.layers2 = (self.conv, self.relu1)
34
35    def construct(self, x, index1, index2):
36        x = self.layers1[index1](x)
37        x = self.layers2[index2](x)
38        return x
39
40
41@pytest.mark.level0
42@pytest.mark.platform_x86_ascend_training
43@pytest.mark.platform_arm_ascend_training
44@pytest.mark.env_onecard
45def test_mindir_switch_layer():
46    context.set_context(mode=context.GRAPH_MODE)
47    net = CaseNet()
48    data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
49    idx = Tensor(0, mstype.int32)
50    idx2 = Tensor(-1, mstype.int32)
51
52    file_name = "switch_layer_net"
53    mindir_name = file_name + ".mindir"
54    export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
55    assert os.path.exists(mindir_name)
56
57    graph = load(mindir_name)
58    loaded_net = nn.GraphCell(graph)
59    outputs_after_load = loaded_net(data, idx, idx2)
60    relu = nn.ReLU()
61    true_value = relu(data)
62    ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
63    assert ret
64
65
66@pytest.mark.skip(reason="depend on export")
67@pytest.mark.level0
68@pytest.mark.platform_x86_ascend_training
69@pytest.mark.platform_arm_ascend_training
70@pytest.mark.env_onecard
71def test_mindir_export():
72    context.set_context(mode=context.GRAPH_MODE)
73    net = CaseNet()
74    data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
75    idx = Tensor(0, mstype.int32)
76    idx2 = Tensor(-1, mstype.int32)
77
78    file_name = "switch_layer_net"
79    mindir_name = file_name + ".mindir"
80    export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
81    assert os.path.exists(mindir_name)
82
83
84@pytest.mark.skip(reason="depend on export")
85@pytest.mark.level0
86@pytest.mark.platform_x86_ascend_training
87@pytest.mark.platform_arm_ascend_training
88@pytest.mark.env_onecard
89def test_mindir_load():
90    context.set_context(mode=context.GRAPH_MODE)
91    data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
92    idx = Tensor(0, mstype.int32)
93    idx2 = Tensor(-1, mstype.int32)
94
95    file_name = "switch_layer_net"
96    mindir_name = file_name + ".mindir"
97    graph = load(mindir_name)
98    loaded_net = nn.GraphCell(graph)
99    outputs_after_load = loaded_net(data, idx, idx2)
100    relu = nn.ReLU()
101    true_value = relu(data)
102    ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
103    assert ret
104