• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing Ngram in mindspore.dataset
17"""
18import numpy as np
19import mindspore.dataset as ds
20import mindspore.dataset.text as text
21
22
23def test_ngram_callable():
24    """
25    Test ngram op is callable
26    """
27    op = text.Ngram(2, separator="-")
28
29    input1 = " WildRose Country"
30    input1 = np.array(input1.split(" "), dtype='S')
31    expect1 = ['-WildRose', 'WildRose-Country']
32    result1 = op(input1)
33    assert np.array_equal(result1, expect1)
34
35    input2 = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"]
36    expect2 = ["WildRose Country-Canada's Ocean Playground", "Canada's Ocean Playground-Land of Living Skies"]
37    result2 = op(input2)
38    assert np.array_equal(result2, expect2)
39
40
41def test_multiple_ngrams():
42    """ test n-gram where n is a list of integers"""
43    plates_mottos = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"]
44    n_gram_mottos = []
45    n_gram_mottos.append(
46        ['WildRose', 'Country', '_ WildRose', 'WildRose Country', 'Country _', '_ _ WildRose', '_ WildRose Country',
47         'WildRose Country _', 'Country _ _'])
48    n_gram_mottos.append(
49        ["Canada's", 'Ocean', 'Playground', "_ Canada's", "Canada's Ocean", 'Ocean Playground', 'Playground _',
50         "_ _ Canada's", "_ Canada's Ocean", "Canada's Ocean Playground", 'Ocean Playground _', 'Playground _ _'])
51    n_gram_mottos.append(
52        ['Land', 'of', 'Living', 'Skies', '_ Land', 'Land of', 'of Living', 'Living Skies', 'Skies _', '_ _ Land',
53         '_ Land of', 'Land of Living', 'of Living Skies', 'Living Skies _', 'Skies _ _'])
54
55    def gen(texts):
56        for line in texts:
57            yield (np.array(line.split(" "), dtype='S'),)
58
59    dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
60    dataset = dataset.map(operations=text.Ngram([1, 2, 3], ("_", 2), ("_", 2), " "), input_columns="text")
61
62    i = 0
63    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
64        assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i]
65        i += 1
66
67
68def test_simple_ngram():
69    """ test simple gram with only one n value"""
70    plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies",
71                     "Birthplace of the Confederation"]
72    n_gram_mottos = [[""]]
73    n_gram_mottos.append(["Yours to Discover"])
74    n_gram_mottos.append(['Land of Living', 'of Living Skies'])
75    n_gram_mottos.append(['Birthplace of the', 'of the Confederation'])
76
77    def gen(texts):
78        for line in texts:
79            yield (np.array(line.split(" "), dtype='S'),)
80
81    dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
82    dataset = dataset.map(operations=text.Ngram(3, separator=" "), input_columns="text")
83
84    i = 0
85    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
86        assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i], i
87        i += 1
88
89
90def test_corner_cases():
91    """ testing various corner cases and exceptions"""
92
93    def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
94        def gen(texts):
95            yield (np.array(texts.split(" "), dtype='S'),)
96
97        try:
98            dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
99            dataset = dataset.map(operations=text.Ngram(n, l_pad, r_pad, separator=sep), input_columns=["text"])
100            for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
101                return [d.decode("utf8") for d in data["text"]]
102        except (ValueError, TypeError) as e:
103            return str(e)
104
105    # test tensor length smaller than n
106    assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""]
107    # test empty separator
108    assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia']
109    # test separator with longer length
110    assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia']
111    # test left pad != right pad
112    assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State']
113    # test invalid n
114    assert "gram[1] with value [1] is not of type [<class 'int'>]" in test_config("Yours to Discover", [1, [1]])
115    assert "n needs to be a non-empty list" in test_config("Yours to Discover", [])
116    # test invalid pad
117    assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1))
118    assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts"))
119    # test 0 as in valid input
120    assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0)
121    assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0])
122    assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0])
123
124
125if __name__ == '__main__':
126    test_ngram_callable()
127    test_multiple_ngrams()
128    test_simple_ngram()
129    test_corner_cases()
130