1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18import mindspore.dataset as ds 19import mindspore.dataset.text as text 20import mindspore.common.dtype as mstype 21from mindspore import log as logger 22 23# this file contains "home is behind the world head" each word is 1 line 24DATA_FILE = "../data/dataset/testVocab/words.txt" 25VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" 26SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt" 27 28 29def test_lookup_callable(): 30 """ 31 Test lookup is callable 32 """ 33 logger.info("test_lookup_callable") 34 vocab = text.Vocab.from_list(['深', '圳', '欢', '迎', '您']) 35 lookup = text.Lookup(vocab) 36 word = "迎" 37 assert lookup(word) == 3 38 39def test_from_list_tutorial(): 40 vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True) 41 lookup = text.Lookup(vocab, "<unk>") 42 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 43 data = data.map(operations=lookup, input_columns=["text"]) 44 ind = 0 45 res = [2, 1, 4, 5, 6, 7] 46 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 47 assert d["text"] == res[ind], ind 48 ind += 1 49 50 51def test_from_file_tutorial(): 52 vocab = text.Vocab.from_file(VOCAB_FILE, ",", None, ["<pad>", "<unk>"], True) 53 lookup = text.Lookup(vocab) 54 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 55 data = data.map(operations=lookup, input_columns=["text"]) 56 ind = 0 57 res = [10, 11, 12, 15, 13, 14] 58 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 59 assert d["text"] == res[ind], ind 60 ind += 1 61 62 63def test_from_dict_tutorial(): 64 vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6}) 65 lookup = text.Lookup(vocab, "<unk>") # any unknown token will be mapped to the id of <unk> 66 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 67 data = data.map(operations=lookup, input_columns=["text"]) 68 res = [3, 6, 2, 4, 5, 6] 69 ind = 0 70 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 71 assert d["text"] == res[ind], ind 72 ind += 1 73 74 75def test_from_dict_exception(): 76 try: 77 vocab = text.Vocab.from_dict({"home": -1, "behind": 0}) 78 if not vocab: 79 raise ValueError("Vocab is None") 80 except ValueError as e: 81 assert "is not within the required interval" in str(e) 82 83 84def test_from_list(): 85 def gen(texts): 86 for word in texts.split(" "): 87 yield (np.array(word, dtype='S'),) 88 89 def test_config(lookup_str, vocab_input, special_tokens, special_first, unknown_token): 90 try: 91 vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first) 92 data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) 93 data = data.map(operations=text.Lookup(vocab, unknown_token), input_columns=["text"]) 94 res = [] 95 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 96 res.append(d["text"].item()) 97 return res 98 except (ValueError, RuntimeError, TypeError) as e: 99 return str(e) 100 101 # test basic default config, special_token=None, unknown_token=None 102 assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2] 103 # test normal operations 104 assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1] 105 assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4] 106 assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True, "w1") == [2, 1, 0] 107 assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False, "w1") == [2, 1, 0] 108 # test unknown token lookup 109 assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["<pad>", "<unk>"], True, "<unk>") == [2, 1, 4, 1] 110 assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["<pad>", "<unk>"], False, "<unk>") == [0, 4, 2, 4] 111 112 # test exceptions 113 assert "doesn't exist in vocab." in test_config("un1", ["w1"], [], False, "unk") 114 assert "doesn't exist in vocab and no unknown token is specified." in test_config("un1", ["w1"], [], False, None) 115 assert "doesn't exist in vocab" in test_config("un1", ["w1"], [], False, None) 116 assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True, "w1") 117 assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True, "w1") 118 assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True, "w1") 119 assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) 120 121 122def test_from_list_lookup_empty_string(): 123 # "" is a valid word in vocab, which can be looked up by LookupOp 124 vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) 125 lookup = text.Lookup(vocab, "") 126 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 127 data = data.map(operations=lookup, input_columns=["text"]) 128 ind = 0 129 res = [2, 1, 4, 5, 6, 7] 130 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 131 assert d["text"] == res[ind], ind 132 ind += 1 133 134 # unknown_token of LookUp is None, it will convert to std::nullopt in C++, 135 # so it has nothing to do with "" in vocab and C++ will skip looking up unknown_token 136 vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) 137 lookup = text.Lookup(vocab) 138 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 139 data = data.map(operations=lookup, input_columns=["text"]) 140 try: 141 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 142 pass 143 except RuntimeError as e: 144 assert "token: \"is\" doesn't exist in vocab and no unknown token is specified" in str(e) 145 146 147def test_from_file(): 148 def gen(texts): 149 for word in texts.split(" "): 150 yield (np.array(word, dtype='S'),) 151 152 def test_config(lookup_str, vocab_size, special_tokens, special_first): 153 try: 154 vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, vocab_size=vocab_size, special_tokens=special_tokens, 155 special_first=special_first) 156 data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) 157 data = data.map(operations=text.Lookup(vocab, "s2"), input_columns=["text"]) 158 res = [] 159 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 160 res.append(d["text"].item()) 161 return res 162 except ValueError as e: 163 return str(e) 164 165 # test special tokens are prepended 166 assert test_config("w1 w2 w3 s1 s2 s3", None, ["s1", "s2", "s3"], True) == [3, 4, 5, 0, 1, 2] 167 # test special tokens are appended 168 assert test_config("w1 w2 w3 s1 s2 s3", None, ["s1", "s2", "s3"], False) == [0, 1, 2, 8, 9, 10] 169 # test special tokens are prepended when not all words in file are used 170 assert test_config("w1 w2 w3 s1 s2 s3", 3, ["s1", "s2", "s3"], False) == [0, 1, 2, 3, 4, 5] 171 # text exception special_words contains duplicate words 172 assert "special_tokens contains duplicate" in test_config("w1", None, ["s1", "s1"], True) 173 # test exception when vocab_size is negative 174 assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True) 175 assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True) 176 177 178def test_lookup_cast_type(): 179 def gen(texts): 180 for word in texts.split(" "): 181 yield (np.array(word, dtype='S'),) 182 183 def test_config(lookup_str, data_type=None): 184 try: 185 vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True) 186 data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) 187 # if data_type is None, test the default value of data_type 188 op = text.Lookup(vocab, "<unk>") if data_type is None else text.Lookup(vocab, "<unk>", data_type) 189 data = data.map(operations=op, input_columns=["text"]) 190 res = [] 191 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 192 res.append(d["text"]) 193 return res[0].dtype 194 except (ValueError, RuntimeError, TypeError) as e: 195 return str(e) 196 197 # test result is correct 198 assert test_config("w1", mstype.int8) == np.dtype("int8") 199 assert test_config("w2", mstype.int32) == np.dtype("int32") 200 assert test_config("w3", mstype.int64) == np.dtype("int64") 201 assert test_config("unk", mstype.float32) != np.dtype("int32") 202 assert test_config("unk") == np.dtype("int32") 203 # test exception, data_type isn't the correct type 204 assert "tldr is not of type [<class 'mindspore._c_expression.typing.Type'>]" in test_config("unk", "tldr") 205 assert "Lookup : The parameter data_type must be numeric including bool." in \ 206 test_config("w1", mstype.string) 207 208 209if __name__ == '__main__': 210 test_lookup_callable() 211 test_from_dict_exception() 212 test_from_list_tutorial() 213 test_from_file_tutorial() 214 test_from_dict_tutorial() 215 test_from_list() 216 test_from_file() 217 test_lookup_cast_type() 218