1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_run_config """ 16import pytest 17 18from mindspore.train.callback import CheckpointConfig 19 20 21def test_init(): 22 """ test_init """ 23 save_checkpoint_steps = 1 24 keep_checkpoint_max = 5 25 26 config = CheckpointConfig(save_checkpoint_steps, 27 keep_checkpoint_max) 28 29 assert config.save_checkpoint_steps == save_checkpoint_steps 30 assert config.keep_checkpoint_max == keep_checkpoint_max 31 policy = config.get_checkpoint_policy() 32 assert policy['keep_checkpoint_max'] == keep_checkpoint_max 33 34 35def test_arguments_values(): 36 """ test_arguments_values """ 37 config = CheckpointConfig() 38 assert config.save_checkpoint_steps == 1 39 assert config.save_checkpoint_seconds is None 40 assert config.keep_checkpoint_max == 5 41 assert config.keep_checkpoint_per_n_minutes is None 42 43 with pytest.raises(TypeError): 44 CheckpointConfig(save_checkpoint_steps='abc') 45 with pytest.raises(TypeError): 46 CheckpointConfig(save_checkpoint_seconds='abc') 47 with pytest.raises(TypeError): 48 CheckpointConfig(keep_checkpoint_max='abc') 49 with pytest.raises(TypeError): 50 CheckpointConfig(keep_checkpoint_per_n_minutes='abc') 51 52 with pytest.raises(ValueError): 53 CheckpointConfig(save_checkpoint_steps=-1) 54 with pytest.raises(ValueError): 55 CheckpointConfig(save_checkpoint_seconds=-1) 56 with pytest.raises(ValueError): 57 CheckpointConfig(keep_checkpoint_max=-1) 58 with pytest.raises(ValueError): 59 CheckpointConfig(keep_checkpoint_per_n_minutes=-1) 60