• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing SlidingWindow in mindspore.dataset
17"""
18import numpy as np
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.text as text
22
23
24def test_sliding_window_callable():
25    """
26    Test sliding window op is callable
27    """
28    op = text.SlidingWindow(2, 0)
29
30    input1 = ["大", "家", "早", "上", "好"]
31    expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
32    result = op(input1)
33    assert np.array_equal(result, expect)
34
35    # test 2D input
36    input2 = [["大", "家", "早", "上", "好"]]
37    with pytest.raises(RuntimeError) as info:
38        _ = op(input2)
39    assert "SlidingWindow: SlidingWindow supports 1D input only for now." in str(info.value)
40
41    # test input multiple tensors
42    with pytest.raises(RuntimeError) as info:
43        _ = op(input1, input1)
44    assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
45
46
47def test_sliding_window_string():
48    """ test sliding_window with string type"""
49    inputs = [["大", "家", "早", "上", "好"]]
50    expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
51
52    dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
53    dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"])
54
55    result = []
56    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
57        for i in range(data['text'].shape[0]):
58            result.append([])
59            for j in range(data['text'].shape[1]):
60                result[i].append(data['text'][i][j].decode('utf8'))
61        result = np.array(result)
62    np.testing.assert_array_equal(result, expect)
63
64
65def test_sliding_window_number():
66    inputs = [1]
67    expect = np.array([[1]])
68
69    def gen(nums):
70        yield (np.array(nums),)
71
72    dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
73    dataset = dataset.map(operations=text.SlidingWindow(1, -1), input_columns=["number"])
74
75    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
76        np.testing.assert_array_equal(data['number'], expect)
77
78
79def test_sliding_window_big_width():
80    inputs = [[1, 2, 3, 4, 5]]
81    expect = np.array([])
82
83    dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
84    dataset = dataset.map(operations=text.SlidingWindow(30, 0), input_columns=["number"])
85
86    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
87        np.testing.assert_array_equal(data['number'], expect)
88
89
90def test_sliding_window_exception():
91    try:
92        _ = text.SlidingWindow(0, 0)
93        assert False
94    except ValueError:
95        pass
96
97    try:
98        _ = text.SlidingWindow("1", 0)
99        assert False
100    except TypeError:
101        pass
102
103    try:
104        _ = text.SlidingWindow(1, "0")
105        assert False
106    except TypeError:
107        pass
108
109    try:
110        inputs = [[1, 2, 3, 4, 5]]
111        dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
112        dataset = dataset.map(operations=text.SlidingWindow(3, -100), input_columns=["text"])
113        for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
114            pass
115        assert False
116    except RuntimeError as e:
117        assert "axis supports 0 or -1 only for now." in str(e)
118
119    try:
120        inputs = ["aa", "bb", "cc"]
121        dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
122        dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"])
123        for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
124            pass
125        assert False
126    except RuntimeError as e:
127        assert "SlidingWindow supports 1D input only for now." in str(e)
128
129
130if __name__ == '__main__':
131    test_sliding_window_callable()
132    test_sliding_window_string()
133    test_sliding_window_number()
134    test_sliding_window_big_width()
135    test_sliding_window_exception()
136