1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test callback function.""" 16import os 17import platform 18import stat 19import secrets 20from unittest import mock 21 22import numpy as np 23import pytest 24 25import mindspore.common.dtype as mstype 26import mindspore.nn as nn 27from mindspore.common.api import ms_function 28from mindspore.common.tensor import Tensor 29from mindspore.nn import TrainOneStepCell, WithLossCell 30from mindspore.nn.optim import Momentum 31from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ 32 _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op 33from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist 34 35 36class Net(nn.Cell): 37 """Net definition.""" 38 39 def __init__(self): 40 super(Net, self).__init__() 41 self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') 42 self.bn = nn.BatchNorm2d(64) 43 self.relu = nn.ReLU() 44 self.flatten = nn.Flatten() 45 self.fc = nn.Dense(64 * 222 * 222, 3) 46 47 @ms_function 48 def construct(self, x): 49 x = self.conv(x) 50 x = self.bn(x) 51 x = self.relu(x) 52 x = self.flatten(x) 53 out = self.fc(x) 54 return out 55 56 57class LossNet(nn.Cell): 58 """ LossNet definition """ 59 60 def __init__(self): 61 super(LossNet, self).__init__() 62 self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid') 63 self.bn = nn.BatchNorm2d(64) 64 self.relu = nn.ReLU() 65 self.flatten = nn.Flatten() 66 self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0 67 self.loss = nn.SoftmaxCrossEntropyWithLogits() 68 69 @ms_function 70 def construct(self, x, y): 71 x = self.conv(x) 72 x = self.bn(x) 73 x = self.relu(x) 74 x = self.flatten(x) 75 x = self.fc(x) 76 out = self.loss(x, y) 77 return out 78 79 80def test_model_checkpoint_prefix_invalid(): 81 """Test ModelCheckpoint prefix invalid.""" 82 with pytest.raises(ValueError): 83 ModelCheckpoint(123) 84 ModelCheckpoint(directory="./") 85 with pytest.raises(TypeError): 86 ModelCheckpoint(config='type_error') 87 ModelCheckpoint(config=CheckpointConfig()) 88 ModelCheckpoint(prefix="ckpt_2", directory="./test_files") 89 90 91def test_save_checkpoint(): 92 """Test save checkpoint.""" 93 train_config = CheckpointConfig( 94 save_checkpoint_steps=16, 95 save_checkpoint_seconds=0, 96 keep_checkpoint_max=5, 97 keep_checkpoint_per_n_minutes=0) 98 cb_params = _InternalCallbackParam() 99 net = Net() 100 loss = nn.SoftmaxCrossEntropyWithLogits() 101 optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 102 network_ = WithLossCell(net, loss) 103 _train_network = TrainOneStepCell(network_, optim) 104 cb_params.train_network = _train_network 105 cb_params.epoch_num = 10 106 cb_params.cur_epoch_num = 5 107 cb_params.cur_step_num = 0 108 cb_params.batch_num = 32 109 ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) 110 run_context = RunContext(cb_params) 111 ckpoint_cb.begin(run_context) 112 ckpoint_cb.step_end(run_context) 113 if os.path.exists('./test_files/test_ckpt-model.pkl'): 114 os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) 115 os.remove('./test_files/test_ckpt-model.pkl') 116 117 118def test_loss_monitor_sink_mode(): 119 """Test loss monitor sink mode.""" 120 cb_params = _InternalCallbackParam() 121 cb_params.cur_epoch_num = 4 122 cb_params.epoch_num = 4 123 cb_params.cur_step_num = 2 124 cb_params.batch_num = 2 125 cb_params.net_outputs = Tensor(2.0) 126 run_context = RunContext(cb_params) 127 loss_cb = LossMonitor(1) 128 callbacks = [loss_cb] 129 with _CallbackManager(callbacks) as callbacklist: 130 callbacklist.begin(run_context) 131 callbacklist.epoch_begin(run_context) 132 callbacklist.step_begin(run_context) 133 callbacklist.step_end(run_context) 134 callbacklist.epoch_end(run_context) 135 callbacklist.end(run_context) 136 137 138def test_loss_monitor_normal_mode(): 139 """Test loss monitor normal(non-sink) mode.""" 140 cb_params = _InternalCallbackParam() 141 run_context = RunContext(cb_params) 142 loss_cb = LossMonitor(1) 143 cb_params.cur_epoch_num = 4 144 cb_params.epoch_num = 4 145 cb_params.cur_step_num = 1 146 cb_params.batch_num = 1 147 cb_params.net_outputs = Tensor(2.0) 148 loss_cb.begin(run_context) 149 loss_cb.epoch_begin(run_context) 150 loss_cb.step_begin(run_context) 151 loss_cb.step_end(run_context) 152 loss_cb.epoch_end(run_context) 153 loss_cb.end(run_context) 154 155 156def test_chg_ckpt_file_name_if_same_exist(): 157 """Test chg ckpt file name if same exist.""" 158 _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") 159 160 161def test_checkpoint_cb_for_save_op(): 162 """Test checkpoint cb for save op.""" 163 parameter_list = [] 164 one_param = {} 165 one_param['name'] = "conv1.weight" 166 one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) 167 parameter_list.append(one_param) 168 _checkpoint_cb_for_save_op(parameter_list) 169 170 171def test_checkpoint_cb_for_save_op_update_net(): 172 """Test checkpoint cb for save op.""" 173 parameter_list = [] 174 one_param = {} 175 one_param['name'] = "conv.weight" 176 one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) 177 parameter_list.append(one_param) 178 net = Net() 179 _set_cur_net(net) 180 _checkpoint_cb_for_save_op(parameter_list) 181 assert net.conv.weight.data.asnumpy()[0][0][0][0] == 1 182 183 184def test_internal_callback_param(): 185 """Test Internal CallbackParam.""" 186 cb_params = _InternalCallbackParam() 187 cb_params.member1 = 1 188 cb_params.member2 = "abc" 189 assert cb_params.member1 == 1 190 assert cb_params.member2 == "abc" 191 192 193def test_checkpoint_save_ckpt_steps(): 194 """Test checkpoint save ckpt steps.""" 195 train_config = CheckpointConfig( 196 save_checkpoint_steps=16, 197 save_checkpoint_seconds=0, 198 keep_checkpoint_max=5, 199 keep_checkpoint_per_n_minutes=0) 200 ckpt_cb = ModelCheckpoint(config=train_config) 201 cb_params = _InternalCallbackParam() 202 net = Net() 203 loss = nn.SoftmaxCrossEntropyWithLogits() 204 optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 205 network_ = WithLossCell(net, loss) 206 _train_network = TrainOneStepCell(network_, optim) 207 cb_params.train_network = _train_network 208 cb_params.epoch_num = 10 209 cb_params.cur_epoch_num = 5 210 cb_params.cur_step_num = 160 211 cb_params.batch_num = 32 212 run_context = RunContext(cb_params) 213 ckpt_cb.begin(run_context) 214 ckpt_cb.step_end(run_context) 215 ckpt_cb2 = ModelCheckpoint(config=train_config) 216 cb_params.cur_epoch_num = 1 217 cb_params.cur_step_num = 15 218 ckpt_cb2.begin(run_context) 219 ckpt_cb2.step_end(run_context) 220 221 222def test_checkpoint_save_ckpt_seconds(): 223 """Test checkpoint save ckpt seconds.""" 224 train_config = CheckpointConfig( 225 save_checkpoint_steps=16, 226 save_checkpoint_seconds=100, 227 keep_checkpoint_max=0, 228 keep_checkpoint_per_n_minutes=1) 229 ckpt_cb = ModelCheckpoint(config=train_config) 230 cb_params = _InternalCallbackParam() 231 net = Net() 232 loss = nn.SoftmaxCrossEntropyWithLogits() 233 optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 234 network_ = WithLossCell(net, loss) 235 _train_network = TrainOneStepCell(network_, optim) 236 cb_params.train_network = _train_network 237 cb_params.epoch_num = 10 238 cb_params.cur_epoch_num = 4 239 cb_params.cur_step_num = 128 240 cb_params.batch_num = 32 241 run_context = RunContext(cb_params) 242 ckpt_cb.begin(run_context) 243 ckpt_cb.step_end(run_context) 244 ckpt_cb2 = ModelCheckpoint(config=train_config) 245 cb_params.cur_epoch_num = 1 246 cb_params.cur_step_num = 16 247 ckpt_cb2.begin(run_context) 248 ckpt_cb2.step_end(run_context) 249 250 251def test_checkpoint_save_ckpt_with_encryption(): 252 """Test checkpoint save ckpt with encryption.""" 253 train_config = CheckpointConfig( 254 save_checkpoint_steps=16, 255 save_checkpoint_seconds=0, 256 keep_checkpoint_max=5, 257 keep_checkpoint_per_n_minutes=0, 258 enc_key=secrets.token_bytes(16), 259 enc_mode="AES-GCM") 260 ckpt_cb = ModelCheckpoint(config=train_config) 261 cb_params = _InternalCallbackParam() 262 net = Net() 263 loss = nn.SoftmaxCrossEntropyWithLogits() 264 optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 265 network_ = WithLossCell(net, loss) 266 _train_network = TrainOneStepCell(network_, optim) 267 cb_params.train_network = _train_network 268 cb_params.epoch_num = 10 269 cb_params.cur_epoch_num = 5 270 cb_params.cur_step_num = 160 271 cb_params.batch_num = 32 272 run_context = RunContext(cb_params) 273 ckpt_cb.begin(run_context) 274 ckpt_cb.step_end(run_context) 275 ckpt_cb2 = ModelCheckpoint(config=train_config) 276 cb_params.cur_epoch_num = 1 277 cb_params.cur_step_num = 15 278 279 if platform.system().lower() == "windows": 280 with pytest.raises(NotImplementedError): 281 ckpt_cb2.begin(run_context) 282 ckpt_cb2.step_end(run_context) 283 else: 284 ckpt_cb2.begin(run_context) 285 ckpt_cb2.step_end(run_context) 286 287 288def test_CallbackManager(): 289 """TestCallbackManager.""" 290 ck_obj = ModelCheckpoint() 291 loss_cb_1 = LossMonitor(1) 292 293 callbacks = [None] 294 with pytest.raises(TypeError): 295 _CallbackManager(callbacks) 296 297 callbacks = ['Error'] 298 with pytest.raises(TypeError): 299 _CallbackManager(callbacks) 300 301 callbacks = [ck_obj, loss_cb_1, 'Error', None] 302 with pytest.raises(TypeError): 303 _CallbackManager(callbacks) 304 305 306def test_CallbackManager_exit_called(): 307 with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: 308 cb1, cb2 = Callback(), Callback() 309 with _CallbackManager([cb1, cb2]): 310 pass 311 for call_args in mock_exit.call_args_list: 312 assert call_args == mock.call(mock.ANY, None, None, None) 313 assert mock_exit.call_count == 2 314 315 316def test_CallbackManager_exit_called_when_raises(): 317 with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: 318 cb1, cb2 = Callback(), Callback() 319 with pytest.raises(ValueError): 320 with _CallbackManager([cb1, cb2]): 321 raise ValueError() 322 for call_args in mock_exit.call_args_list: 323 assert call_args == mock.call(*[mock.ANY] * 4) 324 assert mock_exit.call_count == 2 325 326 327def test_CallbackManager_begin_called(): 328 context = dict() 329 with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: 330 cb1, cb2 = Callback(), Callback() 331 with _CallbackManager([cb1, cb2]) as cm: 332 cm.begin(context) 333 for call_args in mock_begin.call_args_list: 334 assert call_args == mock.call(context) 335 assert mock_begin.call_count == 2 336 337 338def test_RunContext(): 339 """Test RunContext.""" 340 context_err = 666 341 with pytest.raises(TypeError): 342 RunContext(context_err) 343 344 cb_params = _InternalCallbackParam() 345 cb_params.member1 = 1 346 cb_params.member2 = "abc" 347 348 run_context = RunContext(cb_params) 349 run_context.original_args() 350 assert cb_params.member1 == 1 351 assert cb_params.member2 == "abc" 352 353 run_context.request_stop() 354 should_stop = run_context.get_stop_requested() 355 assert should_stop 356 357 358def test_Checkpoint_Config(): 359 """Test CheckpointConfig all None or 0.""" 360 with pytest.raises(ValueError): 361 CheckpointConfig(0, 0, 0, 0, True) 362 363 with pytest.raises(ValueError): 364 CheckpointConfig(0, None, 0, 0, True) 365 366 367def test_step_end_save_graph(): 368 """Test save checkpoint.""" 369 train_config = CheckpointConfig( 370 save_checkpoint_steps=16, 371 save_checkpoint_seconds=0, 372 keep_checkpoint_max=5, 373 keep_checkpoint_per_n_minutes=0) 374 cb_params = _InternalCallbackParam() 375 net = LossNet() 376 input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 377 input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32)) 378 net(input_data, input_label) 379 cb_params.train_network = net 380 cb_params.epoch_num = 10 381 cb_params.cur_epoch_num = 5 382 cb_params.cur_step_num = 0 383 cb_params.batch_num = 32 384 ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config) 385 run_context = RunContext(cb_params) 386 ckpoint_cb.begin(run_context) 387 ckpoint_cb.step_end(run_context) 388 assert os.path.exists('./test_files/test-graph.meta') 389 if os.path.exists('./test_files/test-graph.meta'): 390 os.chmod('./test_files/test-graph.meta', stat.S_IWRITE) 391 os.remove('./test_files/test-graph.meta') 392 ckpoint_cb.step_end(run_context) 393 assert not os.path.exists('./test_files/test-graph.meta') 394