• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing fill op
17"""
18import numpy as np
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.c_transforms as data_trans
22
23
24def test_fillop_basic():
25    def gen():
26        yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
27
28    data = ds.GeneratorDataset(gen, column_names=["col"])
29    fill_op = data_trans.Fill(3)
30
31    data = data.map(operations=fill_op, input_columns=["col"])
32    expected = np.array([3, 3, 3, 3], dtype=np.uint8)
33    for data_row in data:
34        np.testing.assert_array_equal(data_row[0].asnumpy(), expected)
35
36
37def test_fillop_down_type_cast():
38    def gen():
39        yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
40
41    data = ds.GeneratorDataset(gen, column_names=["col"])
42    fill_op = data_trans.Fill(-3)
43
44    data = data.map(operations=fill_op, input_columns=["col"])
45    expected = np.array([253, 253, 253, 253], dtype=np.uint8)
46    for data_row in data:
47        np.testing.assert_array_equal(data_row[0].asnumpy(), expected)
48
49
50def test_fillop_up_type_cast():
51    def gen():
52        yield (np.array([4, 5, 6, 7], dtype=np.float),)
53
54    data = ds.GeneratorDataset(gen, column_names=["col"])
55    fill_op = data_trans.Fill(3)
56
57    data = data.map(operations=fill_op, input_columns=["col"])
58    expected = np.array([3., 3., 3., 3.], dtype=np.float)
59    for data_row in data:
60        np.testing.assert_array_equal(data_row[0].asnumpy(), expected)
61
62
63def test_fillop_string():
64    def gen():
65        yield (np.array(["45555", "45555"], dtype='S'),)
66
67    data = ds.GeneratorDataset(gen, column_names=["col"])
68    fill_op = data_trans.Fill("error")
69
70    data = data.map(operations=fill_op, input_columns=["col"])
71    expected = np.array(['error', 'error'], dtype='S')
72    for data_row in data.create_tuple_iterator(output_numpy=True):
73        np.testing.assert_array_equal(data_row[0], expected)
74
75
76def test_fillop_bytes():
77    def gen():
78        yield (np.array(["A", "B", "C"], dtype='S'),)
79
80    data = ds.GeneratorDataset(gen, column_names=["col"])
81    fill_op = data_trans.Fill(b'abc')
82
83    data = data.map(operations=fill_op, input_columns=["col"])
84    expected = np.array([b'abc', b'abc', b'abc'], dtype='S')
85    for data_row in data.create_tuple_iterator(output_numpy=True):
86        np.testing.assert_array_equal(data_row[0], expected)
87
88
89def test_fillop_error_handling():
90    def gen():
91        yield (np.array([4, 4, 4, 4]),)
92
93    data = ds.GeneratorDataset(gen, column_names=["col"])
94    fill_op = data_trans.Fill("words")
95    data = data.map(operations=fill_op, input_columns=["col"])
96
97    with pytest.raises(RuntimeError) as error_info:
98        for _ in data:
99            pass
100    assert "fill datatype is string but the input datatype is not string" in str(error_info.value)
101
102
103if __name__ == "__main__":
104    test_fillop_basic()
105    test_fillop_up_type_cast()
106    test_fillop_down_type_cast()
107    test_fillop_string()
108    test_fillop_bytes()
109    test_fillop_error_handling()
110