• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing PadEnd op in DE
17"""
18import numpy as np
19import pytest
20
21import mindspore.dataset as ds
22import mindspore.dataset.transforms.c_transforms as ops
23
24
25# Extensive testing of PadEnd is already done in batch with Pad test cases
26
27def pad_compare(array, pad_shape, pad_value, res):
28    data = ds.NumpySlicesDataset([array])
29    if pad_value is not None:
30        data = data.map(operations=ops.PadEnd(pad_shape, pad_value))
31    else:
32        data = data.map(operations=ops.PadEnd(pad_shape))
33    for d in data.create_tuple_iterator(output_numpy=True):
34        np.testing.assert_array_equal(res, d[0])
35
36
37def test_pad_end_basics():
38    pad_compare([1, 2], [3], -1, [1, 2, -1])
39    pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
40    pad_compare([1, 2, 3], [2], -1, [1, 2])
41    pad_compare([1, 2, 3], [5], None, [1, 2, 3, 0, 0])
42
43
44def test_pad_end_str():
45    pad_compare([b"1", b"2"], [3], b"-1", [b"1", b"2", b"-1"])
46    pad_compare([b"1", b"2", b"3"], [3], b"-1", [b"1", b"2", b"3"])
47    pad_compare([b"1", b"2", b"3"], [2], b"-1", [b"1", b"2"])
48    pad_compare([b"1", b"2", b"3"], [5], None, [b"1", b"2", b"3", b"", b""])
49
50
51def test_pad_end_exceptions():
52    with pytest.raises(RuntimeError) as info:
53        pad_compare([1, 2], [3], "-1", [])
54    assert "pad_value and item of dataset are not of the same type" in str(info.value)
55
56    with pytest.raises(RuntimeError) as info:
57        pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
58    assert "pad_value and item of dataset are not of the same type" in str(info.value)
59
60    with pytest.raises(TypeError) as info:
61        pad_compare([3, 4, 5], ["2"], 1, [])
62    assert "a value in the list is not an integer." in str(info.value)
63
64    with pytest.raises(TypeError) as info:
65        pad_compare([1, 2], 3, -1, [1, 2, -1])
66    assert "Argument pad_shape with value 3 is not of type [<class 'list'>]" in str(info.value)
67
68
69if __name__ == "__main__":
70    test_pad_end_basics()
71    test_pad_end_str()
72    test_pad_end_exceptions()
73