1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ut for model serialize(save/load)""" 16import os 17import platform 18import stat 19import time 20import secrets 21 22import numpy as np 23import pytest 24 25import mindspore.common.dtype as mstype 26import mindspore.nn as nn 27from mindspore import context 28from mindspore.common.parameter import Parameter 29from mindspore.common.tensor import Tensor 30from mindspore.nn import SoftmaxCrossEntropyWithLogits 31from mindspore.nn import WithLossCell, TrainOneStepCell 32from mindspore.nn.optim.momentum import Momentum 33from mindspore.ops import operations as P 34from mindspore.train.callback import _CheckpointManager 35from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ 36 export, _save_graph, load 37from tests.security_utils import security_off_wrap 38from ..ut_filter import non_graph_engine 39 40context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") 41 42 43class Net(nn.Cell): 44 """Net definition.""" 45 46 def __init__(self, num_classes=10): 47 super(Net, self).__init__() 48 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") 49 self.bn1 = nn.BatchNorm2d(64) 50 self.relu = nn.ReLU() 51 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 52 self.flatten = nn.Flatten() 53 self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes) 54 55 def construct(self, x): 56 x = self.conv1(x) 57 x = self.bn1(x) 58 x = self.relu(x) 59 x = self.maxpool(x) 60 x = self.flatten(x) 61 x = self.fc(x) 62 return x 63 64 65_input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 66_cur_dir = os.path.dirname(os.path.realpath(__file__)) 67 68 69def setup_module(): 70 import shutil 71 if os.path.exists('./test_files'): 72 shutil.rmtree('./test_files') 73 74 75def test_save_graph(): 76 """ test_exec_save_graph """ 77 78 class Net1(nn.Cell): 79 def __init__(self): 80 super(Net1, self).__init__() 81 self.add = P.Add() 82 83 def construct(self, x, y): 84 z = self.add(x, y) 85 return z 86 87 net = Net1() 88 net.set_train() 89 out_me_list = [] 90 x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32)) 91 y = Tensor(np.array([1.2]).astype(np.float32)) 92 out_put = net(x, y) 93 output_file = "net-graph.meta" 94 _save_graph(network=net, file_name=output_file) 95 out_me_list.append(out_put) 96 assert os.path.exists(output_file) 97 os.chmod(output_file, stat.S_IWRITE) 98 os.remove(output_file) 99 100 101def test_save_checkpoint_for_list(): 102 """ test save_checkpoint for list""" 103 parameter_list = [] 104 one_param = {} 105 param1 = {} 106 param2 = {} 107 one_param['name'] = "param_test" 108 one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) 109 param1['name'] = "param" 110 param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32) 111 param2['name'] = "new_param" 112 param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32) 113 parameter_list.append(one_param) 114 parameter_list.append(param1) 115 parameter_list.append(param2) 116 117 if os.path.exists('./parameters.ckpt'): 118 os.chmod('./parameters.ckpt', stat.S_IWRITE) 119 os.remove('./parameters.ckpt') 120 121 ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') 122 save_checkpoint(parameter_list, ckpt_file_name) 123 124 125def test_save_checkpoint_for_list_append_info(): 126 """ test save_checkpoint for list append info""" 127 parameter_list = [] 128 one_param = {} 129 param1 = {} 130 param2 = {} 131 one_param['name'] = "param_test" 132 one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) 133 param1['name'] = "param" 134 param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32) 135 param2['name'] = "new_param" 136 param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32) 137 parameter_list.append(one_param) 138 parameter_list.append(param1) 139 parameter_list.append(param2) 140 append_dict = {"lr": 0.01, "epoch": 20, "train": True} 141 if os.path.exists('./parameters.ckpt'): 142 os.chmod('./parameters.ckpt', stat.S_IWRITE) 143 os.remove('./parameters.ckpt') 144 145 ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') 146 save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict) 147 148 149def test_load_checkpoint_error_filename(): 150 ckpt_file_name = 1 151 with pytest.raises(ValueError): 152 load_checkpoint(ckpt_file_name) 153 154 155def test_load_checkpoint(): 156 ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') 157 par_dict = load_checkpoint(ckpt_file_name) 158 159 assert len(par_dict) == 6 160 assert par_dict['param_test'].name == 'param_test' 161 assert par_dict['param_test'].data.dtype == mstype.float32 162 assert par_dict['param_test'].data.shape == (1, 3, 224, 224) 163 assert isinstance(par_dict, dict) 164 165 166def test_checkpoint_manager(): 167 """ test_checkpoint_manager """ 168 ckp_mgr = _CheckpointManager() 169 170 ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt') 171 with open(ckpt_file_name, 'w'): 172 os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR) 173 174 ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") 175 assert ckp_mgr.ckpoint_num == 1 176 177 ckp_mgr.remove_ckpoint_file(ckpt_file_name) 178 ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") 179 assert ckp_mgr.ckpoint_num == 0 180 assert not os.path.exists(ckpt_file_name) 181 182 another_file_name = os.path.join(_cur_dir, './test-2_1.ckpt') 183 another_file_name = os.path.realpath(another_file_name) 184 with open(another_file_name, 'w'): 185 os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR) 186 187 ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") 188 assert ckp_mgr.ckpoint_num == 1 189 ckp_mgr.remove_oldest_ckpoint_file() 190 ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") 191 assert ckp_mgr.ckpoint_num == 0 192 assert not os.path.exists(another_file_name) 193 194 # test keep_one_ckpoint_per_minutes 195 file1 = os.path.realpath(os.path.join(_cur_dir, './time_file-1_1.ckpt')) 196 file2 = os.path.realpath(os.path.join(_cur_dir, './time_file-2_1.ckpt')) 197 file3 = os.path.realpath(os.path.join(_cur_dir, './time_file-3_1.ckpt')) 198 with open(file1, 'w'): 199 os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR) 200 with open(file2, 'w'): 201 os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR) 202 with open(file3, 'w'): 203 os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR) 204 time1 = time.time() 205 ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file") 206 assert ckp_mgr.ckpoint_num == 3 207 ckp_mgr.keep_one_ckpoint_per_minutes(1, time1) 208 ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file") 209 assert ckp_mgr.ckpoint_num == 1 210 if os.path.exists(_cur_dir + '/time_file-1_1.ckpt'): 211 os.chmod(_cur_dir + '/time_file-1_1.ckpt', stat.S_IWRITE) 212 os.remove(_cur_dir + '/time_file-1_1.ckpt') 213 214 215def test_load_param_into_net_error_net(): 216 parameter_dict = {} 217 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32), 218 name="conv1.weight") 219 parameter_dict["conv1.weight"] = one_param 220 with pytest.raises(TypeError): 221 load_param_into_net('', parameter_dict) 222 223 224def test_load_param_into_net_error_dict(): 225 net = Net(10) 226 with pytest.raises(TypeError): 227 load_param_into_net(net, '') 228 229 230def test_load_param_into_net_erro_dict_param(): 231 net = Net(10) 232 net.init_parameters_data() 233 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 234 235 parameter_dict = {} 236 one_param = '' 237 parameter_dict["conv1.weight"] = one_param 238 with pytest.raises(TypeError): 239 load_param_into_net(net, parameter_dict) 240 241 242def test_load_param_into_net_has_more_param(): 243 """ test_load_param_into_net_has_more_param """ 244 net = Net(10) 245 net.init_parameters_data() 246 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 247 248 parameter_dict = {} 249 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32), 250 name="conv1.weight") 251 parameter_dict["conv1.weight"] = one_param 252 two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32), 253 name="conv1.weight") 254 parameter_dict["conv1.w"] = two_param 255 load_param_into_net(net, parameter_dict) 256 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1 257 258 259def test_load_param_into_net_param_type_and_shape_error(): 260 net = Net(10) 261 net.init_parameters_data() 262 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 263 264 parameter_dict = {} 265 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight") 266 parameter_dict["conv1.weight"] = one_param 267 with pytest.raises(RuntimeError): 268 load_param_into_net(net, parameter_dict) 269 270 271def test_load_param_into_net_param_type_error(): 272 net = Net(10) 273 net.init_parameters_data() 274 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 275 276 parameter_dict = {} 277 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), 278 name="conv1.weight") 279 parameter_dict["conv1.weight"] = one_param 280 with pytest.raises(RuntimeError): 281 load_param_into_net(net, parameter_dict) 282 283 284def test_load_param_into_net_param_shape_error(): 285 net = Net(10) 286 net.init_parameters_data() 287 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 288 289 parameter_dict = {} 290 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32), 291 name="conv1.weight") 292 parameter_dict["conv1.weight"] = one_param 293 with pytest.raises(RuntimeError): 294 load_param_into_net(net, parameter_dict) 295 296 297def test_load_param_into_net(): 298 net = Net(10) 299 net.init_parameters_data() 300 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0 301 302 parameter_dict = {} 303 one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32), 304 name="conv1.weight") 305 parameter_dict["conv1.weight"] = one_param 306 load_param_into_net(net, parameter_dict) 307 assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1 308 309 310def test_save_checkpoint_for_network(): 311 """ test save_checkpoint for network""" 312 net = Net() 313 loss = SoftmaxCrossEntropyWithLogits(sparse=True) 314 opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) 315 316 loss_net = WithLossCell(net, loss) 317 train_network = TrainOneStepCell(loss_net, opt) 318 save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") 319 320 load_checkpoint("new_ckpt.ckpt") 321 322 323def test_load_checkpoint_empty_file(): 324 os.mknod("empty.ckpt") 325 with pytest.raises(ValueError): 326 load_checkpoint("empty.ckpt") 327 328 329def test_save_and_load_checkpoint_for_network_with_encryption(): 330 """ test save and checkpoint for network with encryption""" 331 net = Net() 332 loss = SoftmaxCrossEntropyWithLogits(sparse=True) 333 opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) 334 335 loss_net = WithLossCell(net, loss) 336 train_network = TrainOneStepCell(loss_net, opt) 337 key = secrets.token_bytes(16) 338 mode = "AES-GCM" 339 ckpt_path = "./encrypt_ckpt.ckpt" 340 if platform.system().lower() == "windows": 341 with pytest.raises(NotImplementedError): 342 save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) 343 param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") 344 load_param_into_net(net, param_dict) 345 else: 346 save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) 347 param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") 348 load_param_into_net(net, param_dict) 349 if os.path.exists(ckpt_path): 350 os.remove(ckpt_path) 351 352 353class MYNET(nn.Cell): 354 """ NET definition """ 355 356 def __init__(self): 357 super(MYNET, self).__init__() 358 self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid') 359 self.bn = nn.BatchNorm2d(64) 360 self.relu = nn.ReLU() 361 self.flatten = nn.Flatten() 362 self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0 363 364 def construct(self, x): 365 x = self.conv(x) 366 x = self.bn(x) 367 x = self.relu(x) 368 x = self.flatten(x) 369 out = self.fc(x) 370 return out 371 372 373@non_graph_engine 374def test_export(): 375 net = MYNET() 376 input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 377 with pytest.raises(ValueError): 378 export(net, input_data, file_name="./me_export.pb", file_format="AIR") 379 380 381@non_graph_engine 382def test_mindir_export(): 383 net = MYNET() 384 input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 385 export(net, input_data, file_name="./me_binary_export", file_format="MINDIR") 386 387 388@non_graph_engine 389def test_mindir_export_and_load_with_encryption(): 390 net = MYNET() 391 input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 392 key = secrets.token_bytes(16) 393 export(net, input_data, file_name="./me_cipher_binary_export.mindir", file_format="MINDIR", enc_key=key) 394 load("./me_cipher_binary_export.mindir", dec_key=key) 395 396 397 398class PrintNet(nn.Cell): 399 def __init__(self): 400 super(PrintNet, self).__init__() 401 self.print = P.Print() 402 403 def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, 404 scale1, scale2): 405 self.print('============tensor int8:==============', int8) 406 self.print('============tensor int8:==============', int8) 407 self.print('============tensor uint8:==============', uint8) 408 self.print('============tensor int16:==============', int16) 409 self.print('============tensor uint16:==============', uint16) 410 self.print('============tensor int32:==============', int32) 411 self.print('============tensor uint32:==============', uint32) 412 self.print('============tensor int64:==============', int64) 413 self.print('============tensor uint64:==============', uint64) 414 self.print('============tensor float16:==============', flt16) 415 self.print('============tensor float32:==============', flt32) 416 self.print('============tensor float64:==============', flt64) 417 self.print('============tensor bool:==============', bool_) 418 self.print('============tensor scale1:==============', scale1) 419 self.print('============tensor scale2:==============', scale2) 420 return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2 421 422 423@security_off_wrap 424def test_print(): 425 print_net = PrintNet() 426 int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8")) 427 uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8")) 428 int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16")) 429 uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16")) 430 int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32")) 431 uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32")) 432 int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64")) 433 uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64")) 434 float16 = Tensor(np.random.rand(224, 224).astype(np.float16)) 435 float32 = Tensor(np.random.rand(224, 224).astype(np.float32)) 436 float64 = Tensor(np.random.rand(224, 224).astype(np.float64)) 437 bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_)) 438 scale1 = Tensor(np.array(1)) 439 scale2 = Tensor(np.array(0.1)) 440 print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1, 441 scale2) 442 443 444def teardown_module(): 445 files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] 446 for item in files: 447 file_name = './' + item 448 if not os.path.exists(file_name): 449 continue 450 os.chmod(file_name, stat.S_IWRITE) 451 os.remove(file_name) 452 import shutil 453 if os.path.exists('./print'): 454 shutil.rmtree('./print') 455