1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
18
19 #include <iomanip>
20 #include "minddata/dataset/core/config_manager.h"
21
22 namespace mindspore {
23 namespace dataset {
BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab,std::vector<std::string> col_names,int32_t vocab_size,float character_coverage,SentencePieceModel model_type,const std::unordered_map<std::string,std::string> & params,int32_t op_conn_size)24 BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab,
25 std::vector<std::string> col_names, int32_t vocab_size,
26 float character_coverage, SentencePieceModel model_type,
27 const std::unordered_map<std::string, std::string> ¶ms,
28 int32_t op_conn_size)
29 : PipelineOp(op_conn_size),
30 vocab_size_(vocab_size),
31 vocab_(vocab),
32 col_names_(col_names),
33 character_coverage_(character_coverage),
34 model_type_(model_type),
35 params_(params),
36 col_id_(0) {
37 sentence_queue_ = std::make_unique<Queue<TensorRow>>(op_conn_size);
38 read_done_ = false;
39 ret_status_ = Status::OK();
40 }
41
operator ()()42 Status BuildSentencePieceVocabOp::operator()() {
43 if (tree_ == nullptr) {
44 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
45 }
46 RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
47 RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
48 "sentenceTask", std::bind(&BuildSentencePieceVocabOp::SentenceThread, this), nullptr, id()));
49 TaskManager::FindMe()->Post();
50 child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
51 TensorRow new_row;
52 RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
53
54 bool eoe_warning = false; // give out warning if receive more than 1 eoe
55 while (child_iterator_->EofHandled() == false) {
56 while (new_row.empty() == false) {
57 RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(new_row));
58 RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
59 }
60 RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
61 CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no operator should be after from_dataset (repeat detected)");
62 eoe_warning = true;
63 }
64 // add empty tensorRow for quit
65 TensorRow empty_row = {};
66 RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(empty_row));
67 return Status::OK();
68 }
69
SentenceThread()70 Status BuildSentencePieceVocabOp::SentenceThread() {
71 TaskManager::FindMe()->Post();
72 if (col_names_.empty() == true) {
73 auto itr = column_name_id_map_.find("text");
74 CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
75 "Invalid data, 'text' column does not exist in dataset.");
76 col_id_ = itr->second;
77 } else {
78 auto itr = column_name_id_map_.find(col_names_[0]);
79 CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
80 "Invalid parameter, column name: " + col_names_[0] + " does not exist in dataset.");
81 col_id_ = itr->second;
82 }
83 std::unique_ptr<DatasetSentenceIterator> sentence_iter = std::make_unique<DatasetSentenceIterator>(this);
84 std::string model_proto;
85 sentencepiece::util::Status s_status =
86 sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
87 if (!s_status.ok()) {
88 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, s_status.message());
89 } else {
90 if (vocab_ == nullptr) {
91 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
92 "Invalid parameter, SentencePiece vocab not set.");
93 }
94 vocab_->set_model_proto(model_proto);
95 }
96 RETURN_IF_NOT_OK(out_connector_->SendEOE());
97 RETURN_IF_NOT_OK(out_connector_->SendEOF());
98 return Status::OK();
99 }
100
BuildParams()101 std::unordered_map<std::string, std::string> BuildSentencePieceVocabOp::BuildParams() {
102 std::unordered_map<std::string, std::string> ret_params;
103 ret_params["vocab_size"] = std::to_string(vocab_size_);
104 ret_params["character_coverage"] = std::to_string(character_coverage_);
105 if (model_type_ == SentencePieceModel::kBpe) {
106 ret_params["model_type"] = "BPE";
107 } else if (model_type_ == SentencePieceModel::kChar) {
108 ret_params["model_type"] = "CHAR";
109 } else if (model_type_ == SentencePieceModel::kWord) {
110 ret_params["model_type"] = "WORD";
111 } else {
112 ret_params["model_type"] = "UNIGRAM";
113 }
114 // filter some params that set by function param
115 // filter model_prefix that must be empty
116 for (auto param : params_) {
117 std::string key = param.first;
118 if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
119 key == "model_type") {
120 continue;
121 }
122 ret_params[key] = param.second;
123 }
124
125 ret_params["model_prefix"] = "";
126 ret_params["minloglevel"] = "1";
127 return ret_params;
128 }
129
Done()130 bool BuildSentencePieceVocabOp::Done() { return read_done_; }
131
Next(std::string * sentence)132 void BuildSentencePieceVocabOp::Next(std::string *sentence) {
133 TensorRow new_row;
134 Status s = sentence_queue_->PopFront(&new_row);
135
136 if (s.IsError()) {
137 read_done_ = true;
138 ret_status_ = s;
139 return;
140 }
141 if (new_row.empty() == true) {
142 read_done_ = true;
143 ret_status_ = Status::OK();
144 return;
145 }
146
147 if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
148 ret_status_ =
149 Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
150 "Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " +
151 new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank()));
152 read_done_ = true;
153 return;
154 }
155
156 std::string_view sentence_v;
157 ret_status_ = new_row[col_id_]->GetItemAt(&sentence_v, {});
158 if (ret_status_.IsError()) {
159 read_done_ = true;
160 return;
161 }
162
163 std::string st{sentence_v};
164 *sentence = st;
165 ret_status_ = Status::OK();
166 }
167
DatasetSentenceIterator(BuildSentencePieceVocabOp * s_p_vocab_ptr)168 BuildSentencePieceVocabOp::DatasetSentenceIterator::DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr)
169 : s_p_vocab_ptr_(s_p_vocab_ptr) {}
170
done() const171 bool BuildSentencePieceVocabOp::DatasetSentenceIterator::done() const {
172 if (s_p_vocab_ptr_ == nullptr) {
173 return true;
174 }
175 return s_p_vocab_ptr_->Done();
176 }
177
Next()178 void BuildSentencePieceVocabOp::DatasetSentenceIterator::Next() {
179 if (s_p_vocab_ptr_ == nullptr) {
180 return;
181 }
182 s_p_vocab_ptr_->Next(&value_);
183 }
184 } // namespace dataset
185 } // namespace mindspore
186