• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
18 
19 #include <iomanip>
20 #include "minddata/dataset/core/config_manager.h"
21 
22 namespace mindspore {
23 namespace dataset {
BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab,std::vector<std::string> col_names,int32_t vocab_size,float character_coverage,SentencePieceModel model_type,const std::unordered_map<std::string,std::string> & params,int32_t op_conn_size)24 BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab,
25                                                      std::vector<std::string> col_names, int32_t vocab_size,
26                                                      float character_coverage, SentencePieceModel model_type,
27                                                      const std::unordered_map<std::string, std::string> &params,
28                                                      int32_t op_conn_size)
29     : PipelineOp(op_conn_size),
30       vocab_size_(vocab_size),
31       vocab_(vocab),
32       col_names_(col_names),
33       character_coverage_(character_coverage),
34       model_type_(model_type),
35       params_(params),
36       col_id_(0) {
37   sentence_queue_ = std::make_unique<Queue<TensorRow>>(op_conn_size);
38   read_done_ = false;
39   ret_status_ = Status::OK();
40 }
41 
operator ()()42 Status BuildSentencePieceVocabOp::operator()() {
43   if (tree_ == nullptr) {
44     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
45   }
46   RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
47   RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
48     "sentenceTask", std::bind(&BuildSentencePieceVocabOp::SentenceThread, this), nullptr, id()));
49   TaskManager::FindMe()->Post();
50   child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
51   TensorRow new_row;
52   RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
53 
54   bool eoe_warning = false;  // give out warning if receive more than 1 eoe
55   while (child_iterator_->EofHandled() == false) {
56     while (new_row.empty() == false) {
57       RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(new_row));
58       RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
59     }
60     RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
61     CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no operator should be after from_dataset (repeat detected)");
62     eoe_warning = true;
63   }
64   // add empty tensorRow for quit
65   TensorRow empty_row = {};
66   RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(empty_row));
67   return Status::OK();
68 }
69 
SentenceThread()70 Status BuildSentencePieceVocabOp::SentenceThread() {
71   TaskManager::FindMe()->Post();
72   if (col_names_.empty() == true) {
73     auto itr = column_name_id_map_.find("text");
74     CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
75                                  "Invalid data, 'text' column does not exist in dataset.");
76     col_id_ = itr->second;
77   } else {
78     auto itr = column_name_id_map_.find(col_names_[0]);
79     CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
80                                  "Invalid parameter, column name: " + col_names_[0] + " does not exist in dataset.");
81     col_id_ = itr->second;
82   }
83   std::unique_ptr<DatasetSentenceIterator> sentence_iter = std::make_unique<DatasetSentenceIterator>(this);
84   std::string model_proto;
85   sentencepiece::util::Status s_status =
86     sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
87   if (!s_status.ok()) {
88     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, s_status.message());
89   } else {
90     if (vocab_ == nullptr) {
91       return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
92                     "Invalid parameter, SentencePiece vocab not set.");
93     }
94     vocab_->set_model_proto(model_proto);
95   }
96   RETURN_IF_NOT_OK(out_connector_->SendEOE());
97   RETURN_IF_NOT_OK(out_connector_->SendEOF());
98   return Status::OK();
99 }
100 
BuildParams()101 std::unordered_map<std::string, std::string> BuildSentencePieceVocabOp::BuildParams() {
102   std::unordered_map<std::string, std::string> ret_params;
103   ret_params["vocab_size"] = std::to_string(vocab_size_);
104   ret_params["character_coverage"] = std::to_string(character_coverage_);
105   if (model_type_ == SentencePieceModel::kBpe) {
106     ret_params["model_type"] = "BPE";
107   } else if (model_type_ == SentencePieceModel::kChar) {
108     ret_params["model_type"] = "CHAR";
109   } else if (model_type_ == SentencePieceModel::kWord) {
110     ret_params["model_type"] = "WORD";
111   } else {
112     ret_params["model_type"] = "UNIGRAM";
113   }
114   // filter some params that set by function param
115   // filter  model_prefix that must be empty
116   for (auto param : params_) {
117     std::string key = param.first;
118     if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
119         key == "model_type") {
120       continue;
121     }
122     ret_params[key] = param.second;
123   }
124 
125   ret_params["model_prefix"] = "";
126   ret_params["minloglevel"] = "1";
127   return ret_params;
128 }
129 
Done()130 bool BuildSentencePieceVocabOp::Done() { return read_done_; }
131 
Next(std::string * sentence)132 void BuildSentencePieceVocabOp::Next(std::string *sentence) {
133   TensorRow new_row;
134   Status s = sentence_queue_->PopFront(&new_row);
135 
136   if (s.IsError()) {
137     read_done_ = true;
138     ret_status_ = s;
139     return;
140   }
141   if (new_row.empty() == true) {
142     read_done_ = true;
143     ret_status_ = Status::OK();
144     return;
145   }
146 
147   if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
148     ret_status_ =
149       Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
150              "Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " +
151                new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank()));
152     read_done_ = true;
153     return;
154   }
155 
156   std::string_view sentence_v;
157   ret_status_ = new_row[col_id_]->GetItemAt(&sentence_v, {});
158   if (ret_status_.IsError()) {
159     read_done_ = true;
160     return;
161   }
162 
163   std::string st{sentence_v};
164   *sentence = st;
165   ret_status_ = Status::OK();
166 }
167 
DatasetSentenceIterator(BuildSentencePieceVocabOp * s_p_vocab_ptr)168 BuildSentencePieceVocabOp::DatasetSentenceIterator::DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr)
169     : s_p_vocab_ptr_(s_p_vocab_ptr) {}
170 
done() const171 bool BuildSentencePieceVocabOp::DatasetSentenceIterator::done() const {
172   if (s_p_vocab_ptr_ == nullptr) {
173     return true;
174   }
175   return s_p_vocab_ptr_->Done();
176 }
177 
Next()178 void BuildSentencePieceVocabOp::DatasetSentenceIterator::Next() {
179   if (s_p_vocab_ptr_ == nullptr) {
180     return;
181   }
182   s_p_vocab_ptr_->Next(&value_);
183 }
184 }  // namespace dataset
185 }  // namespace mindspore
186