1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing concatenate op 17""" 18 19import numpy as np 20import pytest 21 22import mindspore.dataset as ds 23import mindspore.dataset.transforms.c_transforms as data_trans 24 25 26def test_concatenate_op_all(): 27 def gen(): 28 yield (np.array([5., 6., 7., 8.], dtype=np.float),) 29 30 prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) 31 append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) 32 data = ds.GeneratorDataset(gen, column_names=["col"]) 33 concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) 34 data = data.map(operations=concatenate_op, input_columns=["col"]) 35 expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, 36 11., 12.]) 37 for data_row in data.create_tuple_iterator(output_numpy=True): 38 np.testing.assert_array_equal(data_row[0], expected) 39 40 41def test_concatenate_op_none(): 42 def gen(): 43 yield (np.array([5., 6., 7., 8.], dtype=np.float),) 44 45 data = ds.GeneratorDataset(gen, column_names=["col"]) 46 concatenate_op = data_trans.Concatenate() 47 48 data = data.map(operations=concatenate_op, input_columns=["col"]) 49 for data_row in data.create_tuple_iterator(output_numpy=True): 50 np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float)) 51 52 53def test_concatenate_op_string(): 54 def gen(): 55 yield (np.array(["ss", "ad"], dtype='S'),) 56 57 prepend_tensor = np.array(["dw", "df"], dtype='S') 58 append_tensor = np.array(["dwsdf", "df"], dtype='S') 59 data = ds.GeneratorDataset(gen, column_names=["col"]) 60 concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) 61 62 data = data.map(operations=concatenate_op, input_columns=["col"]) 63 expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S') 64 for data_row in data.create_tuple_iterator(output_numpy=True): 65 np.testing.assert_array_equal(data_row[0], expected) 66 67 68def test_concatenate_op_multi_input_string(): 69 prepend_tensor = np.array(["dw", "df"], dtype='S') 70 append_tensor = np.array(["dwsdf", "df"], dtype='S') 71 72 data = ([["1", "2", "d"]], [["3", "4", "e"]]) 73 data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) 74 75 concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor) 76 77 data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"], 78 output_columns=["out1"]) 79 expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S') 80 for data_row in data.create_tuple_iterator(output_numpy=True): 81 np.testing.assert_array_equal(data_row[0], expected) 82 83 84def test_concatenate_op_multi_input_numeric(): 85 prepend_tensor = np.array([3, 5]) 86 87 data = ([[1, 2]], [[3, 4]]) 88 data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) 89 90 concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor) 91 92 data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"], 93 output_columns=["out1"]) 94 expected = np.array([3, 5, 1, 2, 3, 4]) 95 for data_row in data.create_tuple_iterator(output_numpy=True): 96 np.testing.assert_array_equal(data_row[0], expected) 97 98 99def test_concatenate_op_type_mismatch(): 100 def gen(): 101 yield (np.array([3, 4], dtype=np.float),) 102 103 prepend_tensor = np.array(["ss", "ad"], dtype='S') 104 data = ds.GeneratorDataset(gen, column_names=["col"]) 105 concatenate_op = data_trans.Concatenate(0, prepend_tensor) 106 107 data = data.map(operations=concatenate_op, input_columns=["col"]) 108 with pytest.raises(RuntimeError) as error_info: 109 for _ in data: 110 pass 111 assert "input datatype does not match" in str(error_info.value) 112 113 114def test_concatenate_op_type_mismatch2(): 115 def gen(): 116 yield (np.array(["ss", "ad"], dtype='S'),) 117 118 prepend_tensor = np.array([3, 5], dtype=np.float) 119 data = ds.GeneratorDataset(gen, column_names=["col"]) 120 concatenate_op = data_trans.Concatenate(0, prepend_tensor) 121 122 data = data.map(operations=concatenate_op, input_columns=["col"]) 123 with pytest.raises(RuntimeError) as error_info: 124 for _ in data: 125 pass 126 assert "input datatype does not match" in str(error_info.value) 127 128 129def test_concatenate_op_incorrect_dim(): 130 def gen(): 131 yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) 132 133 prepend_tensor = np.array(["ss", "ss"], dtype='S') 134 concatenate_op = data_trans.Concatenate(0, prepend_tensor) 135 data = ds.GeneratorDataset(gen, column_names=["col"]) 136 137 data = data.map(operations=concatenate_op, input_columns=["col"]) 138 with pytest.raises(RuntimeError) as error_info: 139 for _ in data: 140 pass 141 assert "only 1D input supported" in str(error_info.value) 142 143 144def test_concatenate_op_wrong_axis(): 145 with pytest.raises(ValueError) as error_info: 146 data_trans.Concatenate(2) 147 assert "only 1D concatenation supported." in str(error_info.value) 148 149 150def test_concatenate_op_negative_axis(): 151 def gen(): 152 yield (np.array([5., 6., 7., 8.], dtype=np.float),) 153 154 prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) 155 append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) 156 data = ds.GeneratorDataset(gen, column_names=["col"]) 157 concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor) 158 data = data.map(operations=concatenate_op, input_columns=["col"]) 159 expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, 160 11., 12.]) 161 for data_row in data.create_tuple_iterator(output_numpy=True): 162 np.testing.assert_array_equal(data_row[0], expected) 163 164 165def test_concatenate_op_incorrect_input_dim(): 166 prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') 167 168 with pytest.raises(ValueError) as error_info: 169 data_trans.Concatenate(0, prepend_tensor) 170 assert "can only prepend 1D arrays." in str(error_info.value) 171 172 173if __name__ == "__main__": 174 test_concatenate_op_all() 175 test_concatenate_op_none() 176 test_concatenate_op_string() 177 test_concatenate_op_multi_input_string() 178 test_concatenate_op_multi_input_numeric() 179 test_concatenate_op_type_mismatch() 180 test_concatenate_op_type_mismatch2() 181 test_concatenate_op_incorrect_dim() 182 test_concatenate_op_negative_axis() 183 test_concatenate_op_wrong_axis() 184 test_concatenate_op_incorrect_input_dim() 185