1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test dataset helper.""" 16 17import pytest 18import numpy as np 19import mindspore.context as context 20from mindspore.communication.management import init 21from mindspore.train.dataset_helper import DatasetHelper 22from mindspore.communication._comm_helper import GlobalComm 23from ....dataset_mock import MindData 24 25def get_dataset(batch_size=1): 26 dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) 27 dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), 28 (batch_size, 20), (batch_size, 20), (batch_size, 20)) 29 30 dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types, 31 output_shapes=dataset_shapes, input_indexs=(0, 1)) 32 return dataset 33 34 35def test_dataset_helper_dataset_sink_mode_str(): 36 dataset = get_dataset(32) 37 with pytest.raises(TypeError): 38 DatasetHelper(dataset, dataset_sink_mode="True") 39 40 41def test_dataset_helper_dataset_sink_mode_int(): 42 dataset = get_dataset(32) 43 with pytest.raises(TypeError): 44 DatasetHelper(dataset, dataset_sink_mode=1) 45 46 47def test_dataset_helper_sink_size_bool(): 48 dataset = get_dataset(32) 49 with pytest.raises(TypeError): 50 DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True) 51 52 53def test_dataset_helper_sink_size_float(): 54 dataset = get_dataset(32) 55 with pytest.raises(TypeError): 56 DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0) 57 58 59def test_dataset_helper_sink_size_negative(): 60 dataset = get_dataset(32) 61 with pytest.raises(ValueError): 62 DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2) 63 64 65def test_dataset_iter_normal(): 66 dataset = get_dataset(32) 67 dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False) 68 count = 0 69 for _ in range(2): 70 for _ in dataset_helper: 71 count += 1 72 dataset.reset() 73 assert count == 6 74 75 76@pytest.mark.skipif('not context.get_context("enable_ge")') 77def test_dataset_iter_ge(): 78 GlobalComm.CHECK_ENVS = False 79 init("hccl") 80 GlobalComm.CHECK_ENVS = True 81 dataset = get_dataset(32) 82 dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) 83 count = 0 84 for _ in range(2): 85 for _ in dataset_helper: 86 count += 1 87 assert count == 2 88 89 90@pytest.mark.skipif('context.get_context("enable_ge")') 91def test_dataset_iter_ms_loop_sink(): 92 GlobalComm.CHECK_ENVS = False 93 init("hccl") 94 GlobalComm.CHECK_ENVS = True 95 context.set_context(enable_loop_sink=True) 96 dataset = get_dataset(32) 97 dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) 98 count = 0 99 for _ in range(2): 100 for inputs in dataset_helper: 101 count += 1 102 assert inputs == tuple() 103 assert count == 2 104 105 106@pytest.mark.skipif('context.get_context("enable_ge")') 107def test_dataset_iter_ms(): 108 GlobalComm.CHECK_ENVS = False 109 init("hccl") 110 GlobalComm.CHECK_ENVS = True 111 context.set_context(enable_loop_sink=False) 112 dataset = get_dataset(32) 113 DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) 114