• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing ToPIL op in DE
17"""
18import mindspore.dataset as ds
19import mindspore.dataset.transforms.py_transforms
20import mindspore.dataset.vision.c_transforms as c_vision
21import mindspore.dataset.vision.py_transforms as py_vision
22from mindspore import log as logger
23from util import save_and_check_md5
24
25GENERATE_GOLDEN = False
26
27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29
30
31def test_to_pil_01():
32    """
33    Test ToPIL Op with md5 comparison: input is already PIL image
34    Expected to pass
35    """
36    logger.info("test_to_pil_01")
37
38    # Generate dataset
39    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
40    transforms = [
41        py_vision.Decode(),
42        # If input is already PIL image.
43        py_vision.ToPIL(),
44        py_vision.CenterCrop(375),
45        py_vision.ToTensor()
46    ]
47    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
48    data1 = data1.map(operations=transform, input_columns=["image"])
49
50    # Compare with expected md5 from images
51    filename = "to_pil_01_result.npz"
52    save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
53
54def test_to_pil_02():
55    """
56    Test ToPIL Op with md5 comparison: input is not PIL image
57    Expected to pass
58    """
59    logger.info("test_to_pil_02")
60
61    # Generate dataset
62    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
63    decode_op = c_vision.Decode()
64    transforms = [
65        # If input type is not PIL.
66        py_vision.ToPIL(),
67        py_vision.CenterCrop(375),
68        py_vision.ToTensor()
69    ]
70    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
71    data1 = data1.map(operations=decode_op, input_columns=["image"])
72    data1 = data1.map(operations=transform, input_columns=["image"])
73
74    # Compare with expected md5 from images
75    filename = "to_pil_02_result.npz"
76    save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
77
78if __name__ == "__main__":
79    test_to_pil_01()
80    test_to_pil_02()
81