1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import matplotlib.pyplot as plt 17 18import mindspore.dataset as ds 19import mindspore.dataset.vision.c_transforms as c_vision 20from mindspore import log as logger 21 22FLICKR30K_DATASET_DIR = "../data/dataset/testFlickrData/flickr30k/flickr30k-images" 23FLICKR30K_ANNOTATION_FILE_1 = "../data/dataset/testFlickrData/flickr30k/test1.token" 24FLICKR30K_ANNOTATION_FILE_2 = "../data/dataset/testFlickrData/flickr30k/test2.token" 25 26 27def visualize_dataset(images, labels): 28 """ 29 Helper function to visualize the dataset samples 30 """ 31 plt.figure(figsize=(10, 10)) 32 for i, item in enumerate(zip(images, labels), start=1): 33 plt.imshow(item[0]) 34 plt.title('\n'.join([s.decode('utf-8') for s in item[1]])) 35 plt.savefig('./flickr_' + str(i) + '.jpg') 36 37 38def test_flickr30k_dataset_train(plot=False): 39 data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) 40 count = 0 41 images_list = [] 42 annotation_list = [] 43 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 44 logger.info("item[image] is {}".format(item["image"])) 45 images_list.append(item['image']) 46 annotation_list.append(item['annotation']) 47 count = count + 1 48 assert count == 2 49 if plot: 50 visualize_dataset(images_list, annotation_list) 51 52 53def test_flickr30k_dataset_annotation_check(): 54 data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True, shuffle=False) 55 count = 0 56 expect_annotation_arr = [ 57 np.array([ 58 r'This is \*a banana.', 59 'This is a yellow banana.', 60 'This is a banana on the table.', 61 'The banana is yellow.', 62 'The banana is very big.', 63 ]), 64 np.array([ 65 'This is a pen.', 66 'This is a red and black pen.', 67 'This is a pen on the table.', 68 'The color of the pen is red and black.', 69 'The pen has two colors.', 70 ]) 71 ] 72 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 73 annotation = [s.decode("utf8") for s in item["annotation"]] 74 np.testing.assert_array_equal(annotation, expect_annotation_arr[count]) 75 logger.info("item[image] is {}".format(item["image"])) 76 count = count + 1 77 assert count == 2 78 79 80def test_flickr30k_dataset_basic(): 81 # case 1: test num_samples 82 data1 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, num_samples=2, decode=True) 83 num_iter1 = 0 84 for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 85 num_iter1 += 1 86 assert num_iter1 == 2 87 88 # case 2: test repeat 89 data2 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) 90 data2 = data2.repeat(5) 91 num_iter2 = 0 92 for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True): 93 num_iter2 += 1 94 assert num_iter2 == 10 95 96 # case 3: test batch with drop_remainder=False 97 data3 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False) 98 resize_op = c_vision.Resize((100, 100)) 99 data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) 100 assert data3.get_dataset_size() == 3 101 assert data3.get_batch_size() == 1 102 data3 = data3.batch(batch_size=2) # drop_remainder is default to be False 103 assert data3.get_dataset_size() == 2 104 assert data3.get_batch_size() == 2 105 num_iter3 = 0 106 for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True): 107 num_iter3 += 1 108 assert num_iter3 == 2 109 110 # case 4: test batch with drop_remainder=True 111 data4 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False) 112 resize_op = c_vision.Resize((100, 100)) 113 data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) 114 assert data4.get_dataset_size() == 3 115 assert data4.get_batch_size() == 1 116 data4 = data4.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped 117 assert data4.get_dataset_size() == 1 118 assert data4.get_batch_size() == 2 119 num_iter4 = 0 120 for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True): 121 num_iter4 += 1 122 assert num_iter4 == 1 123 124 125def test_flickr30k_dataset_exception(): 126 def exception_func(item): 127 raise Exception("Error occur!") 128 129 try: 130 data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) 131 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 132 num_rows = 0 133 for _ in data.create_dict_iterator(): 134 num_rows += 1 135 assert False 136 except RuntimeError as e: 137 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 138 139 try: 140 data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) 141 data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1) 142 num_rows = 0 143 for _ in data.create_dict_iterator(): 144 num_rows += 1 145 assert False 146 except RuntimeError as e: 147 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 148 149 150if __name__ == "__main__": 151 test_flickr30k_dataset_train(False) 152 test_flickr30k_dataset_annotation_check() 153 test_flickr30k_dataset_basic() 154 test_flickr30k_dataset_exception() 155