1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Test debug hook of debug mode 17""" 18 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.vision as vision 22import mindspore.dataset.debug as dbg 23 24# Need to run all these tests in separate processes since 25# the global configuration setting of debug_mode may impact other tests running in parallel. 26pytestmark = pytest.mark.forked 27 28 29@pytest.mark.parametrize("debug_mode_flag, debug_hook_list", 30 [(True, [dbg.PrintMetaDataHook()]), 31 (True, [dbg.PrintDataHook()]), 32 (True, [])]) 33def test_debug_mode_hook(debug_mode_flag, debug_hook_list): 34 """ 35 Feature: Test the debug mode setter function 36 Description: Test valid debug hook case for debug mode 37 Expectation: Success 38 """ 39 # get original configs to restore after running is done. 40 origin_debug_mode = ds.config.get_debug_mode() 41 origin_seed = ds.config.get_seed() 42 43 # set debug flag and hook 44 ds.config.set_debug_mode(debug_mode_flag=debug_mode_flag, debug_hook_list=debug_hook_list) 45 dataset = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=5) 46 dataset = dataset.map(operations=[vision.Decode(False), vision.CenterCrop((225, 225))]) 47 for _ in dataset.create_dict_iterator(num_epochs=1): 48 pass 49 # restore configs 50 ds.config.set_debug_mode(origin_debug_mode) 51 ds.config.set_seed(origin_seed) 52 53 54if __name__ == '__main__': 55 test_debug_mode_hook(True, [dbg.PrintMetaDataHook()]) 56 test_debug_mode_hook(True, [dbg.PrintDataHook()]) 57 test_debug_mode_hook(True, []) 58