• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_run_config """
16import pytest
17
18from mindspore.train.callback import CheckpointConfig
19
20
21def test_init():
22    """ test_init """
23    save_checkpoint_steps = 1
24    keep_checkpoint_max = 5
25
26    config = CheckpointConfig(save_checkpoint_steps,
27                              keep_checkpoint_max)
28
29    assert config.save_checkpoint_steps == save_checkpoint_steps
30    assert config.keep_checkpoint_max == keep_checkpoint_max
31    policy = config.get_checkpoint_policy()
32    assert policy['keep_checkpoint_max'] == keep_checkpoint_max
33
34
35def test_arguments_values():
36    """ test_arguments_values """
37    config = CheckpointConfig()
38    assert config.save_checkpoint_steps == 1
39    assert config.save_checkpoint_seconds is None
40    assert config.keep_checkpoint_max == 5
41    assert config.keep_checkpoint_per_n_minutes is None
42
43    with pytest.raises(TypeError):
44        CheckpointConfig(save_checkpoint_steps='abc')
45    with pytest.raises(TypeError):
46        CheckpointConfig(save_checkpoint_seconds='abc')
47    with pytest.raises(TypeError):
48        CheckpointConfig(keep_checkpoint_max='abc')
49    with pytest.raises(TypeError):
50        CheckpointConfig(keep_checkpoint_per_n_minutes='abc')
51
52    with pytest.raises(ValueError):
53        CheckpointConfig(save_checkpoint_steps=-1)
54    with pytest.raises(ValueError):
55        CheckpointConfig(save_checkpoint_seconds=-1)
56    with pytest.raises(ValueError):
57        CheckpointConfig(keep_checkpoint_max=-1)
58    with pytest.raises(ValueError):
59        CheckpointConfig(keep_checkpoint_per_n_minutes=-1)
60