# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import numpy as np import matplotlib.pyplot as plt import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as c_vision from mindspore import log as logger FLICKR30K_DATASET_DIR = "../data/dataset/testFlickrData/flickr30k/flickr30k-images" FLICKR30K_ANNOTATION_FILE_1 = "../data/dataset/testFlickrData/flickr30k/test1.token" FLICKR30K_ANNOTATION_FILE_2 = "../data/dataset/testFlickrData/flickr30k/test2.token" def visualize_dataset(images, labels): """ Helper function to visualize the dataset samples """ plt.figure(figsize=(10, 10)) for i, item in enumerate(zip(images, labels), start=1): plt.imshow(item[0]) plt.title('\n'.join([s.decode('utf-8') for s in item[1]])) plt.savefig('./flickr_' + str(i) + '.jpg') def test_flickr30k_dataset_train(plot=False): data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) count = 0 images_list = [] annotation_list = [] for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("item[image] is {}".format(item["image"])) images_list.append(item['image']) annotation_list.append(item['annotation']) count = count + 1 assert count == 2 if plot: visualize_dataset(images_list, annotation_list) def test_flickr30k_dataset_annotation_check(): data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True, shuffle=False) count = 0 expect_annotation_arr = [ np.array([ r'This is \*a banana.', 'This is a yellow banana.', 'This is a banana on the table.', 'The banana is yellow.', 'The banana is very big.', ]), np.array([ 'This is a pen.', 'This is a red and black pen.', 'This is a pen on the table.', 'The color of the pen is red and black.', 'The pen has two colors.', ]) ] for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): annotation = [s.decode("utf8") for s in item["annotation"]] np.testing.assert_array_equal(annotation, expect_annotation_arr[count]) logger.info("item[image] is {}".format(item["image"])) count = count + 1 assert count == 2 def test_flickr30k_dataset_basic(): # case 1: test num_samples data1 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, num_samples=2, decode=True) num_iter1 = 0 for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter1 += 1 assert num_iter1 == 2 # case 2: test repeat data2 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) data2 = data2.repeat(5) num_iter2 = 0 for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter2 += 1 assert num_iter2 == 10 # case 3: test batch with drop_remainder=False data3 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False) resize_op = c_vision.Resize((100, 100)) data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) assert data3.get_dataset_size() == 3 assert data3.get_batch_size() == 1 data3 = data3.batch(batch_size=2) # drop_remainder is default to be False assert data3.get_dataset_size() == 2 assert data3.get_batch_size() == 2 num_iter3 = 0 for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter3 += 1 assert num_iter3 == 2 # case 4: test batch with drop_remainder=True data4 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False) resize_op = c_vision.Resize((100, 100)) data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) assert data4.get_dataset_size() == 3 assert data4.get_batch_size() == 1 data4 = data4.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped assert data4.get_dataset_size() == 1 assert data4.get_batch_size() == 2 num_iter4 = 0 for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter4 += 1 assert num_iter4 == 1 def test_flickr30k_dataset_exception(): def exception_func(item): raise Exception("Error occur!") try: data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) num_rows = 0 for _ in data.create_dict_iterator(): num_rows += 1 assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) try: data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True) data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1) num_rows = 0 for _ in data.create_dict_iterator(): num_rows += 1 assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) if __name__ == "__main__": test_flickr30k_dataset_train(False) test_flickr30k_dataset_annotation_check() test_flickr30k_dataset_basic() test_flickr30k_dataset_exception()