• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Test debug hook of debug mode
17"""
18
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.vision as vision
22import mindspore.dataset.debug as dbg
23
24# Need to run all these tests in separate processes since
25# the global configuration setting of debug_mode may impact other tests running in parallel.
26pytestmark = pytest.mark.forked
27
28
29@pytest.mark.parametrize("debug_mode_flag, debug_hook_list",
30                         [(True, [dbg.PrintMetaDataHook()]),
31                          (True, [dbg.PrintDataHook()]),
32                          (True, [])])
33def test_debug_mode_hook(debug_mode_flag, debug_hook_list):
34    """
35    Feature: Test the debug mode setter function
36    Description: Test valid debug hook case for debug mode
37    Expectation: Success
38    """
39    # get original configs to restore after running is done.
40    origin_debug_mode = ds.config.get_debug_mode()
41    origin_seed = ds.config.get_seed()
42
43    # set debug flag and hook
44    ds.config.set_debug_mode(debug_mode_flag=debug_mode_flag, debug_hook_list=debug_hook_list)
45    dataset = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=5)
46    dataset = dataset.map(operations=[vision.Decode(False), vision.CenterCrop((225, 225))])
47    for _ in dataset.create_dict_iterator(num_epochs=1):
48        pass
49    # restore configs
50    ds.config.set_debug_mode(origin_debug_mode)
51    ds.config.set_seed(origin_seed)
52
53
54if __name__ == '__main__':
55    test_debug_mode_hook(True, [dbg.PrintMetaDataHook()])
56    test_debug_mode_hook(True, [dbg.PrintDataHook()])
57    test_debug_mode_hook(True, [])
58