• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ut for model serialize(save/load)"""
16import os
17import stat
18import numpy as np
19import pytest
20
21import mindspore.nn as nn
22from mindspore import context
23from mindspore.common.parameter import Parameter
24from mindspore.common.tensor import Tensor
25from mindspore.ops import operations as P
26from mindspore.train.serialization import export
27
28context.set_context(mode=context.GRAPH_MODE)
29
30
31def is_enable_onnxruntime():
32    val = os.getenv("ENABLE_ONNXRUNTIME", "False")
33    if val in ('ON', 'on', 'TRUE', 'True', 'true'):
34        return True
35    return False
36
37
38run_on_onnxruntime = pytest.mark.skipif(not is_enable_onnxruntime(), reason="Only support running on onnxruntime")
39
40
41def teardown_module():
42    cur_dir = os.path.dirname(os.path.realpath(__file__))
43    for filename in os.listdir(cur_dir):
44        if filename.find('ms_output_') == 0 and filename.find('.pb') > 0:
45            # delete temp files generated by run ut
46            os.chmod(filename, stat.S_IWRITE)
47            os.remove(filename)
48
49
50class BatchNormTester(nn.Cell):
51    """used to test exporting network in training mode in onnx format"""
52
53    def __init__(self, num_features):
54        super(BatchNormTester, self).__init__()
55        self.bn = nn.BatchNorm2d(num_features)
56
57    def construct(self, x):
58        return self.bn(x)
59
60
61def test_batchnorm_train_onnx_export():
62    """test onnx export interface does not modify trainable flag of a network"""
63    input_ = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
64    net = BatchNormTester(3)
65    net.set_train()
66    if not net.training:
67        raise ValueError('netowrk is not in training mode')
68    onnx_file = 'batch_norm'
69    export(net, input_, file_name=onnx_file, file_format='ONNX')
70
71    if not net.training:
72        raise ValueError('netowrk is not in training mode')
73
74    file_name = "batch_norm.onnx"
75    assert os.path.exists(file_name)
76    os.chmod(file_name, stat.S_IWRITE)
77    os.remove(file_name)
78
79
80class LeNet5(nn.Cell):
81    """LeNet5 definition"""
82
83    def __init__(self):
84        super(LeNet5, self).__init__()
85        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
86        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
87        self.fc1 = nn.Dense(16 * 5 * 5, 120)
88        self.fc2 = nn.Dense(120, 84)
89        self.fc3 = nn.Dense(84, 10)
90        self.relu = nn.ReLU()
91        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
92        self.flatten = P.Flatten()
93
94    def construct(self, x):
95        x = self.max_pool2d(self.relu(self.conv1(x)))
96        x = self.max_pool2d(self.relu(self.conv2(x)))
97        x = self.flatten(x)
98        x = self.relu(self.fc1(x))
99        x = self.relu(self.fc2(x))
100        x = self.fc3(x)
101        return x
102
103
104class DefinedNet(nn.Cell):
105    """simple Net definition with maxpoolwithargmax."""
106
107    def __init__(self, num_classes=10):
108        super(DefinedNet, self).__init__()
109        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
110        self.bn1 = nn.BatchNorm2d(64)
111        self.relu = nn.ReLU()
112        self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=2, strides=2)
113        self.flatten = nn.Flatten()
114        self.fc = nn.Dense(int(56 * 56 * 64), num_classes)
115
116    def construct(self, x):
117        x = self.conv1(x)
118        x = self.bn1(x)
119        x = self.relu(x)
120        x, argmax = self.maxpool(x)
121        x = self.flatten(x)
122        x = self.fc(x)
123        return x
124
125
126class DepthwiseConv2dAndReLU6(nn.Cell):
127    """Net for testing DepthwiseConv2d and ReLU6"""
128    def __init__(self, input_channel, kernel_size):
129        super(DepthwiseConv2dAndReLU6, self).__init__()
130        weight_shape = [1, input_channel, kernel_size, kernel_size]
131        from mindspore.common.initializer import initializer
132        self.weight = Parameter(initializer('ones', weight_shape), name='weight')
133        self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
134        self.relu6 = nn.ReLU6()
135
136    def construct(self, x):
137        x = self.depthwise_conv(x, self.weight)
138        x = self.relu6(x)
139        return x
140
141
142class DeepFMOpNet(nn.Cell):
143    """Net definition with Gatherv2 and Tile and Square."""
144    def __init__(self):
145        super(DeepFMOpNet, self).__init__()
146        self.gather = P.Gather()
147        self.square = P.Square()
148        self.tile = P.Tile()
149
150    def construct(self, x, y):
151        x = self.tile(x, (1000, 1))
152        x = self.square(x)
153        x = self.gather(x, y, 0)
154        return x
155
156
157def gen_tensor(shape, dtype=np.float32):
158    return Tensor(np.ones(shape).astype(dtype))
159
160
161net_cfgs = [
162    ('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])),
163    ('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])),
164    ('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])),
165    ('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32)))
166]
167
168
169def get_id(cfg):
170    _ = cfg
171    return list(map(lambda x: x[0], net_cfgs))
172
173
174# use `pytest test_onnx.py::test_onnx_export[name]` or `pytest test_onnx.py::test_onnx_export -k name` to run single ut
175@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
176def test_onnx_export(name, net, inp):
177    if isinstance(inp, (tuple, list)):
178        export(net, *inp, file_name=name, file_format='ONNX')
179    else:
180        export(net, inp, file_name=name, file_format='ONNX')
181
182    file_file = name + ".onnx"
183    assert os.path.exists(file_file)
184    os.chmod(file_file, stat.S_IWRITE)
185    os.remove(file_file)
186
187
188@run_on_onnxruntime
189@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
190def test_onnx_export_load_run(name, net, inp):
191    export(net, inp, file_name=name, file_format='ONNX')
192
193    import onnx
194    import onnxruntime as ort
195
196    print('--------------------- onnx load ---------------------')
197    # Load the ONNX model
198    model = onnx.load(onnx_file)
199    # Check that the IR is well formed
200    onnx.checker.check_model(model)
201    # Print a human readable representation of the graph
202    g = onnx.helper.printable_graph(model.graph)
203    print(g)
204
205    print('------------------ onnxruntime run ------------------')
206    ort_session = ort.InferenceSession(onnx_file)
207    input_map = {'x': inp.asnumpy()}
208    # provide only input x to run model
209    outputs = ort_session.run(None, input_map)
210    print(outputs[0])
211    # overwrite default weight to run model
212    for item in net.trainable_params():
213        default_value = item.data.asnumpy()
214        input_map[item.name] = np.ones(default_value.shape, dtype=default_value.dtype)
215    outputs = ort_session.run(None, input_map)
216    print(outputs[0])
217
218    file_name = name + ".onnx"
219    assert os.path.exists(file_name)
220    os.chmod(file_name, stat.S_IWRITE)
221    os.remove(file_name)
222