• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test dataset helper."""
16
17import pytest
18import numpy as np
19import mindspore.context as context
20from mindspore.communication.management import init
21from mindspore.train.dataset_helper import DatasetHelper
22from mindspore.communication._comm_helper import GlobalComm
23from ....dataset_mock import MindData
24
25def get_dataset(batch_size=1):
26    dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
27    dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
28                      (batch_size, 20), (batch_size, 20), (batch_size, 20))
29
30    dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types,
31                       output_shapes=dataset_shapes, input_indexs=(0, 1))
32    return dataset
33
34
35def test_dataset_helper_dataset_sink_mode_str():
36    dataset = get_dataset(32)
37    with pytest.raises(TypeError):
38        DatasetHelper(dataset, dataset_sink_mode="True")
39
40
41def test_dataset_helper_dataset_sink_mode_int():
42    dataset = get_dataset(32)
43    with pytest.raises(TypeError):
44        DatasetHelper(dataset, dataset_sink_mode=1)
45
46
47def test_dataset_helper_sink_size_bool():
48    dataset = get_dataset(32)
49    with pytest.raises(TypeError):
50        DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True)
51
52
53def test_dataset_helper_sink_size_float():
54    dataset = get_dataset(32)
55    with pytest.raises(TypeError):
56        DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0)
57
58
59def test_dataset_helper_sink_size_negative():
60    dataset = get_dataset(32)
61    with pytest.raises(ValueError):
62        DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2)
63
64
65def test_dataset_iter_normal():
66    dataset = get_dataset(32)
67    dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False)
68    count = 0
69    for _ in range(2):
70        for _ in dataset_helper:
71            count += 1
72        dataset.reset()
73    assert count == 6
74
75
76@pytest.mark.skipif('not context.get_context("enable_ge")')
77def test_dataset_iter_ge():
78    GlobalComm.CHECK_ENVS = False
79    init("hccl")
80    GlobalComm.CHECK_ENVS = True
81    dataset = get_dataset(32)
82    dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
83    count = 0
84    for _ in range(2):
85        for _ in dataset_helper:
86            count += 1
87    assert count == 2
88
89
90@pytest.mark.skipif('context.get_context("enable_ge")')
91def test_dataset_iter_ms_loop_sink():
92    GlobalComm.CHECK_ENVS = False
93    init("hccl")
94    GlobalComm.CHECK_ENVS = True
95    context.set_context(enable_loop_sink=True)
96    dataset = get_dataset(32)
97    dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
98    count = 0
99    for _ in range(2):
100        for inputs in dataset_helper:
101            count += 1
102            assert inputs == tuple()
103    assert count == 2
104
105
106@pytest.mark.skipif('context.get_context("enable_ge")')
107def test_dataset_iter_ms():
108    GlobalComm.CHECK_ENVS = False
109    init("hccl")
110    GlobalComm.CHECK_ENVS = True
111    context.set_context(enable_loop_sink=False)
112    dataset = get_dataset(32)
113    DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
114