1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing ToPIL op in DE 17""" 18import mindspore.dataset as ds 19import mindspore.dataset.transforms.py_transforms 20import mindspore.dataset.vision.c_transforms as c_vision 21import mindspore.dataset.vision.py_transforms as py_vision 22from mindspore import log as logger 23from util import save_and_check_md5 24 25GENERATE_GOLDEN = False 26 27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 29 30 31def test_to_pil_01(): 32 """ 33 Test ToPIL Op with md5 comparison: input is already PIL image 34 Expected to pass 35 """ 36 logger.info("test_to_pil_01") 37 38 # Generate dataset 39 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 40 transforms = [ 41 py_vision.Decode(), 42 # If input is already PIL image. 43 py_vision.ToPIL(), 44 py_vision.CenterCrop(375), 45 py_vision.ToTensor() 46 ] 47 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 48 data1 = data1.map(operations=transform, input_columns=["image"]) 49 50 # Compare with expected md5 from images 51 filename = "to_pil_01_result.npz" 52 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 53 54def test_to_pil_02(): 55 """ 56 Test ToPIL Op with md5 comparison: input is not PIL image 57 Expected to pass 58 """ 59 logger.info("test_to_pil_02") 60 61 # Generate dataset 62 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 63 decode_op = c_vision.Decode() 64 transforms = [ 65 # If input type is not PIL. 66 py_vision.ToPIL(), 67 py_vision.CenterCrop(375), 68 py_vision.ToTensor() 69 ] 70 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 71 data1 = data1.map(operations=decode_op, input_columns=["image"]) 72 data1 = data1.map(operations=transform, input_columns=["image"]) 73 74 # Compare with expected md5 from images 75 filename = "to_pil_02_result.npz" 76 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 77 78if __name__ == "__main__": 79 test_to_pil_01() 80 test_to_pil_02() 81