1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing SlidingWindow in mindspore.dataset 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.text as text 22 23 24def test_sliding_window_callable(): 25 """ 26 Test sliding window op is callable 27 """ 28 op = text.SlidingWindow(2, 0) 29 30 input1 = ["大", "家", "早", "上", "好"] 31 expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) 32 result = op(input1) 33 assert np.array_equal(result, expect) 34 35 # test 2D input 36 input2 = [["大", "家", "早", "上", "好"]] 37 with pytest.raises(RuntimeError) as info: 38 _ = op(input2) 39 assert "SlidingWindow: SlidingWindow supports 1D input only for now." in str(info.value) 40 41 # test input multiple tensors 42 with pytest.raises(RuntimeError) as info: 43 _ = op(input1, input1) 44 assert "The op is OneToOne, can only accept one tensor as input." in str(info.value) 45 46 47def test_sliding_window_string(): 48 """ test sliding_window with string type""" 49 inputs = [["大", "家", "早", "上", "好"]] 50 expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) 51 52 dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) 53 dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"]) 54 55 result = [] 56 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 57 for i in range(data['text'].shape[0]): 58 result.append([]) 59 for j in range(data['text'].shape[1]): 60 result[i].append(data['text'][i][j].decode('utf8')) 61 result = np.array(result) 62 np.testing.assert_array_equal(result, expect) 63 64 65def test_sliding_window_number(): 66 inputs = [1] 67 expect = np.array([[1]]) 68 69 def gen(nums): 70 yield (np.array(nums),) 71 72 dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"]) 73 dataset = dataset.map(operations=text.SlidingWindow(1, -1), input_columns=["number"]) 74 75 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 76 np.testing.assert_array_equal(data['number'], expect) 77 78 79def test_sliding_window_big_width(): 80 inputs = [[1, 2, 3, 4, 5]] 81 expect = np.array([]) 82 83 dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False) 84 dataset = dataset.map(operations=text.SlidingWindow(30, 0), input_columns=["number"]) 85 86 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 87 np.testing.assert_array_equal(data['number'], expect) 88 89 90def test_sliding_window_exception(): 91 try: 92 _ = text.SlidingWindow(0, 0) 93 assert False 94 except ValueError: 95 pass 96 97 try: 98 _ = text.SlidingWindow("1", 0) 99 assert False 100 except TypeError: 101 pass 102 103 try: 104 _ = text.SlidingWindow(1, "0") 105 assert False 106 except TypeError: 107 pass 108 109 try: 110 inputs = [[1, 2, 3, 4, 5]] 111 dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) 112 dataset = dataset.map(operations=text.SlidingWindow(3, -100), input_columns=["text"]) 113 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 114 pass 115 assert False 116 except RuntimeError as e: 117 assert "axis supports 0 or -1 only for now." in str(e) 118 119 try: 120 inputs = ["aa", "bb", "cc"] 121 dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) 122 dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"]) 123 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 124 pass 125 assert False 126 except RuntimeError as e: 127 assert "SlidingWindow supports 1D input only for now." in str(e) 128 129 130if __name__ == '__main__': 131 test_sliding_window_callable() 132 test_sliding_window_string() 133 test_sliding_window_number() 134 test_sliding_window_big_width() 135 test_sliding_window_exception() 136