• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/text/vocab.h"
18 
19 #include <fstream>
20 #include <unordered_set>
21 #include <unordered_map>
22 #include <utility>
23 #include <algorithm>
24 
25 #include "utils/file_utils.h"
26 #ifndef ENABLE_ANDROID
27 #include "utils/log_adapter.h"
28 #else
29 #include "mindspore/lite/src/common/log_adapter.h"
30 #endif
31 
32 namespace mindspore {
33 namespace dataset {
Vocab(std::unordered_map<WordType,WordIdType> word2id)34 Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
35 
Lookup(const WordType & word) const36 WordIdType Vocab::Lookup(const WordType &word) const {
37   auto itr = word2id_.find(word);
38   return itr == word2id_.end() ? kNoTokenExists : itr->second;
39 }
40 
41 #ifdef ENABLE_PYTHON
BuildFromPyList(const py::list & words,const py::list & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)42 Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
43                               std::shared_ptr<Vocab> *vocab) {
44   if (vocab == nullptr) {
45     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromPyList: input vocab can not be null");
46   }
47   // check of duplication on both words and special_tokens will be performed in python
48   // special_tokens and words both need to be unique, and shouldn't overlap
49   std::unordered_map<WordType, WordIdType> word2id;
50   // if special is added in front, normal words id will start from number of special tokens
51   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
52 
53   for (auto word : words) {
54     word2id[py::str(word)] = word_id++;
55   }
56 
57   word_id = prepend_special ? 0 : word2id.size();
58 
59   for (auto special_token : special_tokens) {
60     word2id[py::str(special_token)] = word_id++;
61   }
62 
63   *vocab = std::make_shared<Vocab>(std::move(word2id));
64   return Status::OK();
65 }
66 
BuildFromPyDict(const py::dict & words,std::shared_ptr<Vocab> * vocab)67 Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) {
68   if (vocab == nullptr) {
69     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromPyDict: input vocab can not be null");
70   }
71   std::unordered_map<WordType, WordIdType> word2id;
72   for (auto p : words) {
73     word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second);
74   }
75   *vocab = std::make_shared<Vocab>(std::move(word2id));
76   return Status::OK();
77 }
78 #endif
79 
append_word(const std::string & word)80 void Vocab::append_word(const std::string &word) {
81   if (word2id_.find(word) == word2id_.end()) {
82     word2id_[word] = word2id_.size();
83   }
84 }
85 
BuildFromUnorderedMap(const std::unordered_map<WordType,WordIdType> & words,std::shared_ptr<Vocab> * vocab)86 Status Vocab::BuildFromUnorderedMap(const std::unordered_map<WordType, WordIdType> &words,
87                                     std::shared_ptr<Vocab> *vocab) {
88   if (vocab == nullptr) {
89     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromUnorderedMap: input vocab can not be null");
90   }
91   // Validate parameters and build map
92   std::unordered_map<WordType, WordIdType> word2id;
93   for (auto p : words) {
94     if (p.second < 0) {
95       RETURN_STATUS_UNEXPECTED("from_dict: index can not be negetive, but got " + std::to_string(p.second));
96     }
97     word2id[p.first] = p.second;
98   }
99   *vocab = std::make_shared<Vocab>(std::move(word2id));
100   return Status::OK();
101 }
102 
BuildFromVector(const std::vector<WordType> & words,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)103 Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vector<WordType> &special_tokens,
104                               bool prepend_special, std::shared_ptr<Vocab> *vocab) {
105   if (vocab == nullptr) {
106     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromVector: input vocab can not be null");
107   }
108   std::unordered_map<WordType, WordIdType> word2id;
109 
110   // if special is added in front, normal words id will start from number of special tokens
111   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
112   for (auto word : words) {
113     if (word2id.find(word) != word2id.end()) {
114       RETURN_STATUS_UNEXPECTED("from_list: word_list contains duplicate word: " + word + ".");
115     }
116     word2id[word] = word_id++;
117   }
118 
119   word_id = prepend_special ? 0 : word2id.size();
120 
121   for (auto special_token : special_tokens) {
122     if (word2id.find(special_token) != word2id.end()) {
123       RETURN_STATUS_UNEXPECTED(
124         "from_list: "
125         "special_tokens and word_list contain duplicate word: " +
126         special_token + ".");
127     }
128     word2id[special_token] = word_id++;
129   }
130 
131   *vocab = std::make_shared<Vocab>(std::move(word2id));
132   return Status::OK();
133 }
134 
BuildFromFileCpp(const std::string & path,const std::string & delimiter,int32_t vocab_size,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)135 Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delimiter, int32_t vocab_size,
136                                const std::vector<WordType> &special_tokens, bool prepend_special,
137                                std::shared_ptr<Vocab> *vocab) {
138   if (vocab == nullptr) {
139     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFileCpp: input vocab can not be null");
140   }
141   // Validate parameters
142   auto realpath = FileUtils::GetRealPath(path.data());
143   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + path);
144 
145   CHECK_FAIL_RETURN_UNEXPECTED(
146     !(vocab_size < 0 && vocab_size != -1),
147     "from_file: vocab_size should be either -1 or positive integer, but got: " + std::to_string(vocab_size));
148 
149   std::string duplicate_sp;
150   for (const WordType &sp : special_tokens) {
151     if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
152       if (duplicate_sp.find(sp) == std::string::npos) {
153         duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
154       }
155     }
156   }
157   CHECK_FAIL_RETURN_UNEXPECTED(duplicate_sp.empty(),
158                                "from_file: special_tokens contains duplicate word: " + duplicate_sp);
159 
160   std::unordered_set<std::string> specials;
161   // used to check that words in file don't contain any special token that already exists
162   for (auto word : special_tokens) {
163     specials.insert(word);
164   }
165   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
166   std::unordered_map<WordType, WordIdType> word2id;
167 
168   std::fstream handle(realpath.value(), std::ios::in);
169   CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open: " + realpath.value());
170 
171   std::string word;
172   while (std::getline(handle, word)) {
173     if (!delimiter.empty()) {
174       // if delimiter is not found, find_first_of would return std::string::npos which is -1
175       word = word.substr(0, word.find_first_of(delimiter));
176     }
177     CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(),
178                                  "from_file: word_list contains duplicate word:" + word);
179     CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(),
180                                  "from_file: special_tokens and word_list contain duplicate word:" + word);
181 
182     word2id[word] = word_id++;
183     // break if enough row is read, if vocab_size is smaller than 0
184     if (word2id.size() == vocab_size) break;
185   }
186 
187   word_id = prepend_special ? 0 : word2id.size();
188 
189   for (auto special_token : special_tokens) {
190     word2id[special_token] = word_id++;
191   }
192 
193   *vocab = std::make_shared<Vocab>(std::move(word2id));
194   return Status::OK();
195 }
196 
BuildFromFile(const std::string & path,const std::string & delimiter,int32_t vocab_size,const py::list & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)197 Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
198                             const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) {
199   if (vocab == nullptr) {
200     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFile: input vocab can not be null");
201   }
202   // python validator checks special_tokens doesn't contain any duplicate words
203   std::unordered_set<std::string> specials;
204   // used to check that words in file don't contain any special token that already exists
205   for (auto word : special_tokens) {
206     specials.insert(py::str(word));
207   }
208   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
209   std::unordered_map<WordType, WordIdType> word2id;
210 
211   auto realpath = FileUtils::GetRealPath(path.data());
212   if (!realpath.has_value()) {
213     RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + path);
214   }
215 
216   std::fstream handle(realpath.value(), std::ios::in);
217   CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open:" + path);
218   std::string word;
219   while (std::getline(handle, word)) {
220     if (!delimiter.empty()) {
221       // if delimiter is not found, find_first_of would return std::string::npos which is -1
222       word = word.substr(0, word.find_first_of(delimiter));
223     }
224     CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "from_file: duplicate word:" + word + ".");
225     CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(),
226                                  "from_file: " + word + " is already in special_tokens.");
227     word2id[word] = word_id++;
228     // break if enough row is read, if vocab_size is smaller than 0
229     if (word2id.size() == vocab_size) break;
230   }
231 
232   word_id = prepend_special ? 0 : word2id.size();
233 
234   for (auto special_token : special_tokens) {
235     word2id[py::str(special_token)] = word_id++;
236   }
237 
238   *vocab = std::make_shared<Vocab>(std::move(word2id));
239   return Status::OK();
240 }
241 
242 const WordIdType Vocab::kNoTokenExists = -1;
243 
244 }  // namespace dataset
245 }  // namespace mindspore
246