• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import mindspore.common.dtype as mstype
16import mindspore.dataset as ds
17from mindspore import log as logger
18
19# just a basic test with parallel random data op
20def test_randomdataset_basic1():
21    logger.info("Test randomdataset basic 1")
22
23    schema = ds.Schema()
24    schema.add_column('image', de_type=mstype.uint8, shape=[2])
25    schema.add_column('label', de_type=mstype.uint8, shape=[1])
26
27    # apply dataset operations
28    ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4)
29    ds1 = ds1.repeat(4)
30
31    num_iter = 0
32    for data in ds1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
33        # in this example, each dictionary has keys "image" and "label"
34        logger.info("{} image: {}".format(num_iter, data["image"]))
35        logger.info("{} label: {}".format(num_iter, data["label"]))
36        num_iter += 1
37
38    logger.info("Number of data in ds1: {}".format(num_iter))
39    assert num_iter == 200
40    logger.info("Test randomdataset basic 1 complete")
41
42
43# Another simple test
44def test_randomdataset_basic2():
45    logger.info("Test randomdataset basic 2")
46
47    schema = ds.Schema()
48    schema.add_column('image', de_type=mstype.uint8,
49                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
50    schema.add_column('label', de_type=mstype.uint8, shape=[1])
51
52    # Make up 10 rows
53    ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1)
54    ds1 = ds1.repeat(4)
55
56    num_iter = 0
57    for data in ds1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
58        # in this example, each dictionary has keys "image" and "label"
59        # logger.info(data["image"])
60        logger.info("printing the label: {}".format(data["label"]))
61        num_iter += 1
62
63    logger.info("Number of data in ds1: {}".format(num_iter))
64    assert num_iter == 40
65    logger.info("Test randomdataset basic 2 complete")
66
67
68# Another simple test
69def test_randomdataset_basic3():
70    logger.info("Test randomdataset basic 3")
71
72    # Make up 10 samples, but here even the schema is randomly created
73    # The columns are named like this "c0", "c1", "c2" etc
74    # But, we will use a tuple iterator instead of dict iterator so the column names
75    # are not needed to iterate
76    ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1)
77    ds1 = ds1.repeat(2)
78
79    num_iter = 0
80    for _ in ds1.create_tuple_iterator(num_epochs=1):
81        num_iter += 1
82
83    logger.info("Number of data in ds1: {}".format(num_iter))
84    assert num_iter == 20
85    logger.info("Test randomdataset basic 3 Complete")
86
87if __name__ == '__main__':
88    test_randomdataset_basic1()
89    test_randomdataset_basic2()
90    test_randomdataset_basic3()
91