• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import mindspore.dataset as ds
16from mindspore import log as logger
17
18
19def test_num_samples():
20    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
21    num_samples = 1
22    # sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
23    data1 = ds.ManifestDataset(
24        manifest_file, num_samples=num_samples, num_shards=3, shard_id=1
25    )
26    row_count = 0
27    for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
28        row_count += 1
29    assert row_count == 1
30
31
32def test_num_samples_tf():
33    logger.info("test_tfrecord_read_all_dataset")
34    schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
35    files = ["../data/dataset/testTFTestAllTypes/test.data"]
36    # here num samples indicate the rows per shard. Total rows in file = 12
37    ds1 = ds.TFRecordDataset(files, schema_file, num_samples=2)
38    count = 0
39    for _ in ds1.create_tuple_iterator(num_epochs=1):
40        count += 1
41    assert count == 2
42
43
44def test_num_samples_image_folder():
45    data_dir = "../data/dataset/testPK/data"
46    ds1 = ds.ImageFolderDataset(data_dir, num_samples=2, num_shards=2, shard_id=0)
47    count = 0
48    for _ in ds1.create_tuple_iterator(num_epochs=1):
49        count += 1
50    assert count == 2
51
52
53if __name__ == "__main__":
54    test_num_samples()
55    test_num_samples_tf()
56    test_num_samples_image_folder()
57