• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import sys
16import pytest
17import numpy as np
18import pandas as pd
19import mindspore.dataset as de
20from mindspore import log as logger
21import mindspore.dataset.vision.c_transforms as vision
22
23
24def test_numpy_slices_list_1():
25    logger.info("Test Slicing a 1D list.")
26
27    np_data = [1, 2, 3]
28    ds = de.NumpySlicesDataset(np_data, shuffle=False)
29
30    for i, data in enumerate(ds):
31        assert data[0].asnumpy() == np_data[i]
32
33
34def test_numpy_slices_list_2():
35    logger.info("Test Slicing a 2D list into 1D list.")
36
37    np_data = [[1, 2], [3, 4]]
38    ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
39
40    for i, data in enumerate(ds):
41        assert np.equal(data[0].asnumpy(), np_data[i]).all()
42
43
44def test_numpy_slices_list_3():
45    logger.info("Test Slicing list in the first dimension.")
46
47    np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
48    ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
49
50    for i, data in enumerate(ds):
51        assert np.equal(data[0].asnumpy(), np_data[i]).all()
52
53
54def test_numpy_slices_numpy():
55    logger.info("Test NumPy structure data.")
56
57    np_data = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])
58    ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
59
60    for i, data in enumerate(ds):
61        assert np.equal(data[0].asnumpy(), np_data[i]).all()
62
63
64def test_numpy_slices_list_append():
65    logger.info("Test reading data of image list.")
66
67    DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
68    resize_height, resize_width = 2, 2
69
70    data1 = de.TFRecordDataset(DATA_DIR)
71    resize_op = vision.Resize((resize_height, resize_width))
72    data1 = data1.map(operations=[vision.Decode(True), resize_op], input_columns=["image"])
73
74    res = []
75    for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
76        res.append(data["image"])
77
78    ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False)
79
80    for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
81        assert np.equal(data, res[i]).all()
82
83
84def test_numpy_slices_dict_1():
85    logger.info("Test Dictionary structure data.")
86
87    np_data = {"a": [1, 2], "b": [3, 4]}
88    ds = de.NumpySlicesDataset(np_data, shuffle=False)
89    res = [[1, 3], [2, 4]]
90
91    for i, data in enumerate(ds):
92        assert data[0].asnumpy() == res[i][0]
93        assert data[1].asnumpy() == res[i][1]
94
95
96def test_numpy_slices_tuple_1():
97    logger.info("Test slicing a list of tuple.")
98
99    np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])]
100    ds = de.NumpySlicesDataset(np_data, shuffle=False)
101
102    for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
103        assert np.equal(data, np_data[i]).all()
104
105    assert sum([1 for _ in ds]) == 3
106
107
108def test_numpy_slices_tuple_2():
109    logger.info("Test slicing a tuple of list.")
110
111    np_data = ([1, 2], [3, 4], [5, 6])
112    expected = [[1, 3, 5], [2, 4, 6]]
113    ds = de.NumpySlicesDataset(np_data, shuffle=False)
114
115    for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
116        assert np.equal(data, expected[i]).all()
117
118    assert sum([1 for _ in ds]) == 2
119
120
121def test_numpy_slices_tuple_3():
122    logger.info("Test reading different dimension of tuple data.")
123    features, labels = np.random.sample((5, 2)), np.random.sample((5, 1))
124    data = (features, labels)
125
126    ds = de.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
127
128    for i, data in enumerate(ds):
129        assert np.equal(data[0].asnumpy(), features[i]).all()
130        assert data[1].asnumpy() == labels[i]
131
132
133def test_numpy_slices_csv_value():
134    logger.info("Test loading value of csv file.")
135    csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
136
137    df = pd.read_csv(csv_file)
138    target = df.pop("target")
139    df.pop("state")
140    np_data = (df.values, target.values)
141
142    ds = de.NumpySlicesDataset(np_data, column_names=["col1", "col2"], shuffle=False)
143
144    for i, data in enumerate(ds):
145        assert np.equal(np_data[0][i], data[0].asnumpy()).all()
146        assert np.equal(np_data[1][i], data[1].asnumpy()).all()
147
148
149def test_numpy_slices_csv_dict():
150    logger.info("Test loading csv file as dict.")
151
152    csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
153    df = pd.read_csv(csv_file)
154    df.pop("state")
155    res = df.values
156
157    ds = de.NumpySlicesDataset(dict(df), shuffle=False)
158
159    for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
160        assert np.equal(data, res[i]).all()
161
162
163def test_numpy_slices_num_samplers():
164    logger.info("Test num_samplers.")
165
166    np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
167    ds = de.NumpySlicesDataset(np_data, shuffle=False, num_samples=2)
168
169    for i, data in enumerate(ds):
170        assert np.equal(data[0].asnumpy(), np_data[i]).all()
171
172    assert sum([1 for _ in ds]) == 2
173
174
175def test_numpy_slices_distributed_sampler():
176    logger.info("Test distributed sampler.")
177
178    np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
179    ds = de.NumpySlicesDataset(np_data, shuffle=False, shard_id=0, num_shards=4)
180
181    for i, data in enumerate(ds):
182        assert np.equal(data[0].asnumpy(), np_data[i * 4]).all()
183
184    assert sum([1 for _ in ds]) == 2
185
186
187def test_numpy_slices_distributed_shard_limit():
188    logger.info("Test Slicing a 1D list.")
189
190    np_data = [1, 2, 3]
191    num = sys.maxsize
192    with pytest.raises(ValueError) as err:
193        de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False)
194    assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
195
196
197def test_numpy_slices_distributed_zero_shard():
198    logger.info("Test Slicing a 1D list.")
199
200    np_data = [1, 2, 3]
201    with pytest.raises(ValueError) as err:
202        de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False)
203    assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
204
205
206def test_numpy_slices_sequential_sampler():
207    logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
208
209    np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
210    ds = de.NumpySlicesDataset(np_data, sampler=de.SequentialSampler()).repeat(2)
211
212    for i, data in enumerate(ds):
213        assert np.equal(data[0].asnumpy(), np_data[i % 8]).all()
214
215
216def test_numpy_slices_invalid_column_names_type():
217    logger.info("Test incorrect column_names input")
218    np_data = [1, 2, 3]
219
220    with pytest.raises(TypeError) as err:
221        de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
222    assert "Argument column_names[0] with value 1 is not of type [<class 'str'>]" in str(err.value)
223
224
225def test_numpy_slices_invalid_column_names_string():
226    logger.info("Test incorrect column_names input")
227    np_data = [1, 2, 3]
228
229    with pytest.raises(ValueError) as err:
230        de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
231    assert "column_names[0] should not be empty" in str(err.value)
232
233
234def test_numpy_slices_invalid_empty_column_names():
235    logger.info("Test incorrect column_names input")
236    np_data = [1, 2, 3]
237
238    with pytest.raises(ValueError) as err:
239        de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
240    assert "column_names should not be empty" in str(err.value)
241
242
243def test_numpy_slices_invalid_empty_data_column():
244    logger.info("Test incorrect column_names input")
245    np_data = []
246
247    with pytest.raises(ValueError) as err:
248        de.NumpySlicesDataset(np_data, shuffle=False)
249    assert "Argument data cannot be empty" in str(err.value)
250
251
252def test_numpy_slice_empty_output_shape():
253    logger.info("running test_numpy_slice_empty_output_shape")
254    dataset = de.NumpySlicesDataset([[[1, 2], [3, 4]]], column_names=["col1"])
255    dataset = dataset.batch(batch_size=3, drop_remainder=True)
256    assert dataset.output_shapes() == []
257
258
259if __name__ == "__main__":
260    test_numpy_slices_list_1()
261    test_numpy_slices_list_2()
262    test_numpy_slices_list_3()
263    test_numpy_slices_list_append()
264    test_numpy_slices_dict_1()
265    test_numpy_slices_tuple_1()
266    test_numpy_slices_tuple_2()
267    test_numpy_slices_tuple_3()
268    test_numpy_slices_csv_value()
269    test_numpy_slices_csv_dict()
270    test_numpy_slices_num_samplers()
271    test_numpy_slices_distributed_sampler()
272    test_numpy_slices_distributed_shard_limit()
273    test_numpy_slices_distributed_zero_shard()
274    test_numpy_slices_sequential_sampler()
275    test_numpy_slices_invalid_column_names_type()
276    test_numpy_slices_invalid_column_names_string()
277    test_numpy_slices_invalid_empty_column_names()
278    test_numpy_slices_invalid_empty_data_column()
279    test_numpy_slice_empty_output_shape()
280