• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18import mindspore.dataset.vision.c_transforms as vision
19from mindspore import log as logger
20
21DATA_DIR = "../data/dataset/testPK/data"
22
23
24# Generate 1d int numpy array from 0 - 64
25def generator_1d():
26    for i in range(64):
27        yield (np.array([i]),)
28
29
30def test_apply_generator_case():
31    # apply dataset operations
32    data1 = ds.GeneratorDataset(generator_1d, ["data"])
33    data2 = ds.GeneratorDataset(generator_1d, ["data"])
34
35    def dataset_fn(ds_):
36        ds_ = ds_.repeat(2)
37        return ds_.batch(4)
38
39    data1 = data1.apply(dataset_fn)
40    data2 = data2.repeat(2)
41    data2 = data2.batch(4)
42
43    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
44                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
45        np.testing.assert_array_equal(item1["data"], item2["data"])
46
47
48def test_apply_imagefolder_case():
49    # apply dataset map operations
50    data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
51    data2 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
52
53    decode_op = vision.Decode()
54    normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
55
56    def dataset_fn(ds_):
57        ds_ = ds_.map(operations=decode_op)
58        ds_ = ds_.map(operations=normalize_op)
59        ds_ = ds_.repeat(2)
60        return ds_
61
62    data1 = data1.apply(dataset_fn)
63    data2 = data2.map(operations=decode_op)
64    data2 = data2.map(operations=normalize_op)
65    data2 = data2.repeat(2)
66
67    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
68                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
69        np.testing.assert_array_equal(item1["image"], item2["image"])
70
71
72def test_apply_flow_case_0(id_=0):
73    # apply control flow operations
74    data1 = ds.GeneratorDataset(generator_1d, ["data"])
75
76    def dataset_fn(ds_):
77        if id_ == 0:
78            ds_ = ds_.batch(4)
79        elif id_ == 1:
80            ds_ = ds_.repeat(2)
81        elif id_ == 2:
82            ds_ = ds_.batch(4)
83            ds_ = ds_.repeat(2)
84        else:
85            ds_ = ds_.shuffle(buffer_size=4)
86        return ds_
87
88    data1 = data1.apply(dataset_fn)
89    num_iter = 0
90    for _ in data1.create_dict_iterator(num_epochs=1):
91        num_iter += 1
92
93    if id_ == 0:
94        assert num_iter == 16
95    elif id_ == 1:
96        assert num_iter == 128
97    elif id_ == 2:
98        assert num_iter == 32
99    else:
100        assert num_iter == 64
101
102
103def test_apply_flow_case_1(id_=1):
104    # apply control flow operations
105    data1 = ds.GeneratorDataset(generator_1d, ["data"])
106
107    def dataset_fn(ds_):
108        if id_ == 0:
109            ds_ = ds_.batch(4)
110        elif id_ == 1:
111            ds_ = ds_.repeat(2)
112        elif id_ == 2:
113            ds_ = ds_.batch(4)
114            ds_ = ds_.repeat(2)
115        else:
116            ds_ = ds_.shuffle(buffer_size=4)
117        return ds_
118
119    data1 = data1.apply(dataset_fn)
120    num_iter = 0
121    for _ in data1.create_dict_iterator(num_epochs=1):
122        num_iter += 1
123
124    if id_ == 0:
125        assert num_iter == 16
126    elif id_ == 1:
127        assert num_iter == 128
128    elif id_ == 2:
129        assert num_iter == 32
130    else:
131        assert num_iter == 64
132
133
134def test_apply_flow_case_2(id_=2):
135    # apply control flow operations
136    data1 = ds.GeneratorDataset(generator_1d, ["data"])
137
138    def dataset_fn(ds_):
139        if id_ == 0:
140            ds_ = ds_.batch(4)
141        elif id_ == 1:
142            ds_ = ds_.repeat(2)
143        elif id_ == 2:
144            ds_ = ds_.batch(4)
145            ds_ = ds_.repeat(2)
146        else:
147            ds_ = ds_.shuffle(buffer_size=4)
148        return ds_
149
150    data1 = data1.apply(dataset_fn)
151    num_iter = 0
152    for _ in data1.create_dict_iterator(num_epochs=1):
153        num_iter += 1
154
155    if id_ == 0:
156        assert num_iter == 16
157    elif id_ == 1:
158        assert num_iter == 128
159    elif id_ == 2:
160        assert num_iter == 32
161    else:
162        assert num_iter == 64
163
164
165def test_apply_flow_case_3(id_=3):
166    # apply control flow operations
167    data1 = ds.GeneratorDataset(generator_1d, ["data"])
168
169    def dataset_fn(ds_):
170        if id_ == 0:
171            ds_ = ds_.batch(4)
172        elif id_ == 1:
173            ds_ = ds_.repeat(2)
174        elif id_ == 2:
175            ds_ = ds_.batch(4)
176            ds_ = ds_.repeat(2)
177        else:
178            ds_ = ds_.shuffle(buffer_size=4)
179        return ds_
180
181    data1 = data1.apply(dataset_fn)
182    num_iter = 0
183    for _ in data1.create_dict_iterator(num_epochs=1):
184        num_iter += 1
185
186    if id_ == 0:
187        assert num_iter == 16
188    elif id_ == 1:
189        assert num_iter == 128
190    elif id_ == 2:
191        assert num_iter == 32
192    else:
193        assert num_iter == 64
194
195
196def test_apply_exception_case():
197    # apply exception operations
198    data1 = ds.GeneratorDataset(generator_1d, ["data"])
199
200    def dataset_fn(ds_):
201        ds_ = ds_.repeat(2)
202        return ds_.batch(4)
203
204    def exception_fn():
205        return np.array([[0], [1], [3], [4], [5]])
206
207    try:
208        data1 = data1.apply("123")
209        for _ in data1.create_dict_iterator(num_epochs=1):
210            pass
211        assert False
212    except TypeError:
213        pass
214
215    try:
216        data1 = data1.apply(exception_fn)
217        for _ in data1.create_dict_iterator(num_epochs=1):
218            pass
219        assert False
220    except TypeError:
221        pass
222
223    try:
224        data2 = data1.apply(dataset_fn)
225        _ = data1.apply(dataset_fn)
226        for _, _ in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
227            pass
228        assert False
229    except ValueError as e:
230        logger.info("Got an exception in DE: {}".format(str(e)))
231
232
233if __name__ == '__main__':
234    logger.info("Running test_apply.py test_apply_generator_case() function")
235    test_apply_generator_case()
236
237    logger.info("Running test_apply.py test_apply_imagefolder_case() function")
238    test_apply_imagefolder_case()
239
240    logger.info("Running test_apply.py test_apply_flow_case(id) function")
241    test_apply_flow_case_0()
242    test_apply_flow_case_1()
243    test_apply_flow_case_2()
244    test_apply_flow_case_3()
245
246    logger.info("Running test_apply.py test_apply_exception_case() function")
247    test_apply_exception_case()
248