1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing PadEnd op in DE 17""" 18import numpy as np 19import pytest 20 21import mindspore.dataset as ds 22import mindspore.dataset.transforms.c_transforms as ops 23 24 25# Extensive testing of PadEnd is already done in batch with Pad test cases 26 27def pad_compare(array, pad_shape, pad_value, res): 28 data = ds.NumpySlicesDataset([array]) 29 if pad_value is not None: 30 data = data.map(operations=ops.PadEnd(pad_shape, pad_value)) 31 else: 32 data = data.map(operations=ops.PadEnd(pad_shape)) 33 for d in data.create_tuple_iterator(output_numpy=True): 34 np.testing.assert_array_equal(res, d[0]) 35 36 37def test_pad_end_basics(): 38 pad_compare([1, 2], [3], -1, [1, 2, -1]) 39 pad_compare([1, 2, 3], [3], -1, [1, 2, 3]) 40 pad_compare([1, 2, 3], [2], -1, [1, 2]) 41 pad_compare([1, 2, 3], [5], None, [1, 2, 3, 0, 0]) 42 43 44def test_pad_end_str(): 45 pad_compare([b"1", b"2"], [3], b"-1", [b"1", b"2", b"-1"]) 46 pad_compare([b"1", b"2", b"3"], [3], b"-1", [b"1", b"2", b"3"]) 47 pad_compare([b"1", b"2", b"3"], [2], b"-1", [b"1", b"2"]) 48 pad_compare([b"1", b"2", b"3"], [5], None, [b"1", b"2", b"3", b"", b""]) 49 50 51def test_pad_end_exceptions(): 52 with pytest.raises(RuntimeError) as info: 53 pad_compare([1, 2], [3], "-1", []) 54 assert "pad_value and item of dataset are not of the same type" in str(info.value) 55 56 with pytest.raises(RuntimeError) as info: 57 pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, []) 58 assert "pad_value and item of dataset are not of the same type" in str(info.value) 59 60 with pytest.raises(TypeError) as info: 61 pad_compare([3, 4, 5], ["2"], 1, []) 62 assert "a value in the list is not an integer." in str(info.value) 63 64 with pytest.raises(TypeError) as info: 65 pad_compare([1, 2], 3, -1, [1, 2, -1]) 66 assert "Argument pad_shape with value 3 is not of type [<class 'list'>]" in str(info.value) 67 68 69if __name__ == "__main__": 70 test_pad_end_basics() 71 test_pad_end_str() 72 test_pad_end_exceptions() 73