• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test callback function."""
16import os
17import platform
18import stat
19import secrets
20from unittest import mock
21
22import numpy as np
23import pytest
24
25import mindspore.common.dtype as mstype
26import mindspore.nn as nn
27from mindspore.common.api import ms_function
28from mindspore.common.tensor import Tensor
29from mindspore.nn import TrainOneStepCell, WithLossCell
30from mindspore.nn.optim import Momentum
31from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
32    _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
33from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist
34
35
36class Net(nn.Cell):
37    """Net definition."""
38
39    def __init__(self):
40        super(Net, self).__init__()
41        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
42        self.bn = nn.BatchNorm2d(64)
43        self.relu = nn.ReLU()
44        self.flatten = nn.Flatten()
45        self.fc = nn.Dense(64 * 222 * 222, 3)
46
47    @ms_function
48    def construct(self, x):
49        x = self.conv(x)
50        x = self.bn(x)
51        x = self.relu(x)
52        x = self.flatten(x)
53        out = self.fc(x)
54        return out
55
56
57class LossNet(nn.Cell):
58    """ LossNet definition """
59
60    def __init__(self):
61        super(LossNet, self).__init__()
62        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
63        self.bn = nn.BatchNorm2d(64)
64        self.relu = nn.ReLU()
65        self.flatten = nn.Flatten()
66        self.fc = nn.Dense(64 * 222 * 222, 3)  # padding=0
67        self.loss = nn.SoftmaxCrossEntropyWithLogits()
68
69    @ms_function
70    def construct(self, x, y):
71        x = self.conv(x)
72        x = self.bn(x)
73        x = self.relu(x)
74        x = self.flatten(x)
75        x = self.fc(x)
76        out = self.loss(x, y)
77        return out
78
79
80def test_model_checkpoint_prefix_invalid():
81    """Test ModelCheckpoint prefix invalid."""
82    with pytest.raises(ValueError):
83        ModelCheckpoint(123)
84    ModelCheckpoint(directory="./")
85    with pytest.raises(TypeError):
86        ModelCheckpoint(config='type_error')
87    ModelCheckpoint(config=CheckpointConfig())
88    ModelCheckpoint(prefix="ckpt_2", directory="./test_files")
89
90
91def test_save_checkpoint():
92    """Test save checkpoint."""
93    train_config = CheckpointConfig(
94        save_checkpoint_steps=16,
95        save_checkpoint_seconds=0,
96        keep_checkpoint_max=5,
97        keep_checkpoint_per_n_minutes=0)
98    cb_params = _InternalCallbackParam()
99    net = Net()
100    loss = nn.SoftmaxCrossEntropyWithLogits()
101    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
102    network_ = WithLossCell(net, loss)
103    _train_network = TrainOneStepCell(network_, optim)
104    cb_params.train_network = _train_network
105    cb_params.epoch_num = 10
106    cb_params.cur_epoch_num = 5
107    cb_params.cur_step_num = 0
108    cb_params.batch_num = 32
109    ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config)
110    run_context = RunContext(cb_params)
111    ckpoint_cb.begin(run_context)
112    ckpoint_cb.step_end(run_context)
113    if os.path.exists('./test_files/test_ckpt-model.pkl'):
114        os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
115        os.remove('./test_files/test_ckpt-model.pkl')
116
117
118def test_loss_monitor_sink_mode():
119    """Test loss monitor sink mode."""
120    cb_params = _InternalCallbackParam()
121    cb_params.cur_epoch_num = 4
122    cb_params.epoch_num = 4
123    cb_params.cur_step_num = 2
124    cb_params.batch_num = 2
125    cb_params.net_outputs = Tensor(2.0)
126    run_context = RunContext(cb_params)
127    loss_cb = LossMonitor(1)
128    callbacks = [loss_cb]
129    with _CallbackManager(callbacks) as callbacklist:
130        callbacklist.begin(run_context)
131        callbacklist.epoch_begin(run_context)
132        callbacklist.step_begin(run_context)
133        callbacklist.step_end(run_context)
134        callbacklist.epoch_end(run_context)
135        callbacklist.end(run_context)
136
137
138def test_loss_monitor_normal_mode():
139    """Test loss monitor normal(non-sink) mode."""
140    cb_params = _InternalCallbackParam()
141    run_context = RunContext(cb_params)
142    loss_cb = LossMonitor(1)
143    cb_params.cur_epoch_num = 4
144    cb_params.epoch_num = 4
145    cb_params.cur_step_num = 1
146    cb_params.batch_num = 1
147    cb_params.net_outputs = Tensor(2.0)
148    loss_cb.begin(run_context)
149    loss_cb.epoch_begin(run_context)
150    loss_cb.step_begin(run_context)
151    loss_cb.step_end(run_context)
152    loss_cb.epoch_end(run_context)
153    loss_cb.end(run_context)
154
155
156def test_chg_ckpt_file_name_if_same_exist():
157    """Test chg ckpt file name if same exist."""
158    _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")
159
160
161def test_checkpoint_cb_for_save_op():
162    """Test checkpoint cb for save op."""
163    parameter_list = []
164    one_param = {}
165    one_param['name'] = "conv1.weight"
166    one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
167    parameter_list.append(one_param)
168    _checkpoint_cb_for_save_op(parameter_list)
169
170
171def test_checkpoint_cb_for_save_op_update_net():
172    """Test checkpoint cb for save op."""
173    parameter_list = []
174    one_param = {}
175    one_param['name'] = "conv.weight"
176    one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
177    parameter_list.append(one_param)
178    net = Net()
179    _set_cur_net(net)
180    _checkpoint_cb_for_save_op(parameter_list)
181    assert net.conv.weight.data.asnumpy()[0][0][0][0] == 1
182
183
184def test_internal_callback_param():
185    """Test Internal CallbackParam."""
186    cb_params = _InternalCallbackParam()
187    cb_params.member1 = 1
188    cb_params.member2 = "abc"
189    assert cb_params.member1 == 1
190    assert cb_params.member2 == "abc"
191
192
193def test_checkpoint_save_ckpt_steps():
194    """Test checkpoint save ckpt steps."""
195    train_config = CheckpointConfig(
196        save_checkpoint_steps=16,
197        save_checkpoint_seconds=0,
198        keep_checkpoint_max=5,
199        keep_checkpoint_per_n_minutes=0)
200    ckpt_cb = ModelCheckpoint(config=train_config)
201    cb_params = _InternalCallbackParam()
202    net = Net()
203    loss = nn.SoftmaxCrossEntropyWithLogits()
204    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
205    network_ = WithLossCell(net, loss)
206    _train_network = TrainOneStepCell(network_, optim)
207    cb_params.train_network = _train_network
208    cb_params.epoch_num = 10
209    cb_params.cur_epoch_num = 5
210    cb_params.cur_step_num = 160
211    cb_params.batch_num = 32
212    run_context = RunContext(cb_params)
213    ckpt_cb.begin(run_context)
214    ckpt_cb.step_end(run_context)
215    ckpt_cb2 = ModelCheckpoint(config=train_config)
216    cb_params.cur_epoch_num = 1
217    cb_params.cur_step_num = 15
218    ckpt_cb2.begin(run_context)
219    ckpt_cb2.step_end(run_context)
220
221
222def test_checkpoint_save_ckpt_seconds():
223    """Test checkpoint save ckpt seconds."""
224    train_config = CheckpointConfig(
225        save_checkpoint_steps=16,
226        save_checkpoint_seconds=100,
227        keep_checkpoint_max=0,
228        keep_checkpoint_per_n_minutes=1)
229    ckpt_cb = ModelCheckpoint(config=train_config)
230    cb_params = _InternalCallbackParam()
231    net = Net()
232    loss = nn.SoftmaxCrossEntropyWithLogits()
233    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
234    network_ = WithLossCell(net, loss)
235    _train_network = TrainOneStepCell(network_, optim)
236    cb_params.train_network = _train_network
237    cb_params.epoch_num = 10
238    cb_params.cur_epoch_num = 4
239    cb_params.cur_step_num = 128
240    cb_params.batch_num = 32
241    run_context = RunContext(cb_params)
242    ckpt_cb.begin(run_context)
243    ckpt_cb.step_end(run_context)
244    ckpt_cb2 = ModelCheckpoint(config=train_config)
245    cb_params.cur_epoch_num = 1
246    cb_params.cur_step_num = 16
247    ckpt_cb2.begin(run_context)
248    ckpt_cb2.step_end(run_context)
249
250
251def test_checkpoint_save_ckpt_with_encryption():
252    """Test checkpoint save ckpt with encryption."""
253    train_config = CheckpointConfig(
254        save_checkpoint_steps=16,
255        save_checkpoint_seconds=0,
256        keep_checkpoint_max=5,
257        keep_checkpoint_per_n_minutes=0,
258        enc_key=secrets.token_bytes(16),
259        enc_mode="AES-GCM")
260    ckpt_cb = ModelCheckpoint(config=train_config)
261    cb_params = _InternalCallbackParam()
262    net = Net()
263    loss = nn.SoftmaxCrossEntropyWithLogits()
264    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
265    network_ = WithLossCell(net, loss)
266    _train_network = TrainOneStepCell(network_, optim)
267    cb_params.train_network = _train_network
268    cb_params.epoch_num = 10
269    cb_params.cur_epoch_num = 5
270    cb_params.cur_step_num = 160
271    cb_params.batch_num = 32
272    run_context = RunContext(cb_params)
273    ckpt_cb.begin(run_context)
274    ckpt_cb.step_end(run_context)
275    ckpt_cb2 = ModelCheckpoint(config=train_config)
276    cb_params.cur_epoch_num = 1
277    cb_params.cur_step_num = 15
278
279    if platform.system().lower() == "windows":
280        with pytest.raises(NotImplementedError):
281            ckpt_cb2.begin(run_context)
282            ckpt_cb2.step_end(run_context)
283    else:
284        ckpt_cb2.begin(run_context)
285        ckpt_cb2.step_end(run_context)
286
287
288def test_CallbackManager():
289    """TestCallbackManager."""
290    ck_obj = ModelCheckpoint()
291    loss_cb_1 = LossMonitor(1)
292
293    callbacks = [None]
294    with pytest.raises(TypeError):
295        _CallbackManager(callbacks)
296
297    callbacks = ['Error']
298    with pytest.raises(TypeError):
299        _CallbackManager(callbacks)
300
301    callbacks = [ck_obj, loss_cb_1, 'Error', None]
302    with pytest.raises(TypeError):
303        _CallbackManager(callbacks)
304
305
306def test_CallbackManager_exit_called():
307    with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
308        cb1, cb2 = Callback(), Callback()
309        with _CallbackManager([cb1, cb2]):
310            pass
311    for call_args in mock_exit.call_args_list:
312        assert call_args == mock.call(mock.ANY, None, None, None)
313    assert mock_exit.call_count == 2
314
315
316def test_CallbackManager_exit_called_when_raises():
317    with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
318        cb1, cb2 = Callback(), Callback()
319        with pytest.raises(ValueError):
320            with _CallbackManager([cb1, cb2]):
321                raise ValueError()
322    for call_args in mock_exit.call_args_list:
323        assert call_args == mock.call(*[mock.ANY] * 4)
324    assert mock_exit.call_count == 2
325
326
327def test_CallbackManager_begin_called():
328    context = dict()
329    with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
330        cb1, cb2 = Callback(), Callback()
331        with _CallbackManager([cb1, cb2]) as cm:
332            cm.begin(context)
333    for call_args in mock_begin.call_args_list:
334        assert call_args == mock.call(context)
335    assert mock_begin.call_count == 2
336
337
338def test_RunContext():
339    """Test RunContext."""
340    context_err = 666
341    with pytest.raises(TypeError):
342        RunContext(context_err)
343
344    cb_params = _InternalCallbackParam()
345    cb_params.member1 = 1
346    cb_params.member2 = "abc"
347
348    run_context = RunContext(cb_params)
349    run_context.original_args()
350    assert cb_params.member1 == 1
351    assert cb_params.member2 == "abc"
352
353    run_context.request_stop()
354    should_stop = run_context.get_stop_requested()
355    assert should_stop
356
357
358def test_Checkpoint_Config():
359    """Test CheckpointConfig all None or 0."""
360    with pytest.raises(ValueError):
361        CheckpointConfig(0, 0, 0, 0, True)
362
363    with pytest.raises(ValueError):
364        CheckpointConfig(0, None, 0, 0, True)
365
366
367def test_step_end_save_graph():
368    """Test save checkpoint."""
369    train_config = CheckpointConfig(
370        save_checkpoint_steps=16,
371        save_checkpoint_seconds=0,
372        keep_checkpoint_max=5,
373        keep_checkpoint_per_n_minutes=0)
374    cb_params = _InternalCallbackParam()
375    net = LossNet()
376    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
377    input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
378    net(input_data, input_label)
379    cb_params.train_network = net
380    cb_params.epoch_num = 10
381    cb_params.cur_epoch_num = 5
382    cb_params.cur_step_num = 0
383    cb_params.batch_num = 32
384    ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config)
385    run_context = RunContext(cb_params)
386    ckpoint_cb.begin(run_context)
387    ckpoint_cb.step_end(run_context)
388    assert os.path.exists('./test_files/test-graph.meta')
389    if os.path.exists('./test_files/test-graph.meta'):
390        os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
391        os.remove('./test_files/test-graph.meta')
392    ckpoint_cb.step_end(run_context)
393    assert not os.path.exists('./test_files/test-graph.meta')
394