• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing concatenate op
17"""
18
19import numpy as np
20import pytest
21
22import mindspore.dataset as ds
23import mindspore.dataset.transforms.c_transforms as data_trans
24
25
26def test_concatenate_op_all():
27    def gen():
28        yield (np.array([5., 6., 7., 8.], dtype=np.float),)
29
30    prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
31    append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
32    data = ds.GeneratorDataset(gen, column_names=["col"])
33    concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor)
34    data = data.map(operations=concatenate_op, input_columns=["col"])
35    expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
36                         11., 12.])
37    for data_row in data.create_tuple_iterator(output_numpy=True):
38        np.testing.assert_array_equal(data_row[0], expected)
39
40
41def test_concatenate_op_none():
42    def gen():
43        yield (np.array([5., 6., 7., 8.], dtype=np.float),)
44
45    data = ds.GeneratorDataset(gen, column_names=["col"])
46    concatenate_op = data_trans.Concatenate()
47
48    data = data.map(operations=concatenate_op, input_columns=["col"])
49    for data_row in data.create_tuple_iterator(output_numpy=True):
50        np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float))
51
52
53def test_concatenate_op_string():
54    def gen():
55        yield (np.array(["ss", "ad"], dtype='S'),)
56
57    prepend_tensor = np.array(["dw", "df"], dtype='S')
58    append_tensor = np.array(["dwsdf", "df"], dtype='S')
59    data = ds.GeneratorDataset(gen, column_names=["col"])
60    concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor)
61
62    data = data.map(operations=concatenate_op, input_columns=["col"])
63    expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S')
64    for data_row in data.create_tuple_iterator(output_numpy=True):
65        np.testing.assert_array_equal(data_row[0], expected)
66
67
68def test_concatenate_op_multi_input_string():
69    prepend_tensor = np.array(["dw", "df"], dtype='S')
70    append_tensor = np.array(["dwsdf", "df"], dtype='S')
71
72    data = ([["1", "2", "d"]], [["3", "4", "e"]])
73    data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"])
74
75    concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor)
76
77    data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
78                    output_columns=["out1"])
79    expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S')
80    for data_row in data.create_tuple_iterator(output_numpy=True):
81        np.testing.assert_array_equal(data_row[0], expected)
82
83
84def test_concatenate_op_multi_input_numeric():
85    prepend_tensor = np.array([3, 5])
86
87    data = ([[1, 2]], [[3, 4]])
88    data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"])
89
90    concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor)
91
92    data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
93                    output_columns=["out1"])
94    expected = np.array([3, 5, 1, 2, 3, 4])
95    for data_row in data.create_tuple_iterator(output_numpy=True):
96        np.testing.assert_array_equal(data_row[0], expected)
97
98
99def test_concatenate_op_type_mismatch():
100    def gen():
101        yield (np.array([3, 4], dtype=np.float),)
102
103    prepend_tensor = np.array(["ss", "ad"], dtype='S')
104    data = ds.GeneratorDataset(gen, column_names=["col"])
105    concatenate_op = data_trans.Concatenate(0, prepend_tensor)
106
107    data = data.map(operations=concatenate_op, input_columns=["col"])
108    with pytest.raises(RuntimeError) as error_info:
109        for _ in data:
110            pass
111    assert "input datatype does not match" in str(error_info.value)
112
113
114def test_concatenate_op_type_mismatch2():
115    def gen():
116        yield (np.array(["ss", "ad"], dtype='S'),)
117
118    prepend_tensor = np.array([3, 5], dtype=np.float)
119    data = ds.GeneratorDataset(gen, column_names=["col"])
120    concatenate_op = data_trans.Concatenate(0, prepend_tensor)
121
122    data = data.map(operations=concatenate_op, input_columns=["col"])
123    with pytest.raises(RuntimeError) as error_info:
124        for _ in data:
125            pass
126    assert "input datatype does not match" in str(error_info.value)
127
128
129def test_concatenate_op_incorrect_dim():
130    def gen():
131        yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),)
132
133    prepend_tensor = np.array(["ss", "ss"], dtype='S')
134    concatenate_op = data_trans.Concatenate(0, prepend_tensor)
135    data = ds.GeneratorDataset(gen, column_names=["col"])
136
137    data = data.map(operations=concatenate_op, input_columns=["col"])
138    with pytest.raises(RuntimeError) as error_info:
139        for _ in data:
140            pass
141    assert "only 1D input supported" in str(error_info.value)
142
143
144def test_concatenate_op_wrong_axis():
145    with pytest.raises(ValueError) as error_info:
146        data_trans.Concatenate(2)
147    assert "only 1D concatenation supported." in str(error_info.value)
148
149
150def test_concatenate_op_negative_axis():
151    def gen():
152        yield (np.array([5., 6., 7., 8.], dtype=np.float),)
153
154    prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
155    append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
156    data = ds.GeneratorDataset(gen, column_names=["col"])
157    concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor)
158    data = data.map(operations=concatenate_op, input_columns=["col"])
159    expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
160                         11., 12.])
161    for data_row in data.create_tuple_iterator(output_numpy=True):
162        np.testing.assert_array_equal(data_row[0], expected)
163
164
165def test_concatenate_op_incorrect_input_dim():
166    prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
167
168    with pytest.raises(ValueError) as error_info:
169        data_trans.Concatenate(0, prepend_tensor)
170    assert "can only prepend 1D arrays." in str(error_info.value)
171
172
173if __name__ == "__main__":
174    test_concatenate_op_all()
175    test_concatenate_op_none()
176    test_concatenate_op_string()
177    test_concatenate_op_multi_input_string()
178    test_concatenate_op_multi_input_numeric()
179    test_concatenate_op_type_mismatch()
180    test_concatenate_op_type_mismatch2()
181    test_concatenate_op_incorrect_dim()
182    test_concatenate_op_negative_axis()
183    test_concatenate_op_wrong_axis()
184    test_concatenate_op_incorrect_input_dim()
185