• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <fstream>
18 #include <unordered_map>
19 #include <unordered_set>
20 
21 #include "minddata/dataset/include/dataset/text.h"
22 #include "minddata/dataset/util/log_adapter.h"
23 #include "minddata/dataset/util/status.h"
24 #include "utils/file_utils.h"
25 
26 namespace mindspore {
27 namespace dataset {
Vocab(std::unordered_map<WordType,WordIdType> word2id)28 Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
29 
TokensToIds(const WordType & word) const30 WordIdType Vocab::TokensToIds(const WordType &word) const {
31   auto itr = word2id_.find(word);
32   return itr == word2id_.end() ? kNoTokenExists : itr->second;
33 }
34 
TokensToIds(const std::vector<WordType> & words) const35 std::vector<WordIdType> Vocab::TokensToIds(const std::vector<WordType> &words) const {
36   std::vector<WordIdType> ids;
37   (void)std::transform(words.begin(), words.end(), std::back_inserter(ids), [this](auto w) { return TokensToIds(w); });
38   return ids;
39 }
40 
IdsToTokens(const WordIdType & id)41 WordType Vocab::IdsToTokens(const WordIdType &id) {
42   // lazy initialization, since I think it's not common use but waste memory
43   if (id2word_.empty()) {
44     for (const auto &[word_, id_] : word2id_) {
45       id2word_[id_] = word_;
46     }
47   }
48   auto itr = id2word_.find(id);
49   return itr == id2word_.end() ? kNoIdExists : itr->second;
50 }
51 
IdsToTokens(const std::vector<WordIdType> & ids)52 std::vector<WordType> Vocab::IdsToTokens(const std::vector<WordIdType> &ids) {
53   // lazy initialization, since I think it's not common use but waste memory
54   if (id2word_.empty()) {
55     for (const auto &[word_, id_] : word2id_) {
56       id2word_[id_] = word_;
57     }
58   }
59   std::vector<WordType> words;
60   (void)std::transform(ids.begin(), ids.end(), std::back_inserter(words), [this](auto i) { return IdsToTokens(i); });
61   return words;
62 }
63 
AppendWord(const std::string & word)64 void Vocab::AppendWord(const std::string &word) {
65   if (word2id_.find(word) == word2id_.end()) {
66     word2id_[word] = static_cast<WordIdType>(word2id_.size());
67   }
68 }
69 
BuildFromUnorderedMap(const std::unordered_map<WordType,WordIdType> & words,std::shared_ptr<Vocab> * vocab)70 Status Vocab::BuildFromUnorderedMap(const std::unordered_map<WordType, WordIdType> &words,
71                                     std::shared_ptr<Vocab> *vocab) {
72   if (vocab == nullptr) {
73     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromUnorderedMap: input vocab can not be null");
74   }
75   // Validate parameters and build map
76   std::unordered_map<WordType, WordIdType> word2id;
77   for (auto p : words) {
78     if (p.second < 0) {
79       RETURN_STATUS_UNEXPECTED("from_dict: index can not be negetive, but got " + std::to_string(p.second));
80     }
81     word2id[p.first] = p.second;
82   }
83   *vocab = std::make_shared<Vocab>(std::move(word2id));
84   return Status::OK();
85 }
86 
BuildFromVector(const std::vector<WordType> & words,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)87 Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vector<WordType> &special_tokens,
88                               bool prepend_special, std::shared_ptr<Vocab> *vocab) {
89   if (vocab == nullptr) {
90     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromVector: input vocab can not be null");
91   }
92   std::unordered_map<WordType, WordIdType> word2id;
93 
94   // if special is added in front, normal words id will start from number of special tokens
95   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
96   for (auto word : words) {
97     if (word2id.find(word) != word2id.end()) {
98       RETURN_STATUS_UNEXPECTED("from_list: word_list contains duplicate word: " + word + ".");
99     }
100     word2id[word] = word_id++;
101   }
102 
103   word_id = prepend_special ? 0 : word2id.size();
104 
105   for (auto special_token : special_tokens) {
106     if (word2id.find(special_token) != word2id.end()) {
107       RETURN_STATUS_UNEXPECTED(
108         "from_list: "
109         "special_tokens and word_list contain duplicate word: " +
110         special_token + ".");
111     }
112     word2id[special_token] = word_id++;
113   }
114 
115   *vocab = std::make_shared<Vocab>(std::move(word2id));
116   return Status::OK();
117 }
118 
BuildFromFile(const std::string & path,const std::string & delimiter,int32_t vocab_size,const std::vector<WordType> & special_tokens,bool prepend_special,std::shared_ptr<Vocab> * vocab)119 Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
120                             const std::vector<WordType> &special_tokens, bool prepend_special,
121                             std::shared_ptr<Vocab> *vocab) {
122   if (vocab == nullptr) {
123     RETURN_STATUS_UNEXPECTED("Vocab::BuildFromFile: input vocab can not be null");
124   }
125   // Validate parameters
126   auto realpath = FileUtils::GetRealPath(path.c_str());
127   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + path);
128 
129   CHECK_FAIL_RETURN_UNEXPECTED(
130     !(vocab_size < 0 && vocab_size != -1),
131     "from_file: vocab_size should be either -1 or positive integer, but got: " + std::to_string(vocab_size));
132 
133   std::string duplicate_sp;
134   for (const WordType &sp : special_tokens) {
135     if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
136       if (duplicate_sp.find(sp) == std::string::npos) {
137         duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
138       }
139     }
140   }
141   CHECK_FAIL_RETURN_UNEXPECTED(duplicate_sp.empty(),
142                                "from_file: special_tokens contains duplicate word: " + duplicate_sp);
143 
144   std::unordered_set<std::string> specials;
145   // used to check that words in file don't contain any special token that already exists
146   for (auto word : special_tokens) {
147     (void)specials.insert(word);
148   }
149   WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
150   std::unordered_map<WordType, WordIdType> word2id;
151 
152   std::fstream handle(realpath.value(), std::ios::in);
153   CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "from_file: fail to open: " + realpath.value());
154 
155   std::string word;
156   while (std::getline(handle, word)) {
157     if (!delimiter.empty()) {
158       // if delimiter is not found, find_first_of would return std::string::npos which is -1
159       word = word.substr(0, word.find_first_of(delimiter));
160     }
161     if (word2id.find(word) != word2id.end()) {
162       handle.close();
163       RETURN_STATUS_UNEXPECTED("from_file: word_list contains duplicate word:" + word);
164     }
165     if (specials.find(word) != specials.end()) {
166       handle.close();
167       RETURN_STATUS_UNEXPECTED("from_file: special_tokens and word_list contain duplicate word:" + word);
168     }
169     word2id[word] = word_id++;
170     // break if enough row is read, if vocab_size is smaller than 0
171     if (word2id.size() == vocab_size) {
172       break;
173     }
174   }
175 
176   handle.close();
177   word_id = prepend_special ? 0 : word2id.size();
178 
179   for (auto special_token : special_tokens) {
180     word2id[special_token] = word_id++;
181   }
182 
183   *vocab = std::make_shared<Vocab>(std::move(word2id));
184   return Status::OK();
185 }
186 
187 const WordIdType Vocab::kNoTokenExists = -1;
188 const WordType Vocab::kNoIdExists = std::string();
189 
190 }  // namespace dataset
191 }  // namespace mindspore
192