1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <fstream>
18 #include <unordered_map>
19 #include <unordered_set>
20
21 #include "minddata/dataset/include/dataset/text.h"
22 #include "minddata/dataset/util/log_adapter.h"
23 #include "minddata/dataset/util/status.h"
24 #include "utils/file_utils.h"
25
26 namespace mindspore {
27 namespace dataset {
Vocab(std::unordered_map<WordType,WordIdType> word2id)28 Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
29
TokensToIds(const WordType & word) const30 WordIdType Vocab::TokensToIds(const WordType &word) const {
31 auto itr = word2id_.find(word);
32 return itr == word2id_.end() ? kNoTokenExists : itr->second;
33 }
34
TokensToIds(const std::vector<WordType> & words) const35 std::vector<WordIdType> Vocab::TokensToIds(const std::vector<WordType> &words) const {
36 std::vector<WordIdType> ids;
37 (void)std::transform(words.begin(), words.end(), std::back_inserter(ids), [this](auto w) { return TokensToIds(w); });
38 return ids;
39 }
40
IdsToTokens(const WordIdType & id)41 WordType Vocab::IdsToTokens(const WordIdType &id) {
42 // lazy initialization, since I think it's not common use but waste memory
43 if (id2word_.empty()) {
44 for (const auto &[word_, id_] : word2id_) {
45 id2word_[id_] = word_;
46 }
47 }
48 auto itr = id2word_.find(id);
49 return itr == id2word_.end() ? kNoIdExists : itr->second;
50 }
51
IdsToTokens(const std::vector<WordIdType> & ids)52 std::vector<WordType> Vocab::IdsToTokens(const std::vector<WordIdType> &ids) {
53 // lazy initialization, since I think it's not common use but waste memory
54 if (id2word_.empty()) {
55 for (const auto &[word_, id_] : word2id_) {
56 id2word_[id_] = word_;
57 }
58 }
59 std::vector<WordType> words;
60 (void)std::transform(ids.begin(), ids.end(), std::back_inserter(words), [this](auto i) { return IdsToTokens(i); });
61 return words;
62 }
63
AppendWord(const std::string & word)64 void Vocab::AppendWord(const std::string &word) {
65 if (word2id_.find(word) == word2id_.end()) {
66 word2id_[word] = static_cast<WordIdType>(word2id_.size());
67 }
68 }
69
BuildFromUnorderedMap(const std::unordered_map<WordType,WordIdType> & words,std::shared_ptr<Vocab> * vocab)70 Status Vocab::BuildFromUnorderedMap(const std::unordered_map<WordType, WordIdType> &words,
71 std::shared_ptr<Vocab> *vocab) {
72 if (vocab == nullptr) {
73 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromUnorderedMap: input vocab can not be null");
74 }
75 // Validate parameters and build map
76 std::unordered_map<WordType, WordIdType> word2id;
77 for (auto p : words) {
78 if (p.second < 0) {
79 RETURN_STATUS_UNEXPECTED("from_dict: index can not be negetive, but got " + std::to_string(p.second));
80 }
81 word2id[p.first] = p.second;
82 }
83 *vocab = std::make_shared<Vocab>(std::move(word2id));
84 return Status::OK();
85 }
86
BuildFromVector(const std::vector<WordType> & words,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)87 Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vector<WordType> &special_tokens,
88 bool prepend_special, std::shared_ptr<Vocab> *vocab) {
89 if (vocab == nullptr) {
90 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromVector: input vocab can not be null");
91 }
92 std::unordered_map<WordType, WordIdType> word2id;
93
94 // if special is added in front, normal words id will start from number of special tokens
95 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
96 for (auto word : words) {
97 if (word2id.find(word) != word2id.end()) {
98 RETURN_STATUS_UNEXPECTED("from_list: word_list contains duplicate word: " + word + ".");
99 }
100 word2id[word] = word_id++;
101 }
102
103 word_id = prepend_special ? 0 : word2id.size();
104
105 for (auto special_token : special_tokens) {
106 if (word2id.find(special_token) != word2id.end()) {
107 RETURN_STATUS_UNEXPECTED(
108 "from_list: "
109 "special_tokens and word_list contain duplicate word: " +
110 special_token + ".");
111 }
112 word2id[special_token] = word_id++;
113 }
114
115 *vocab = std::make_shared<Vocab>(std::move(word2id));
116 return Status::OK();
117 }
118
BuildFromFile(const std::string & path,const std::string & delimiter,int32_t vocab_size,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)119 Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
120 const std::vector<WordType> &special_tokens, bool prepend_special,
121 std::shared_ptr<Vocab> *vocab) {
122 if (vocab == nullptr) {
123 RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFile: input vocab can not be null");
124 }
125 // Validate parameters
126 auto realpath = FileUtils::GetRealPath(path.c_str());
127 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + path);
128
129 CHECK_FAIL_RETURN_UNEXPECTED(
130 !(vocab_size < 0 && vocab_size != -1),
131 "from_file: vocab_size should be either -1 or positive integer, but got: " + std::to_string(vocab_size));
132
133 std::string duplicate_sp;
134 for (const WordType &sp : special_tokens) {
135 if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
136 if (duplicate_sp.find(sp) == std::string::npos) {
137 duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
138 }
139 }
140 }
141 CHECK_FAIL_RETURN_UNEXPECTED(duplicate_sp.empty(),
142 "from_file: special_tokens contains duplicate word: " + duplicate_sp);
143
144 std::unordered_set<std::string> specials;
145 // used to check that words in file don't contain any special token that already exists
146 for (auto word : special_tokens) {
147 (void)specials.insert(word);
148 }
149 WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
150 std::unordered_map<WordType, WordIdType> word2id;
151
152 std::fstream handle(realpath.value(), std::ios::in);
153 CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open: " + realpath.value());
154
155 std::string word;
156 while (std::getline(handle, word)) {
157 if (!delimiter.empty()) {
158 // if delimiter is not found, find_first_of would return std::string::npos which is -1
159 word = word.substr(0, word.find_first_of(delimiter));
160 }
161 if (word2id.find(word) != word2id.end()) {
162 handle.close();
163 RETURN_STATUS_UNEXPECTED("from_file: word_list contains duplicate word:" + word);
164 }
165 if (specials.find(word) != specials.end()) {
166 handle.close();
167 RETURN_STATUS_UNEXPECTED("from_file: special_tokens and word_list contain duplicate word:" + word);
168 }
169 word2id[word] = word_id++;
170 // break if enough row is read, if vocab_size is smaller than 0
171 if (word2id.size() == vocab_size) {
172 break;
173 }
174 }
175
176 handle.close();
177 word_id = prepend_special ? 0 : word2id.size();
178
179 for (auto special_token : special_tokens) {
180 word2id[special_token] = word_id++;
181 }
182
183 *vocab = std::make_shared<Vocab>(std::move(word2id));
184 return Status::OK();
185 }
186
187 const WordIdType Vocab::kNoTokenExists = -1;
188 const WordType Vocab::kNoIdExists = std::string();
189
190 } // namespace dataset
191 } // namespace mindspore
192