1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/text/vocab.h"
18
19 #include <fstream>
20 #include <unordered_set>
21 #include <unordered_map>
22 #include <utility>
23 #include <algorithm>
24
25 #include "utils/file_utils.h"
26 #ifndef ENABLE_ANDROID
27 #include "utils/log_adapter.h"
28 #else
29 #include "mindspore/lite/src/common/log_adapter.h"
30 #endif
31
32 namespace mindspore {
33 namespace dataset {
Vocab(std::unordered_map<WordType,WordIdType> word2id)34 Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
35
Lookup(const WordType & word) const36 WordIdType Vocab::Lookup(const WordType &word) const {
37 auto itr = word2id_.find(word);
38 return itr == word2id_.end() ? kNoTokenExists : itr->second;
39 }
40
41 #ifdef ENABLE_PYTHON
BuildFromPyList(const py::list & words,const py::list & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)42 Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
43 std::shared_ptr<Vocab> *vocab) {
44 if (vocab == nullptr) {
45 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromPyList: input vocab can not be null");
46 }
47 // check of duplication on both words and special_tokens will be performed in python
48 // special_tokens and words both need to be unique, and shouldn't overlap
49 std::unordered_map<WordType, WordIdType> word2id;
50 // if special is added in front, normal words id will start from number of special tokens
51 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
52
53 for (auto word : words) {
54 word2id[py::str(word)] = word_id++;
55 }
56
57 word_id = prepend_special ? 0 : word2id.size();
58
59 for (auto special_token : special_tokens) {
60 word2id[py::str(special_token)] = word_id++;
61 }
62
63 *vocab = std::make_shared<Vocab>(std::move(word2id));
64 return Status::OK();
65 }
66
BuildFromPyDict(const py::dict & words,std::shared_ptr<Vocab> * vocab)67 Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) {
68 if (vocab == nullptr) {
69 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromPyDict: input vocab can not be null");
70 }
71 std::unordered_map<WordType, WordIdType> word2id;
72 for (auto p : words) {
73 word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second);
74 }
75 *vocab = std::make_shared<Vocab>(std::move(word2id));
76 return Status::OK();
77 }
78 #endif
79
append_word(const std::string & word)80 void Vocab::append_word(const std::string &word) {
81 if (word2id_.find(word) == word2id_.end()) {
82 word2id_[word] = word2id_.size();
83 }
84 }
85
BuildFromUnorderedMap(const std::unordered_map<WordType,WordIdType> & words,std::shared_ptr<Vocab> * vocab)86 Status Vocab::BuildFromUnorderedMap(const std::unordered_map<WordType, WordIdType> &words,
87 std::shared_ptr<Vocab> *vocab) {
88 if (vocab == nullptr) {
89 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromUnorderedMap: input vocab can not be null");
90 }
91 // Validate parameters and build map
92 std::unordered_map<WordType, WordIdType> word2id;
93 for (auto p : words) {
94 if (p.second < 0) {
95 RETURN_STATUS_UNEXPECTED("from_dict: index can not be negetive, but got " + std::to_string(p.second));
96 }
97 word2id[p.first] = p.second;
98 }
99 *vocab = std::make_shared<Vocab>(std::move(word2id));
100 return Status::OK();
101 }
102
BuildFromVector(const std::vector<WordType> & words,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)103 Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vector<WordType> &special_tokens,
104 bool prepend_special, std::shared_ptr<Vocab> *vocab) {
105 if (vocab == nullptr) {
106 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromVector: input vocab can not be null");
107 }
108 std::unordered_map<WordType, WordIdType> word2id;
109
110 // if special is added in front, normal words id will start from number of special tokens
111 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
112 for (auto word : words) {
113 if (word2id.find(word) != word2id.end()) {
114 RETURN_STATUS_UNEXPECTED("from_list: word_list contains duplicate word: " + word + ".");
115 }
116 word2id[word] = word_id++;
117 }
118
119 word_id = prepend_special ? 0 : word2id.size();
120
121 for (auto special_token : special_tokens) {
122 if (word2id.find(special_token) != word2id.end()) {
123 RETURN_STATUS_UNEXPECTED(
124 "from_list: "
125 "special_tokens and word_list contain duplicate word: " +
126 special_token + ".");
127 }
128 word2id[special_token] = word_id++;
129 }
130
131 *vocab = std::make_shared<Vocab>(std::move(word2id));
132 return Status::OK();
133 }
134
BuildFromFileCpp(const std::string & path,const std::string & delimiter,int32_t vocab_size,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)135 Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delimiter, int32_t vocab_size,
136 const std::vector<WordType> &special_tokens, bool prepend_special,
137 std::shared_ptr<Vocab> *vocab) {
138 if (vocab == nullptr) {
139 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFileCpp: input vocab can not be null");
140 }
141 // Validate parameters
142 auto realpath = FileUtils::GetRealPath(path.data());
143 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + path);
144
145 CHECK_FAIL_RETURN_UNEXPECTED(
146 !(vocab_size < 0 && vocab_size != -1),
147 "from_file: vocab_size should be either -1 or positive integer, but got: " + std::to_string(vocab_size));
148
149 std::string duplicate_sp;
150 for (const WordType &sp : special_tokens) {
151 if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
152 if (duplicate_sp.find(sp) == std::string::npos) {
153 duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
154 }
155 }
156 }
157 CHECK_FAIL_RETURN_UNEXPECTED(duplicate_sp.empty(),
158 "from_file: special_tokens contains duplicate word: " + duplicate_sp);
159
160 std::unordered_set<std::string> specials;
161 // used to check that words in file don't contain any special token that already exists
162 for (auto word : special_tokens) {
163 specials.insert(word);
164 }
165 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
166 std::unordered_map<WordType, WordIdType> word2id;
167
168 std::fstream handle(realpath.value(), std::ios::in);
169 CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open: " + realpath.value());
170
171 std::string word;
172 while (std::getline(handle, word)) {
173 if (!delimiter.empty()) {
174 // if delimiter is not found, find_first_of would return std::string::npos which is -1
175 word = word.substr(0, word.find_first_of(delimiter));
176 }
177 CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(),
178 "from_file: word_list contains duplicate word:" + word);
179 CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(),
180 "from_file: special_tokens and word_list contain duplicate word:" + word);
181
182 word2id[word] = word_id++;
183 // break if enough row is read, if vocab_size is smaller than 0
184 if (word2id.size() == vocab_size) break;
185 }
186
187 word_id = prepend_special ? 0 : word2id.size();
188
189 for (auto special_token : special_tokens) {
190 word2id[special_token] = word_id++;
191 }
192
193 *vocab = std::make_shared<Vocab>(std::move(word2id));
194 return Status::OK();
195 }
196
BuildFromFile(const std::string & path,const std::string & delimiter,int32_t vocab_size,const py::list & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)197 Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
198 const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) {
199 if (vocab == nullptr) {
200 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFile: input vocab can not be null");
201 }
202 // python validator checks special_tokens doesn't contain any duplicate words
203 std::unordered_set<std::string> specials;
204 // used to check that words in file don't contain any special token that already exists
205 for (auto word : special_tokens) {
206 specials.insert(py::str(word));
207 }
208 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
209 std::unordered_map<WordType, WordIdType> word2id;
210
211 auto realpath = FileUtils::GetRealPath(path.data());
212 if (!realpath.has_value()) {
213 RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + path);
214 }
215
216 std::fstream handle(realpath.value(), std::ios::in);
217 CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open:" + path);
218 std::string word;
219 while (std::getline(handle, word)) {
220 if (!delimiter.empty()) {
221 // if delimiter is not found, find_first_of would return std::string::npos which is -1
222 word = word.substr(0, word.find_first_of(delimiter));
223 }
224 CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "from_file: duplicate word:" + word + ".");
225 CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(),
226 "from_file: " + word + " is already in special_tokens.");
227 word2id[word] = word_id++;
228 // break if enough row is read, if vocab_size is smaller than 0
229 if (word2id.size() == vocab_size) break;
230 }
231
232 word_id = prepend_special ? 0 : word2id.size();
233
234 for (auto special_token : special_tokens) {
235 word2id[py::str(special_token)] = word_id++;
236 }
237
238 *vocab = std::make_shared<Vocab>(std::move(word2id));
239 return Status::OK();
240 }
241
242 const WordIdType Vocab::kNoTokenExists = -1;
243
244 } // namespace dataset
245 } // namespace mindspore
246