• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ut for model serialize(save/load)"""
16import os
17import platform
18import stat
19import time
20import secrets
21
22import numpy as np
23import pytest
24
25import mindspore.common.dtype as mstype
26import mindspore.nn as nn
27from mindspore import context
28from mindspore.common.parameter import Parameter
29from mindspore.common.tensor import Tensor
30from mindspore.nn import SoftmaxCrossEntropyWithLogits
31from mindspore.nn import WithLossCell, TrainOneStepCell
32from mindspore.nn.optim.momentum import Momentum
33from mindspore.ops import operations as P
34from mindspore.train.callback import _CheckpointManager
35from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
36     export, _save_graph, load
37from tests.security_utils import security_off_wrap
38from ..ut_filter import non_graph_engine
39
40context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
41
42
43class Net(nn.Cell):
44    """Net definition."""
45
46    def __init__(self, num_classes=10):
47        super(Net, self).__init__()
48        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
49        self.bn1 = nn.BatchNorm2d(64)
50        self.relu = nn.ReLU()
51        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
52        self.flatten = nn.Flatten()
53        self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes)
54
55    def construct(self, x):
56        x = self.conv1(x)
57        x = self.bn1(x)
58        x = self.relu(x)
59        x = self.maxpool(x)
60        x = self.flatten(x)
61        x = self.fc(x)
62        return x
63
64
65_input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
66_cur_dir = os.path.dirname(os.path.realpath(__file__))
67
68
69def setup_module():
70    import shutil
71    if os.path.exists('./test_files'):
72        shutil.rmtree('./test_files')
73
74
75def test_save_graph():
76    """ test_exec_save_graph """
77
78    class Net1(nn.Cell):
79        def __init__(self):
80            super(Net1, self).__init__()
81            self.add = P.Add()
82
83        def construct(self, x, y):
84            z = self.add(x, y)
85            return z
86
87    net = Net1()
88    net.set_train()
89    out_me_list = []
90    x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
91    y = Tensor(np.array([1.2]).astype(np.float32))
92    out_put = net(x, y)
93    output_file = "net-graph.meta"
94    _save_graph(network=net, file_name=output_file)
95    out_me_list.append(out_put)
96    assert os.path.exists(output_file)
97    os.chmod(output_file, stat.S_IWRITE)
98    os.remove(output_file)
99
100
101def test_save_checkpoint_for_list():
102    """ test save_checkpoint for list"""
103    parameter_list = []
104    one_param = {}
105    param1 = {}
106    param2 = {}
107    one_param['name'] = "param_test"
108    one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
109    param1['name'] = "param"
110    param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
111    param2['name'] = "new_param"
112    param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
113    parameter_list.append(one_param)
114    parameter_list.append(param1)
115    parameter_list.append(param2)
116
117    if os.path.exists('./parameters.ckpt'):
118        os.chmod('./parameters.ckpt', stat.S_IWRITE)
119        os.remove('./parameters.ckpt')
120
121    ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
122    save_checkpoint(parameter_list, ckpt_file_name)
123
124
125def test_save_checkpoint_for_list_append_info():
126    """ test save_checkpoint for list append info"""
127    parameter_list = []
128    one_param = {}
129    param1 = {}
130    param2 = {}
131    one_param['name'] = "param_test"
132    one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
133    param1['name'] = "param"
134    param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
135    param2['name'] = "new_param"
136    param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
137    parameter_list.append(one_param)
138    parameter_list.append(param1)
139    parameter_list.append(param2)
140    append_dict = {"lr": 0.01, "epoch": 20, "train": True}
141    if os.path.exists('./parameters.ckpt'):
142        os.chmod('./parameters.ckpt', stat.S_IWRITE)
143        os.remove('./parameters.ckpt')
144
145    ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
146    save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict)
147
148
149def test_load_checkpoint_error_filename():
150    ckpt_file_name = 1
151    with pytest.raises(ValueError):
152        load_checkpoint(ckpt_file_name)
153
154
155def test_load_checkpoint():
156    ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
157    par_dict = load_checkpoint(ckpt_file_name)
158
159    assert len(par_dict) == 6
160    assert par_dict['param_test'].name == 'param_test'
161    assert par_dict['param_test'].data.dtype == mstype.float32
162    assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
163    assert isinstance(par_dict, dict)
164
165
166def test_checkpoint_manager():
167    """ test_checkpoint_manager """
168    ckp_mgr = _CheckpointManager()
169
170    ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt')
171    with open(ckpt_file_name, 'w'):
172        os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
173
174    ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
175    assert ckp_mgr.ckpoint_num == 1
176
177    ckp_mgr.remove_ckpoint_file(ckpt_file_name)
178    ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
179    assert ckp_mgr.ckpoint_num == 0
180    assert not os.path.exists(ckpt_file_name)
181
182    another_file_name = os.path.join(_cur_dir, './test-2_1.ckpt')
183    another_file_name = os.path.realpath(another_file_name)
184    with open(another_file_name, 'w'):
185        os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
186
187    ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
188    assert ckp_mgr.ckpoint_num == 1
189    ckp_mgr.remove_oldest_ckpoint_file()
190    ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
191    assert ckp_mgr.ckpoint_num == 0
192    assert not os.path.exists(another_file_name)
193
194    # test keep_one_ckpoint_per_minutes
195    file1 = os.path.realpath(os.path.join(_cur_dir, './time_file-1_1.ckpt'))
196    file2 = os.path.realpath(os.path.join(_cur_dir, './time_file-2_1.ckpt'))
197    file3 = os.path.realpath(os.path.join(_cur_dir, './time_file-3_1.ckpt'))
198    with open(file1, 'w'):
199        os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
200    with open(file2, 'w'):
201        os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
202    with open(file3, 'w'):
203        os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
204    time1 = time.time()
205    ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
206    assert ckp_mgr.ckpoint_num == 3
207    ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
208    ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
209    assert ckp_mgr.ckpoint_num == 1
210    if os.path.exists(_cur_dir + '/time_file-1_1.ckpt'):
211        os.chmod(_cur_dir + '/time_file-1_1.ckpt', stat.S_IWRITE)
212        os.remove(_cur_dir + '/time_file-1_1.ckpt')
213
214
215def test_load_param_into_net_error_net():
216    parameter_dict = {}
217    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
218                          name="conv1.weight")
219    parameter_dict["conv1.weight"] = one_param
220    with pytest.raises(TypeError):
221        load_param_into_net('', parameter_dict)
222
223
224def test_load_param_into_net_error_dict():
225    net = Net(10)
226    with pytest.raises(TypeError):
227        load_param_into_net(net, '')
228
229
230def test_load_param_into_net_erro_dict_param():
231    net = Net(10)
232    net.init_parameters_data()
233    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
234
235    parameter_dict = {}
236    one_param = ''
237    parameter_dict["conv1.weight"] = one_param
238    with pytest.raises(TypeError):
239        load_param_into_net(net, parameter_dict)
240
241
242def test_load_param_into_net_has_more_param():
243    """ test_load_param_into_net_has_more_param """
244    net = Net(10)
245    net.init_parameters_data()
246    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
247
248    parameter_dict = {}
249    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
250                          name="conv1.weight")
251    parameter_dict["conv1.weight"] = one_param
252    two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
253                          name="conv1.weight")
254    parameter_dict["conv1.w"] = two_param
255    load_param_into_net(net, parameter_dict)
256    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
257
258
259def test_load_param_into_net_param_type_and_shape_error():
260    net = Net(10)
261    net.init_parameters_data()
262    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
263
264    parameter_dict = {}
265    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight")
266    parameter_dict["conv1.weight"] = one_param
267    with pytest.raises(RuntimeError):
268        load_param_into_net(net, parameter_dict)
269
270
271def test_load_param_into_net_param_type_error():
272    net = Net(10)
273    net.init_parameters_data()
274    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
275
276    parameter_dict = {}
277    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
278                          name="conv1.weight")
279    parameter_dict["conv1.weight"] = one_param
280    with pytest.raises(RuntimeError):
281        load_param_into_net(net, parameter_dict)
282
283
284def test_load_param_into_net_param_shape_error():
285    net = Net(10)
286    net.init_parameters_data()
287    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
288
289    parameter_dict = {}
290    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
291                          name="conv1.weight")
292    parameter_dict["conv1.weight"] = one_param
293    with pytest.raises(RuntimeError):
294        load_param_into_net(net, parameter_dict)
295
296
297def test_load_param_into_net():
298    net = Net(10)
299    net.init_parameters_data()
300    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
301
302    parameter_dict = {}
303    one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
304                          name="conv1.weight")
305    parameter_dict["conv1.weight"] = one_param
306    load_param_into_net(net, parameter_dict)
307    assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
308
309
310def test_save_checkpoint_for_network():
311    """ test save_checkpoint for network"""
312    net = Net()
313    loss = SoftmaxCrossEntropyWithLogits(sparse=True)
314    opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
315
316    loss_net = WithLossCell(net, loss)
317    train_network = TrainOneStepCell(loss_net, opt)
318    save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
319
320    load_checkpoint("new_ckpt.ckpt")
321
322
323def test_load_checkpoint_empty_file():
324    os.mknod("empty.ckpt")
325    with pytest.raises(ValueError):
326        load_checkpoint("empty.ckpt")
327
328
329def test_save_and_load_checkpoint_for_network_with_encryption():
330    """ test save and checkpoint for network with encryption"""
331    net = Net()
332    loss = SoftmaxCrossEntropyWithLogits(sparse=True)
333    opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
334
335    loss_net = WithLossCell(net, loss)
336    train_network = TrainOneStepCell(loss_net, opt)
337    key = secrets.token_bytes(16)
338    mode = "AES-GCM"
339    ckpt_path = "./encrypt_ckpt.ckpt"
340    if platform.system().lower() == "windows":
341        with pytest.raises(NotImplementedError):
342            save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
343            param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
344            load_param_into_net(net, param_dict)
345    else:
346        save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
347        param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
348        load_param_into_net(net, param_dict)
349    if os.path.exists(ckpt_path):
350        os.remove(ckpt_path)
351
352
353class MYNET(nn.Cell):
354    """ NET definition """
355
356    def __init__(self):
357        super(MYNET, self).__init__()
358        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
359        self.bn = nn.BatchNorm2d(64)
360        self.relu = nn.ReLU()
361        self.flatten = nn.Flatten()
362        self.fc = nn.Dense(64 * 222 * 222, 3)  # padding=0
363
364    def construct(self, x):
365        x = self.conv(x)
366        x = self.bn(x)
367        x = self.relu(x)
368        x = self.flatten(x)
369        out = self.fc(x)
370        return out
371
372
373@non_graph_engine
374def test_export():
375    net = MYNET()
376    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
377    with pytest.raises(ValueError):
378        export(net, input_data, file_name="./me_export.pb", file_format="AIR")
379
380
381@non_graph_engine
382def test_mindir_export():
383    net = MYNET()
384    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
385    export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
386
387
388@non_graph_engine
389def test_mindir_export_and_load_with_encryption():
390    net = MYNET()
391    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
392    key = secrets.token_bytes(16)
393    export(net, input_data, file_name="./me_cipher_binary_export.mindir", file_format="MINDIR", enc_key=key)
394    load("./me_cipher_binary_export.mindir", dec_key=key)
395
396
397
398class PrintNet(nn.Cell):
399    def __init__(self):
400        super(PrintNet, self).__init__()
401        self.print = P.Print()
402
403    def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
404                  scale1, scale2):
405        self.print('============tensor int8:==============', int8)
406        self.print('============tensor int8:==============', int8)
407        self.print('============tensor uint8:==============', uint8)
408        self.print('============tensor int16:==============', int16)
409        self.print('============tensor uint16:==============', uint16)
410        self.print('============tensor int32:==============', int32)
411        self.print('============tensor uint32:==============', uint32)
412        self.print('============tensor int64:==============', int64)
413        self.print('============tensor uint64:==============', uint64)
414        self.print('============tensor float16:==============', flt16)
415        self.print('============tensor float32:==============', flt32)
416        self.print('============tensor float64:==============', flt64)
417        self.print('============tensor bool:==============', bool_)
418        self.print('============tensor scale1:==============', scale1)
419        self.print('============tensor scale2:==============', scale2)
420        return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2
421
422
423@security_off_wrap
424def test_print():
425    print_net = PrintNet()
426    int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8"))
427    uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8"))
428    int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16"))
429    uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16"))
430    int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32"))
431    uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32"))
432    int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64"))
433    uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64"))
434    float16 = Tensor(np.random.rand(224, 224).astype(np.float16))
435    float32 = Tensor(np.random.rand(224, 224).astype(np.float32))
436    float64 = Tensor(np.random.rand(224, 224).astype(np.float64))
437    bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_))
438    scale1 = Tensor(np.array(1))
439    scale2 = Tensor(np.array(0.1))
440    print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1,
441              scale2)
442
443
444def teardown_module():
445    files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
446    for item in files:
447        file_name = './' + item
448        if not os.path.exists(file_name):
449            continue
450        os.chmod(file_name, stat.S_IWRITE)
451        os.remove(file_name)
452    import shutil
453    if os.path.exists('./print'):
454        shutil.rmtree('./print')
455