1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing fill op 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.c_transforms as data_trans 22 23 24def test_fillop_basic(): 25 def gen(): 26 yield (np.array([4, 5, 6, 7], dtype=np.uint8),) 27 28 data = ds.GeneratorDataset(gen, column_names=["col"]) 29 fill_op = data_trans.Fill(3) 30 31 data = data.map(operations=fill_op, input_columns=["col"]) 32 expected = np.array([3, 3, 3, 3], dtype=np.uint8) 33 for data_row in data: 34 np.testing.assert_array_equal(data_row[0].asnumpy(), expected) 35 36 37def test_fillop_down_type_cast(): 38 def gen(): 39 yield (np.array([4, 5, 6, 7], dtype=np.uint8),) 40 41 data = ds.GeneratorDataset(gen, column_names=["col"]) 42 fill_op = data_trans.Fill(-3) 43 44 data = data.map(operations=fill_op, input_columns=["col"]) 45 expected = np.array([253, 253, 253, 253], dtype=np.uint8) 46 for data_row in data: 47 np.testing.assert_array_equal(data_row[0].asnumpy(), expected) 48 49 50def test_fillop_up_type_cast(): 51 def gen(): 52 yield (np.array([4, 5, 6, 7], dtype=np.float),) 53 54 data = ds.GeneratorDataset(gen, column_names=["col"]) 55 fill_op = data_trans.Fill(3) 56 57 data = data.map(operations=fill_op, input_columns=["col"]) 58 expected = np.array([3., 3., 3., 3.], dtype=np.float) 59 for data_row in data: 60 np.testing.assert_array_equal(data_row[0].asnumpy(), expected) 61 62 63def test_fillop_string(): 64 def gen(): 65 yield (np.array(["45555", "45555"], dtype='S'),) 66 67 data = ds.GeneratorDataset(gen, column_names=["col"]) 68 fill_op = data_trans.Fill("error") 69 70 data = data.map(operations=fill_op, input_columns=["col"]) 71 expected = np.array(['error', 'error'], dtype='S') 72 for data_row in data.create_tuple_iterator(output_numpy=True): 73 np.testing.assert_array_equal(data_row[0], expected) 74 75 76def test_fillop_bytes(): 77 def gen(): 78 yield (np.array(["A", "B", "C"], dtype='S'),) 79 80 data = ds.GeneratorDataset(gen, column_names=["col"]) 81 fill_op = data_trans.Fill(b'abc') 82 83 data = data.map(operations=fill_op, input_columns=["col"]) 84 expected = np.array([b'abc', b'abc', b'abc'], dtype='S') 85 for data_row in data.create_tuple_iterator(output_numpy=True): 86 np.testing.assert_array_equal(data_row[0], expected) 87 88 89def test_fillop_error_handling(): 90 def gen(): 91 yield (np.array([4, 4, 4, 4]),) 92 93 data = ds.GeneratorDataset(gen, column_names=["col"]) 94 fill_op = data_trans.Fill("words") 95 data = data.map(operations=fill_op, input_columns=["col"]) 96 97 with pytest.raises(RuntimeError) as error_info: 98 for _ in data: 99 pass 100 assert "fill datatype is string but the input datatype is not string" in str(error_info.value) 101 102 103if __name__ == "__main__": 104 test_fillop_basic() 105 test_fillop_up_type_cast() 106 test_fillop_down_type_cast() 107 test_fillop_string() 108 test_fillop_bytes() 109 test_fillop_error_handling() 110