1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18 19DATA_FILE = "../data/dataset/test_flat_map/images1.txt" 20INDEX_FILE = "../data/dataset/test_flat_map/image_index.txt" 21 22 23def test_flat_map_1(): 24 ''' 25 DATA_FILE records the path of image folders, load the images from them. 26 ''' 27 28 def flat_map_func(x): 29 data_dir = x[0].item().decode('utf8') 30 d = ds.ImageFolderDataset(data_dir) 31 return d 32 33 data = ds.TextFileDataset(DATA_FILE) 34 data = data.flat_map(flat_map_func) 35 36 count = 0 37 for d in data.create_tuple_iterator(output_numpy=True): 38 assert isinstance(d[0], np.ndarray) 39 count += 1 40 assert count == 52 41 42 43def test_flat_map_2(): 44 ''' 45 Flatten 3D structure data 46 ''' 47 48 def flat_map_func_1(x): 49 data_dir = x[0].item().decode('utf8') 50 d = ds.ImageFolderDataset(data_dir) 51 return d 52 53 def flat_map_func_2(x): 54 text_file = x[0].item().decode('utf8') 55 d = ds.TextFileDataset(text_file) 56 d = d.flat_map(flat_map_func_1) 57 return d 58 59 data = ds.TextFileDataset(INDEX_FILE) 60 data = data.flat_map(flat_map_func_2) 61 62 count = 0 63 for d in data.create_tuple_iterator(output_numpy=True): 64 assert isinstance(d[0], np.ndarray) 65 count += 1 66 assert count == 104 67 68 69if __name__ == "__main__": 70 test_flat_map_1() 71 test_flat_map_2() 72