1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ut for model serialize(save/load)""" 16import os 17import stat 18import numpy as np 19import pytest 20 21import mindspore.nn as nn 22from mindspore import context 23from mindspore.common.parameter import Parameter 24from mindspore.common.tensor import Tensor 25from mindspore.ops import operations as P 26from mindspore.train.serialization import export 27 28context.set_context(mode=context.GRAPH_MODE) 29 30 31def is_enable_onnxruntime(): 32 val = os.getenv("ENABLE_ONNXRUNTIME", "False") 33 if val in ('ON', 'on', 'TRUE', 'True', 'true'): 34 return True 35 return False 36 37 38run_on_onnxruntime = pytest.mark.skipif(not is_enable_onnxruntime(), reason="Only support running on onnxruntime") 39 40 41def teardown_module(): 42 cur_dir = os.path.dirname(os.path.realpath(__file__)) 43 for filename in os.listdir(cur_dir): 44 if filename.find('ms_output_') == 0 and filename.find('.pb') > 0: 45 # delete temp files generated by run ut 46 os.chmod(filename, stat.S_IWRITE) 47 os.remove(filename) 48 49 50class BatchNormTester(nn.Cell): 51 """used to test exporting network in training mode in onnx format""" 52 53 def __init__(self, num_features): 54 super(BatchNormTester, self).__init__() 55 self.bn = nn.BatchNorm2d(num_features) 56 57 def construct(self, x): 58 return self.bn(x) 59 60 61def test_batchnorm_train_onnx_export(): 62 """test onnx export interface does not modify trainable flag of a network""" 63 input_ = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01) 64 net = BatchNormTester(3) 65 net.set_train() 66 if not net.training: 67 raise ValueError('netowrk is not in training mode') 68 onnx_file = 'batch_norm' 69 export(net, input_, file_name=onnx_file, file_format='ONNX') 70 71 if not net.training: 72 raise ValueError('netowrk is not in training mode') 73 74 file_name = "batch_norm.onnx" 75 assert os.path.exists(file_name) 76 os.chmod(file_name, stat.S_IWRITE) 77 os.remove(file_name) 78 79 80class LeNet5(nn.Cell): 81 """LeNet5 definition""" 82 83 def __init__(self): 84 super(LeNet5, self).__init__() 85 self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid') 86 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 87 self.fc1 = nn.Dense(16 * 5 * 5, 120) 88 self.fc2 = nn.Dense(120, 84) 89 self.fc3 = nn.Dense(84, 10) 90 self.relu = nn.ReLU() 91 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 92 self.flatten = P.Flatten() 93 94 def construct(self, x): 95 x = self.max_pool2d(self.relu(self.conv1(x))) 96 x = self.max_pool2d(self.relu(self.conv2(x))) 97 x = self.flatten(x) 98 x = self.relu(self.fc1(x)) 99 x = self.relu(self.fc2(x)) 100 x = self.fc3(x) 101 return x 102 103 104class DefinedNet(nn.Cell): 105 """simple Net definition with maxpoolwithargmax.""" 106 107 def __init__(self, num_classes=10): 108 super(DefinedNet, self).__init__() 109 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") 110 self.bn1 = nn.BatchNorm2d(64) 111 self.relu = nn.ReLU() 112 self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=2, strides=2) 113 self.flatten = nn.Flatten() 114 self.fc = nn.Dense(int(56 * 56 * 64), num_classes) 115 116 def construct(self, x): 117 x = self.conv1(x) 118 x = self.bn1(x) 119 x = self.relu(x) 120 x, argmax = self.maxpool(x) 121 x = self.flatten(x) 122 x = self.fc(x) 123 return x 124 125 126class DepthwiseConv2dAndReLU6(nn.Cell): 127 """Net for testing DepthwiseConv2d and ReLU6""" 128 def __init__(self, input_channel, kernel_size): 129 super(DepthwiseConv2dAndReLU6, self).__init__() 130 weight_shape = [1, input_channel, kernel_size, kernel_size] 131 from mindspore.common.initializer import initializer 132 self.weight = Parameter(initializer('ones', weight_shape), name='weight') 133 self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size)) 134 self.relu6 = nn.ReLU6() 135 136 def construct(self, x): 137 x = self.depthwise_conv(x, self.weight) 138 x = self.relu6(x) 139 return x 140 141 142class DeepFMOpNet(nn.Cell): 143 """Net definition with Gatherv2 and Tile and Square.""" 144 def __init__(self): 145 super(DeepFMOpNet, self).__init__() 146 self.gather = P.Gather() 147 self.square = P.Square() 148 self.tile = P.Tile() 149 150 def construct(self, x, y): 151 x = self.tile(x, (1000, 1)) 152 x = self.square(x) 153 x = self.gather(x, y, 0) 154 return x 155 156 157def gen_tensor(shape, dtype=np.float32): 158 return Tensor(np.ones(shape).astype(dtype)) 159 160 161net_cfgs = [ 162 ('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])), 163 ('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])), 164 ('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])), 165 ('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32))) 166] 167 168 169def get_id(cfg): 170 _ = cfg 171 return list(map(lambda x: x[0], net_cfgs)) 172 173 174# use `pytest test_onnx.py::test_onnx_export[name]` or `pytest test_onnx.py::test_onnx_export -k name` to run single ut 175@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs)) 176def test_onnx_export(name, net, inp): 177 if isinstance(inp, (tuple, list)): 178 export(net, *inp, file_name=name, file_format='ONNX') 179 else: 180 export(net, inp, file_name=name, file_format='ONNX') 181 182 file_file = name + ".onnx" 183 assert os.path.exists(file_file) 184 os.chmod(file_file, stat.S_IWRITE) 185 os.remove(file_file) 186 187 188@run_on_onnxruntime 189@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs)) 190def test_onnx_export_load_run(name, net, inp): 191 export(net, inp, file_name=name, file_format='ONNX') 192 193 import onnx 194 import onnxruntime as ort 195 196 print('--------------------- onnx load ---------------------') 197 # Load the ONNX model 198 model = onnx.load(onnx_file) 199 # Check that the IR is well formed 200 onnx.checker.check_model(model) 201 # Print a human readable representation of the graph 202 g = onnx.helper.printable_graph(model.graph) 203 print(g) 204 205 print('------------------ onnxruntime run ------------------') 206 ort_session = ort.InferenceSession(onnx_file) 207 input_map = {'x': inp.asnumpy()} 208 # provide only input x to run model 209 outputs = ort_session.run(None, input_map) 210 print(outputs[0]) 211 # overwrite default weight to run model 212 for item in net.trainable_params(): 213 default_value = item.data.asnumpy() 214 input_map[item.name] = np.ones(default_value.shape, dtype=default_value.dtype) 215 outputs = ort_session.run(None, input_map) 216 print(outputs[0]) 217 218 file_name = name + ".onnx" 219 assert os.path.exists(file_name) 220 os.chmod(file_name, stat.S_IWRITE) 221 os.remove(file_name) 222