1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing Ngram in mindspore.dataset 17""" 18import numpy as np 19import mindspore.dataset as ds 20import mindspore.dataset.text as text 21 22 23def test_ngram_callable(): 24 """ 25 Test ngram op is callable 26 """ 27 op = text.Ngram(2, separator="-") 28 29 input1 = " WildRose Country" 30 input1 = np.array(input1.split(" "), dtype='S') 31 expect1 = ['-WildRose', 'WildRose-Country'] 32 result1 = op(input1) 33 assert np.array_equal(result1, expect1) 34 35 input2 = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"] 36 expect2 = ["WildRose Country-Canada's Ocean Playground", "Canada's Ocean Playground-Land of Living Skies"] 37 result2 = op(input2) 38 assert np.array_equal(result2, expect2) 39 40 41def test_multiple_ngrams(): 42 """ test n-gram where n is a list of integers""" 43 plates_mottos = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"] 44 n_gram_mottos = [] 45 n_gram_mottos.append( 46 ['WildRose', 'Country', '_ WildRose', 'WildRose Country', 'Country _', '_ _ WildRose', '_ WildRose Country', 47 'WildRose Country _', 'Country _ _']) 48 n_gram_mottos.append( 49 ["Canada's", 'Ocean', 'Playground', "_ Canada's", "Canada's Ocean", 'Ocean Playground', 'Playground _', 50 "_ _ Canada's", "_ Canada's Ocean", "Canada's Ocean Playground", 'Ocean Playground _', 'Playground _ _']) 51 n_gram_mottos.append( 52 ['Land', 'of', 'Living', 'Skies', '_ Land', 'Land of', 'of Living', 'Living Skies', 'Skies _', '_ _ Land', 53 '_ Land of', 'Land of Living', 'of Living Skies', 'Living Skies _', 'Skies _ _']) 54 55 def gen(texts): 56 for line in texts: 57 yield (np.array(line.split(" "), dtype='S'),) 58 59 dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) 60 dataset = dataset.map(operations=text.Ngram([1, 2, 3], ("_", 2), ("_", 2), " "), input_columns="text") 61 62 i = 0 63 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 64 assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i] 65 i += 1 66 67 68def test_simple_ngram(): 69 """ test simple gram with only one n value""" 70 plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies", 71 "Birthplace of the Confederation"] 72 n_gram_mottos = [[""]] 73 n_gram_mottos.append(["Yours to Discover"]) 74 n_gram_mottos.append(['Land of Living', 'of Living Skies']) 75 n_gram_mottos.append(['Birthplace of the', 'of the Confederation']) 76 77 def gen(texts): 78 for line in texts: 79 yield (np.array(line.split(" "), dtype='S'),) 80 81 dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) 82 dataset = dataset.map(operations=text.Ngram(3, separator=" "), input_columns="text") 83 84 i = 0 85 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 86 assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i], i 87 i += 1 88 89 90def test_corner_cases(): 91 """ testing various corner cases and exceptions""" 92 93 def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): 94 def gen(texts): 95 yield (np.array(texts.split(" "), dtype='S'),) 96 97 try: 98 dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) 99 dataset = dataset.map(operations=text.Ngram(n, l_pad, r_pad, separator=sep), input_columns=["text"]) 100 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 101 return [d.decode("utf8") for d in data["text"]] 102 except (ValueError, TypeError) as e: 103 return str(e) 104 105 # test tensor length smaller than n 106 assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""] 107 # test empty separator 108 assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia'] 109 # test separator with longer length 110 assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia'] 111 # test left pad != right pad 112 assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State'] 113 # test invalid n 114 assert "gram[1] with value [1] is not of type [<class 'int'>]" in test_config("Yours to Discover", [1, [1]]) 115 assert "n needs to be a non-empty list" in test_config("Yours to Discover", []) 116 # test invalid pad 117 assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1)) 118 assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts")) 119 # test 0 as in valid input 120 assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0) 121 assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0]) 122 assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0]) 123 124 125if __name__ == '__main__': 126 test_ngram_callable() 127 test_multiple_ngrams() 128 test_simple_ngram() 129 test_corner_cases() 130