• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/text/ir/kernels/text_ir.h"
18 
19 #include <fstream>
20 
21 #include "minddata/dataset/text/kernels/add_token_op.h"
22 #ifndef _WIN32
23 #include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
24 #include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
25 #include "minddata/dataset/text/kernels/case_fold_op.h"
26 #include "minddata/dataset/text/kernels/filter_wikipedia_xml_op.h"
27 #endif
28 #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
29 #include "minddata/dataset/text/kernels/lookup_op.h"
30 #include "minddata/dataset/text/kernels/ngram_op.h"
31 #ifndef _WIN32
32 #include "minddata/dataset/text/kernels/normalize_utf8_op.h"
33 #include "minddata/dataset/text/kernels/regex_replace_op.h"
34 #include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
35 #endif
36 #include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
37 #include "minddata/dataset/text/kernels/sliding_window_op.h"
38 #include "minddata/dataset/text/kernels/to_number_op.h"
39 #include "minddata/dataset/text/kernels/to_vectors_op.h"
40 #include "minddata/dataset/text/kernels/truncate_op.h"
41 #include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
42 #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
43 #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
44 #ifndef _WIN32
45 #include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
46 #include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
47 #endif
48 #include "minddata/dataset/core/data_type.h"
49 #include "minddata/dataset/core/type_id.h"
50 #include "minddata/dataset/util/path.h"
51 #include "minddata/dataset/util/validators.h"
52 
53 #include "minddata/dataset/audio/ir/validators.h"
54 #include "minddata/dataset/text/ir/validators.h"
55 
56 namespace mindspore {
57 namespace dataset {
58 // Transform operations for text.
59 namespace text {
60 /* ####################################### Derived TensorOperation classes ################################# */
61 
62 // (In alphabetical order)
63 
64 // AddToken
AddTokenOperation(const std::string & token,bool begin)65 AddTokenOperation::AddTokenOperation(const std::string &token, bool begin) : token_(token), begin_(begin) {}
66 
67 AddTokenOperation::~AddTokenOperation() = default;
68 
Build()69 std::shared_ptr<TensorOp> AddTokenOperation::Build() {
70   std::shared_ptr<AddTokenOp> tensor_op = std::make_shared<AddTokenOp>(token_, begin_);
71   return tensor_op;
72 }
73 
ValidateParams()74 Status AddTokenOperation::ValidateParams() {
75   if (token_.empty()) {
76     std::string err_msg = "AddToken: Parameter token is not provided.";
77     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
78   }
79   return Status::OK();
80 }
81 
Name() const82 std::string AddTokenOperation::Name() const { return kAddTokenOperation; }
83 
to_json(nlohmann::json * out_json)84 Status AddTokenOperation::to_json(nlohmann::json *out_json) {
85   nlohmann::json args;
86   args["token"] = token_;
87   args["begin"] = begin_;
88   *out_json = args;
89   return Status::OK();
90 }
91 
92 #ifndef _WIN32
93 // BasicTokenizerOperation
BasicTokenizerOperation(bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)94 BasicTokenizerOperation::BasicTokenizerOperation(bool lower_case, bool keep_whitespace,
95                                                  const NormalizeForm normalize_form, bool preserve_unused_token,
96                                                  bool with_offsets)
97     : lower_case_(lower_case),
98       keep_whitespace_(keep_whitespace),
99       normalize_form_(normalize_form),
100       preserve_unused_token_(preserve_unused_token),
101       with_offsets_(with_offsets) {}
102 
ValidateParams()103 Status BasicTokenizerOperation::ValidateParams() {
104   if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
105       normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
106       normalize_form_ != NormalizeForm::kNfkd) {
107     std::string err_msg = "BasicTokenizer: Invalid NormalizeForm, check input value of enum.";
108     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
109   }
110   return Status::OK();
111 }
112 
Build()113 std::shared_ptr<TensorOp> BasicTokenizerOperation::Build() {
114   std::shared_ptr<BasicTokenizerOp> tensor_op = std::make_shared<BasicTokenizerOp>(
115     lower_case_, keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
116   return tensor_op;
117 }
118 
119 // BertTokenizerOperation
BertTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)120 BertTokenizerOperation::BertTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
121                                                int32_t max_bytes_per_token, const std::string &unknown_token,
122                                                bool lower_case, bool keep_whitespace,
123                                                const NormalizeForm normalize_form, bool preserve_unused_token,
124                                                bool with_offsets)
125     : vocab_(vocab),
126       suffix_indicator_(suffix_indicator),
127       max_bytes_per_token_(max_bytes_per_token),
128       unknown_token_(unknown_token),
129       lower_case_(lower_case),
130       keep_whitespace_(keep_whitespace),
131       normalize_form_(normalize_form),
132       preserve_unused_token_(preserve_unused_token),
133       with_offsets_(with_offsets) {}
134 
135 BertTokenizerOperation::~BertTokenizerOperation() = default;
136 
ValidateParams()137 Status BertTokenizerOperation::ValidateParams() {
138   if (vocab_ == nullptr) {
139     std::string err_msg = "BertTokenizer: vocab object type is incorrect or null.";
140     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
141   }
142 
143   if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
144       normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
145       normalize_form_ != NormalizeForm::kNfkd) {
146     std::string err_msg = "BertTokenizer: Invalid NormalizeForm, check input value of enum.";
147     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
148   }
149 
150   if (max_bytes_per_token_ < 0) {
151     std::string err_msg = "BertTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
152                           std::to_string(max_bytes_per_token_);
153     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
154   }
155 
156   return Status::OK();
157 }
158 
Build()159 std::shared_ptr<TensorOp> BertTokenizerOperation::Build() {
160   std::shared_ptr<BertTokenizerOp> tensor_op =
161     std::make_shared<BertTokenizerOp>(vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, lower_case_,
162                                       keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
163   return tensor_op;
164 }
165 
166 // CaseFoldOperation
ValidateParams()167 Status CaseFoldOperation::ValidateParams() { return Status::OK(); }
168 
Build()169 std::shared_ptr<TensorOp> CaseFoldOperation::Build() {
170   std::shared_ptr<CaseFoldOp> tensor_op = std::make_shared<CaseFoldOp>();
171   return tensor_op;
172 }
173 
174 // FilterWikipediaXMLOperation
FilterWikipediaXMLOperation()175 FilterWikipediaXMLOperation::FilterWikipediaXMLOperation() {}
176 
ValidateParams()177 Status FilterWikipediaXMLOperation::ValidateParams() { return Status::OK(); }
178 
Build()179 std::shared_ptr<TensorOp> FilterWikipediaXMLOperation::Build() {
180   std::shared_ptr<FilterWikipediaXMLOp> tensor_op = std::make_shared<FilterWikipediaXMLOp>();
181   return tensor_op;
182 }
183 #endif
184 
185 // JiebaTokenizerOperation
JiebaTokenizerOperation(const std::string & hmm_path,const std::string & mp_path,const JiebaMode & mode,bool with_offsets)186 JiebaTokenizerOperation::JiebaTokenizerOperation(const std::string &hmm_path, const std::string &mp_path,
187                                                  const JiebaMode &mode, bool with_offsets)
188     : hmm_path_(hmm_path), mp_path_(mp_path), mode_(mode), with_offsets_(with_offsets) {}
189 
MakeNodeInfo(cppjieba::DictUnit & node_info,const std::string & word,double weight,const std::string & tag)190 bool MakeNodeInfo(cppjieba::DictUnit &node_info, const std::string &word, double weight, const std::string &tag) {
191   if (!cppjieba::DecodeRunesInString(word, node_info.word)) {
192     return false;
193   }
194   node_info.weight = weight;
195   node_info.tag = tag;
196   return true;
197 }
198 
CalcFreqSum(const std::vector<cppjieba::DictUnit> & node_infos)199 double CalcFreqSum(const std::vector<cppjieba::DictUnit> &node_infos) {
200   double sum = 0.0;
201   for (size_t i = 0; i < node_infos.size(); i++) {
202     sum += node_infos[i].weight;
203   }
204   return sum;
205 }
206 
ValidateMPPPath(const std::string & dict_path)207 Status ValidateMPPPath(const std::string &dict_path) {
208   double freq_sum = 0.0;
209   std::vector<cppjieba::DictUnit> static_node_infos;
210   std::ifstream ifs(dict_path.c_str(), std::ios::in);
211   CHECK_FAIL_RETURN_UNEXPECTED(ifs.is_open(), "JiebaTokenizer: Failed to open file: " + dict_path);
212   std::string line;
213   std::vector<std::string> buf;
214 
215   cppjieba::DictUnit node_info;
216   for (size_t lineno = 0; std::getline(ifs, line); lineno++) {
217     cppjieba::Split(line, buf, " ");
218     if (buf.size() != cppjieba::DICT_COLUMN_NUM) {
219       ifs.close();
220       RETURN_STATUS_UNEXPECTED("JiebaTokenizer: Split result illegal, line: " + line);
221     }
222     if (!(MakeNodeInfo(node_info, buf[0], std::atof(buf[1].c_str()), buf[2]))) {
223       ifs.close();
224       RETURN_STATUS_UNEXPECTED("JiebaTokenizer: Failed to make node info.");
225     }
226     static_node_infos.push_back(node_info);
227   }
228   freq_sum = CalcFreqSum(static_node_infos);
229   if (freq_sum <= 0) {
230     ifs.close();
231     RETURN_STATUS_UNEXPECTED("JiebaTokenizer: MPSegment algorithm file format is incorrect.");
232   }
233   ifs.close();
234   return Status::OK();
235 }
236 
GetLine(std::ifstream & ifs,std::string * line)237 bool GetLine(std::ifstream &ifs, std::string *line) {
238   while (std::getline(ifs, *line)) {
239     cppjieba::Trim(*line);
240     if (line->empty()) {
241       continue;
242     }
243     if (cppjieba::StartsWith(*line, "#")) {
244       continue;
245     }
246     return true;
247   }
248   return false;
249 }
250 
ValidateHMMPath(const std::string & dict_path)251 Status ValidateHMMPath(const std::string &dict_path) {
252   std::ifstream ifs(dict_path.c_str(), std::ios::in);
253   CHECK_FAIL_RETURN_UNEXPECTED(ifs.is_open(), "JiebaTokenizer: Failed to open file: " + dict_path);
254   std::string line;
255   std::vector<std::string> buf;
256 
257   // Load startProb
258   if (!GetLine(ifs, &line)) {
259     ifs.close();
260     RETURN_STATUS_UNEXPECTED(
261       "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
262       "content fails to be obtained when startProb is loaded.");
263   }
264   cppjieba::Split(line, buf, " ");
265   if (buf.size() != kStatusSum) {
266     ifs.close();
267     RETURN_STATUS_UNEXPECTED(
268       "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
269       "content fails to be obtained when startProb is loaded.");
270   }
271 
272   // Load transProb
273   for (size_t i = 0; i < kStatusSum; i++) {
274     if (!GetLine(ifs, &line)) {
275       ifs.close();
276       RETURN_STATUS_UNEXPECTED(
277         "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
278         "content fails to be obtained when transProb is loaded.");
279     }
280     cppjieba::Split(line, buf, " ");
281     if (buf.size() != kStatusSum) {
282       ifs.close();
283       RETURN_STATUS_UNEXPECTED(
284         "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
285         "content fails to be obtained when transProb is loaded.");
286     }
287   }
288   if (!GetLine(ifs, &line)) {
289     ifs.close();
290     RETURN_STATUS_UNEXPECTED("JiebaTokenizer: HMMSegment algorithm file format is incorrect.");
291   }
292   ifs.close();
293   return Status::OK();
294 }
295 
ValidateParams()296 Status JiebaTokenizerOperation::ValidateParams() {
297   if (hmm_path_.empty()) {
298     std::string err_msg = "JiebaTokenizer: The dict of HMMSegment in cppjieba is not provided.";
299     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
300   }
301 
302   if (mp_path_.empty()) {
303     std::string err_msg = "JiebaTokenizer: The dict of MPSegment in cppjieba is not provided.";
304     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
305   }
306 
307   if (mode_ != JiebaMode::kMix && mode_ != JiebaMode::kMp && mode_ != JiebaMode::kHmm) {
308     std::string err_msg = "JiebaTokenizer: Invalid JiebaMode, check input value of enum.";
309     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
310   }
311 
312   RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", hmm_path_));
313   RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", mp_path_));
314   RETURN_IF_NOT_OK(ValidateHMMPath(hmm_path_));
315   RETURN_IF_NOT_OK(ValidateMPPPath(mp_path_));
316   return Status::OK();
317 }
318 
Build()319 std::shared_ptr<TensorOp> JiebaTokenizerOperation::Build() {
320   std::shared_ptr<JiebaTokenizerOp> tensor_op =
321     std::make_shared<JiebaTokenizerOp>(hmm_path_, mp_path_, mode_, with_offsets_);
322   for (auto &word : words_list_) {
323     Status rc = tensor_op->AddWord(word.first, word.second);
324     if (rc.IsError()) {
325       MS_LOG(ERROR) << rc;
326       return {};
327     }
328   }
329   return tensor_op;
330 }
331 
AddWord(const std::string & word,int64_t freq)332 Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) {
333   words_list_.emplace_back(word, freq);
334   return Status::OK();
335 }
336 
337 // LookupOperation
338 // DataType data_type - required for C++ API
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const DataType & data_type)339 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
340                                  const DataType &data_type)
341     : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
342 
343 // std::string data_type - required for Pybind
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const std::string & data_type)344 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
345                                  const std::string &data_type)
346     : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {
347   // Convert from string to DEType
348   DataType temp_data_type(data_type);
349   data_type_ = temp_data_type;
350 }
351 
352 LookupOperation::~LookupOperation() = default;
353 
ValidateParams()354 Status LookupOperation::ValidateParams() {
355   if (vocab_ == nullptr) {
356     std::string err_msg = "Lookup: vocab object type is incorrect or null.";
357     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
358   }
359   if (unknown_token_ != std::nullopt) {
360     default_id_ = vocab_->TokensToIds(*unknown_token_);
361     if (default_id_ == Vocab::kNoTokenExists) {
362       std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab.";
363       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
364     }
365   }
366 
367   if (!data_type_.IsNumeric()) {
368     // Note: For DEType, Bool is counted as numeric, and is a valid type for Lookup
369     std::string err_msg = "Lookup : The parameter data_type must be numeric including bool.";
370     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
371   }
372 
373   return Status::OK();
374 }
375 
Build()376 std::shared_ptr<TensorOp> LookupOperation::Build() {
377   std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, DataType(data_type_));
378   return tensor_op;
379 }
380 
381 // NgramOperation
NgramOperation(const std::vector<int32_t> & ngrams,const std::pair<std::string,int32_t> & left_pad,const std::pair<std::string,int32_t> & right_pad,const std::string & separator)382 NgramOperation::NgramOperation(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad,
383                                const std::pair<std::string, int32_t> &right_pad, const std::string &separator)
384     : ngrams_(ngrams), left_pad_(left_pad), right_pad_(right_pad), separator_(separator) {}
385 
ValidateParams()386 Status NgramOperation::ValidateParams() {
387   if (ngrams_.size() == 0) {
388     std::string err_msg = "Ngram : The size of the parameter 'ngrams' is not to be 0.";
389     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
390   } else {
391     for (int32_t i = 0; i < ngrams_.size(); ++i) {
392       if (ngrams_[i] <= 0) {
393         std::string err_msg =
394           "Ngram : The value of ngrams vector must be greater than 0: " + std::to_string(ngrams_[i]);
395         LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
396       }
397     }
398   }
399 
400   if (left_pad_.second < 0) {
401     std::string err_msg =
402       "Ngram : The second parameter pad_width in left_pad vector must be greater than or equal to 0: " +
403       std::to_string(left_pad_.second);
404     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
405   }
406 
407   if (right_pad_.second < 0) {
408     std::string err_msg =
409       "Ngram : The second parameter pad_width in right_pad vector must be greater than or equal to 0: " +
410       std::to_string(right_pad_.second);
411     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
412   }
413   return Status::OK();
414 }
415 
Build()416 std::shared_ptr<TensorOp> NgramOperation::Build() {
417   int32_t l_len = left_pad_.second;
418   int32_t r_len = right_pad_.second;
419   std::string l_pad = left_pad_.first;
420   std::string r_pad = right_pad_.first;
421   std::shared_ptr<NgramOp> tensor_op = std::make_shared<NgramOp>(ngrams_, l_len, l_pad, r_len, r_pad, separator_);
422   return tensor_op;
423 }
424 
425 #ifndef _WIN32
426 // NormalizeUTF8Operation
NormalizeUTF8Operation(NormalizeForm normalize_form)427 NormalizeUTF8Operation::NormalizeUTF8Operation(NormalizeForm normalize_form) : normalize_form_(normalize_form) {}
428 
ValidateParams()429 Status NormalizeUTF8Operation::ValidateParams() {
430   if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
431       normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
432       normalize_form_ != NormalizeForm::kNfkd) {
433     std::string err_msg = "NormalizeUTF8: Invalid NormalizeForm, check input value of enum.";
434     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
435   }
436   return Status::OK();
437 }
438 
Build()439 std::shared_ptr<TensorOp> NormalizeUTF8Operation::Build() {
440   std::shared_ptr<NormalizeUTF8Op> tensor_op = std::make_shared<NormalizeUTF8Op>(normalize_form_);
441   return tensor_op;
442 }
443 
444 // RegexReplaceOperation
RegexReplaceOperation(std::string pattern,std::string replace,bool replace_all)445 RegexReplaceOperation::RegexReplaceOperation(std::string pattern, std::string replace, bool replace_all)
446     : pattern_(pattern), replace_(replace), replace_all_(replace_all) {}
447 
ValidateParams()448 Status RegexReplaceOperation::ValidateParams() { return Status::OK(); }
449 
Build()450 std::shared_ptr<TensorOp> RegexReplaceOperation::Build() {
451   std::shared_ptr<RegexReplaceOp> tensor_op = std::make_shared<RegexReplaceOp>(pattern_, replace_, replace_all_);
452   return tensor_op;
453 }
454 
455 // RegexTokenizerOperation
RegexTokenizerOperation(std::string delim_pattern,std::string keep_delim_pattern,bool with_offsets)456 RegexTokenizerOperation::RegexTokenizerOperation(std::string delim_pattern, std::string keep_delim_pattern,
457                                                  bool with_offsets)
458     : delim_pattern_(delim_pattern), keep_delim_pattern_(keep_delim_pattern), with_offsets_(with_offsets) {}
459 
ValidateParams()460 Status RegexTokenizerOperation::ValidateParams() { return Status::OK(); }
461 
Build()462 std::shared_ptr<TensorOp> RegexTokenizerOperation::Build() {
463   std::shared_ptr<RegexTokenizerOp> tensor_op =
464     std::make_shared<RegexTokenizerOp>(delim_pattern_, keep_delim_pattern_, with_offsets_);
465   return tensor_op;
466 }
467 #endif
468 
469 // SentencePieceTokenizerOperation
470 SentencePieceTokenizerOperation::~SentencePieceTokenizerOperation() = default;
471 
SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> & vocab,SPieceTokenizerOutType out_type)472 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab,
473                                                                  SPieceTokenizerOutType out_type)
474     : vocab_(vocab), vocab_path_(std::string()), load_type_(SPieceTokenizerLoadType::kModel), out_type_(out_type) {}
475 
SentencePieceTokenizerOperation(const std::string & vocab_path,SPieceTokenizerOutType out_type)476 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::string &vocab_path,
477                                                                  SPieceTokenizerOutType out_type)
478     : vocab_(nullptr), vocab_path_(vocab_path), load_type_(SPieceTokenizerLoadType::kFile), out_type_(out_type) {}
479 
ValidateParams()480 Status SentencePieceTokenizerOperation::ValidateParams() {
481   if (out_type_ != SPieceTokenizerOutType::kString && out_type_ != SPieceTokenizerOutType::kInt) {
482     std::string err_msg = "SentencePieceTokenizer: Invalid SPieceTokenizerOutType, check input value of enum.";
483     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
484   }
485   if (load_type_ == SPieceTokenizerLoadType::kModel) {
486     if (vocab_ == nullptr) {
487       std::string err_msg = "SentencePieceTokenizer: vocab object type is incorrect or null.";
488       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
489     }
490   } else {
491     std::string real_vocab_path;
492     RETURN_IF_NOT_OK(Path::RealPath(vocab_path_, real_vocab_path));
493     Path vocab_file(real_vocab_path);
494     if (!vocab_file.Exists() || vocab_file.IsDirectory()) {
495       std::string err_msg = "SentencePieceTokenizer : vocab file: [" + vocab_path_ + "] is invalid or does not exist.";
496       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
497     }
498     if (access(vocab_file.ToString().c_str(), R_OK) == -1) {
499       std::string err_msg = "SentencePieceTokenizer : no access to specified dataset file: " + vocab_path_;
500       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
501     }
502   }
503   return Status::OK();
504 }
505 
Build()506 std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
507   std::shared_ptr<SentencePieceTokenizerOp> tensor_op;
508   if (load_type_ == SPieceTokenizerLoadType::kModel) {
509     tensor_op = std::make_shared<SentencePieceTokenizerOp>(vocab_, load_type_, out_type_);
510   } else {
511     Path vocab_file(vocab_path_);
512     std::string model_path = vocab_file.ParentPath();
513     std::string model_filename = vocab_file.Basename();
514     tensor_op = std::make_shared<SentencePieceTokenizerOp>(model_path, model_filename, load_type_, out_type_);
515   }
516   return tensor_op;
517 }
518 
519 // SlidingWindowOperation
SlidingWindowOperation(const int32_t width,const int32_t axis)520 SlidingWindowOperation::SlidingWindowOperation(const int32_t width, const int32_t axis) : width_(width), axis_(axis) {}
521 
ValidateParams()522 Status SlidingWindowOperation::ValidateParams() {
523   if (width_ < 1) {
524     std::string err_msg =
525       "SlidingWindow : The parameter width must be greater than or equal to 1: " + std::to_string(width_);
526     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
527   }
528   return Status::OK();
529 }
530 
Build()531 std::shared_ptr<TensorOp> SlidingWindowOperation::Build() {
532   std::shared_ptr<SlidingWindowOp> tensor_op = std::make_shared<SlidingWindowOp>(static_cast<uint32_t>(width_), axis_);
533   return tensor_op;
534 }
535 
536 // ToNumberOperation
537 // DataType data_type - required for C++ API
ToNumberOperation(const DataType & data_type)538 ToNumberOperation::ToNumberOperation(const DataType &data_type) : data_type_(data_type) {}
539 
540 // std::string data_type - required for Pybind
ToNumberOperation(const std::string & data_type)541 ToNumberOperation::ToNumberOperation(const std::string &data_type) {
542   // Convert from string to DEType
543   DataType temp_data_type(data_type);
544   data_type_ = temp_data_type;
545 }
546 
ValidateParams()547 Status ToNumberOperation::ValidateParams() {
548   if (!data_type_.IsNumeric() || data_type_.IsBool()) {
549     // Note: For DEType, Bool is counted as numeric, but is not a valid type for ToNumber.
550     std::string err_msg = "ToNumber : The parameter data_type must be numeric and excludes bool.";
551     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
552   }
553 
554   return Status::OK();
555 }
556 
Build()557 std::shared_ptr<TensorOp> ToNumberOperation::Build() {
558   std::shared_ptr<ToNumberOp> tensor_op = std::make_shared<ToNumberOp>(data_type_);
559   return tensor_op;
560 }
561 
to_json(nlohmann::json * out_json)562 Status ToNumberOperation::to_json(nlohmann::json *out_json) {
563   nlohmann::json args;
564   args["data_type"] = data_type_.ToString();
565   *out_json = args;
566   return Status::OK();
567 }
568 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)569 Status ToNumberOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
570   RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "data_type", kToNumberOperation));
571   std::string data_type = op_params["data_type"];
572   *operation = std::make_shared<text::ToNumberOperation>(data_type);
573   return Status::OK();
574 }
575 
576 // ToVectorsOperation
ToVectorsOperation(const std::shared_ptr<Vectors> & vectors,const std::vector<float> & unk_init,bool lower_case_backup)577 ToVectorsOperation::ToVectorsOperation(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init,
578                                        bool lower_case_backup)
579     : vectors_(vectors), unk_init_(unk_init), lower_case_backup_(lower_case_backup) {}
580 
581 ToVectorsOperation::~ToVectorsOperation() = default;
582 
ValidateParams()583 Status ToVectorsOperation::ValidateParams() {
584   if (vectors_ == nullptr) {
585     std::string err_msg = "ToVectors: vectors can't be nullptr.";
586     MS_LOG(ERROR) << err_msg;
587     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
588   }
589   return Status::OK();
590 }
591 
Build()592 std::shared_ptr<TensorOp> ToVectorsOperation::Build() {
593   std::shared_ptr<ToVectorsOp> tensor_op = std::make_shared<ToVectorsOp>(vectors_, unk_init_, lower_case_backup_);
594   return tensor_op;
595 }
596 
597 // TruncateOperation
TruncateOperation(int32_t max_seq_len)598 TruncateOperation::TruncateOperation(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
599 
ValidateParams()600 Status TruncateOperation::ValidateParams() {
601   RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Truncate", "max_seq_len", max_seq_len_));
602 
603   return Status::OK();
604 }
605 
Build()606 std::shared_ptr<TensorOp> TruncateOperation::Build() {
607   std::shared_ptr<TruncateOp> tensor_op = std::make_shared<TruncateOp>(max_seq_len_);
608   return tensor_op;
609 }
610 
to_json(nlohmann::json * out_json)611 Status TruncateOperation::to_json(nlohmann::json *out_json) {
612   nlohmann::json args;
613   args["max_seq_len"] = max_seq_len_;
614   *out_json = args;
615   return Status::OK();
616 }
617 
618 // TruncateSequencePairOperation
TruncateSequencePairOperation(int32_t max_length)619 TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
620 
ValidateParams()621 Status TruncateSequencePairOperation::ValidateParams() {
622   if (max_length_ < 0) {
623     std::string err_msg = "TruncateSequencePair : The parameter max_length must be greater than or equal to 0: " +
624                           std::to_string(max_length_);
625     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
626   }
627 
628   return Status::OK();
629 }
630 
Build()631 std::shared_ptr<TensorOp> TruncateSequencePairOperation::Build() {
632   std::shared_ptr<TruncateSequencePairOp> tensor_op = std::make_shared<TruncateSequencePairOp>(max_length_);
633   return tensor_op;
634 }
635 
636 // UnicodeCharTokenizerOperation
UnicodeCharTokenizerOperation(bool with_offsets)637 UnicodeCharTokenizerOperation::UnicodeCharTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
638 
ValidateParams()639 Status UnicodeCharTokenizerOperation::ValidateParams() { return Status::OK(); }
640 
Build()641 std::shared_ptr<TensorOp> UnicodeCharTokenizerOperation::Build() {
642   std::shared_ptr<UnicodeCharTokenizerOp> tensor_op = std::make_shared<UnicodeCharTokenizerOp>(with_offsets_);
643   return tensor_op;
644 }
645 
646 // WordpieceTokenizerOperation
WordpieceTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool with_offsets)647 WordpieceTokenizerOperation::WordpieceTokenizerOperation(const std::shared_ptr<Vocab> &vocab,
648                                                          const std::string &suffix_indicator,
649                                                          int32_t max_bytes_per_token, const std::string &unknown_token,
650                                                          bool with_offsets)
651     : vocab_(vocab),
652       suffix_indicator_(suffix_indicator),
653       max_bytes_per_token_(max_bytes_per_token),
654       unknown_token_(unknown_token),
655       with_offsets_(with_offsets) {}
656 
ValidateParams()657 Status WordpieceTokenizerOperation::ValidateParams() {
658   if (vocab_ == nullptr) {
659     std::string err_msg = "WordpieceTokenizer: vocab object type is incorrect or null.";
660     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
661   }
662   if (max_bytes_per_token_ < 0) {
663     std::string err_msg =
664       "WordpieceTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
665       std::to_string(max_bytes_per_token_);
666     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
667   }
668   return Status::OK();
669 }
670 
Build()671 std::shared_ptr<TensorOp> WordpieceTokenizerOperation::Build() {
672   std::shared_ptr<WordpieceTokenizerOp> tensor_op = std::make_shared<WordpieceTokenizerOp>(
673     vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, with_offsets_);
674   return tensor_op;
675 }
676 
677 #ifndef _WIN32
678 // UnicodeScriptTokenizerOperation
UnicodeScriptTokenizerOperation(bool keep_whitespace,bool with_offsets)679 UnicodeScriptTokenizerOperation::UnicodeScriptTokenizerOperation(bool keep_whitespace, bool with_offsets)
680     : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {}
681 
ValidateParams()682 Status UnicodeScriptTokenizerOperation::ValidateParams() { return Status::OK(); }
683 
Build()684 std::shared_ptr<TensorOp> UnicodeScriptTokenizerOperation::Build() {
685   std::shared_ptr<UnicodeScriptTokenizerOp> tensor_op =
686     std::make_shared<UnicodeScriptTokenizerOp>(keep_whitespace_, with_offsets_);
687   return tensor_op;
688 }
689 
690 // WhitespaceTokenizerOperation
WhitespaceTokenizerOperation(bool with_offsets)691 WhitespaceTokenizerOperation::WhitespaceTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
692 
ValidateParams()693 Status WhitespaceTokenizerOperation::ValidateParams() { return Status::OK(); }
694 
Build()695 std::shared_ptr<TensorOp> WhitespaceTokenizerOperation::Build() {
696   std::shared_ptr<WhitespaceTokenizerOp> tensor_op = std::make_shared<WhitespaceTokenizerOp>(with_offsets_);
697   return tensor_op;
698 }
699 #endif
700 }  // namespace text
701 }  // namespace dataset
702 }  // namespace mindspore
703