1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/text/ir/kernels/text_ir.h"
18
19 #include <fstream>
20
21 #include "minddata/dataset/text/kernels/add_token_op.h"
22 #ifndef _WIN32
23 #include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
24 #include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
25 #include "minddata/dataset/text/kernels/case_fold_op.h"
26 #include "minddata/dataset/text/kernels/filter_wikipedia_xml_op.h"
27 #endif
28 #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
29 #include "minddata/dataset/text/kernels/lookup_op.h"
30 #include "minddata/dataset/text/kernels/ngram_op.h"
31 #ifndef _WIN32
32 #include "minddata/dataset/text/kernels/normalize_utf8_op.h"
33 #include "minddata/dataset/text/kernels/regex_replace_op.h"
34 #include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
35 #endif
36 #include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
37 #include "minddata/dataset/text/kernels/sliding_window_op.h"
38 #include "minddata/dataset/text/kernels/to_number_op.h"
39 #include "minddata/dataset/text/kernels/to_vectors_op.h"
40 #include "minddata/dataset/text/kernels/truncate_op.h"
41 #include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
42 #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
43 #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
44 #ifndef _WIN32
45 #include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
46 #include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
47 #endif
48 #include "minddata/dataset/core/data_type.h"
49 #include "minddata/dataset/core/type_id.h"
50 #include "minddata/dataset/util/path.h"
51 #include "minddata/dataset/util/validators.h"
52
53 #include "minddata/dataset/audio/ir/validators.h"
54 #include "minddata/dataset/text/ir/validators.h"
55
56 namespace mindspore {
57 namespace dataset {
58 // Transform operations for text.
59 namespace text {
60 /* ####################################### Derived TensorOperation classes ################################# */
61
62 // (In alphabetical order)
63
64 // AddToken
AddTokenOperation(const std::string & token,bool begin)65 AddTokenOperation::AddTokenOperation(const std::string &token, bool begin) : token_(token), begin_(begin) {}
66
67 AddTokenOperation::~AddTokenOperation() = default;
68
Build()69 std::shared_ptr<TensorOp> AddTokenOperation::Build() {
70 std::shared_ptr<AddTokenOp> tensor_op = std::make_shared<AddTokenOp>(token_, begin_);
71 return tensor_op;
72 }
73
ValidateParams()74 Status AddTokenOperation::ValidateParams() {
75 if (token_.empty()) {
76 std::string err_msg = "AddToken: Parameter token is not provided.";
77 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
78 }
79 return Status::OK();
80 }
81
Name() const82 std::string AddTokenOperation::Name() const { return kAddTokenOperation; }
83
to_json(nlohmann::json * out_json)84 Status AddTokenOperation::to_json(nlohmann::json *out_json) {
85 nlohmann::json args;
86 args["token"] = token_;
87 args["begin"] = begin_;
88 *out_json = args;
89 return Status::OK();
90 }
91
92 #ifndef _WIN32
93 // BasicTokenizerOperation
BasicTokenizerOperation(bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)94 BasicTokenizerOperation::BasicTokenizerOperation(bool lower_case, bool keep_whitespace,
95 const NormalizeForm normalize_form, bool preserve_unused_token,
96 bool with_offsets)
97 : lower_case_(lower_case),
98 keep_whitespace_(keep_whitespace),
99 normalize_form_(normalize_form),
100 preserve_unused_token_(preserve_unused_token),
101 with_offsets_(with_offsets) {}
102
ValidateParams()103 Status BasicTokenizerOperation::ValidateParams() {
104 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
105 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
106 normalize_form_ != NormalizeForm::kNfkd) {
107 std::string err_msg = "BasicTokenizer: Invalid NormalizeForm, check input value of enum.";
108 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
109 }
110 return Status::OK();
111 }
112
Build()113 std::shared_ptr<TensorOp> BasicTokenizerOperation::Build() {
114 std::shared_ptr<BasicTokenizerOp> tensor_op = std::make_shared<BasicTokenizerOp>(
115 lower_case_, keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
116 return tensor_op;
117 }
118
119 // BertTokenizerOperation
BertTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)120 BertTokenizerOperation::BertTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
121 int32_t max_bytes_per_token, const std::string &unknown_token,
122 bool lower_case, bool keep_whitespace,
123 const NormalizeForm normalize_form, bool preserve_unused_token,
124 bool with_offsets)
125 : vocab_(vocab),
126 suffix_indicator_(suffix_indicator),
127 max_bytes_per_token_(max_bytes_per_token),
128 unknown_token_(unknown_token),
129 lower_case_(lower_case),
130 keep_whitespace_(keep_whitespace),
131 normalize_form_(normalize_form),
132 preserve_unused_token_(preserve_unused_token),
133 with_offsets_(with_offsets) {}
134
135 BertTokenizerOperation::~BertTokenizerOperation() = default;
136
ValidateParams()137 Status BertTokenizerOperation::ValidateParams() {
138 if (vocab_ == nullptr) {
139 std::string err_msg = "BertTokenizer: vocab object type is incorrect or null.";
140 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
141 }
142
143 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
144 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
145 normalize_form_ != NormalizeForm::kNfkd) {
146 std::string err_msg = "BertTokenizer: Invalid NormalizeForm, check input value of enum.";
147 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
148 }
149
150 if (max_bytes_per_token_ < 0) {
151 std::string err_msg = "BertTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
152 std::to_string(max_bytes_per_token_);
153 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
154 }
155
156 return Status::OK();
157 }
158
Build()159 std::shared_ptr<TensorOp> BertTokenizerOperation::Build() {
160 std::shared_ptr<BertTokenizerOp> tensor_op =
161 std::make_shared<BertTokenizerOp>(vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, lower_case_,
162 keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
163 return tensor_op;
164 }
165
166 // CaseFoldOperation
ValidateParams()167 Status CaseFoldOperation::ValidateParams() { return Status::OK(); }
168
Build()169 std::shared_ptr<TensorOp> CaseFoldOperation::Build() {
170 std::shared_ptr<CaseFoldOp> tensor_op = std::make_shared<CaseFoldOp>();
171 return tensor_op;
172 }
173
174 // FilterWikipediaXMLOperation
FilterWikipediaXMLOperation()175 FilterWikipediaXMLOperation::FilterWikipediaXMLOperation() {}
176
ValidateParams()177 Status FilterWikipediaXMLOperation::ValidateParams() { return Status::OK(); }
178
Build()179 std::shared_ptr<TensorOp> FilterWikipediaXMLOperation::Build() {
180 std::shared_ptr<FilterWikipediaXMLOp> tensor_op = std::make_shared<FilterWikipediaXMLOp>();
181 return tensor_op;
182 }
183 #endif
184
185 // JiebaTokenizerOperation
JiebaTokenizerOperation(const std::string & hmm_path,const std::string & mp_path,const JiebaMode & mode,bool with_offsets)186 JiebaTokenizerOperation::JiebaTokenizerOperation(const std::string &hmm_path, const std::string &mp_path,
187 const JiebaMode &mode, bool with_offsets)
188 : hmm_path_(hmm_path), mp_path_(mp_path), mode_(mode), with_offsets_(with_offsets) {}
189
MakeNodeInfo(cppjieba::DictUnit & node_info,const std::string & word,double weight,const std::string & tag)190 bool MakeNodeInfo(cppjieba::DictUnit &node_info, const std::string &word, double weight, const std::string &tag) {
191 if (!cppjieba::DecodeRunesInString(word, node_info.word)) {
192 return false;
193 }
194 node_info.weight = weight;
195 node_info.tag = tag;
196 return true;
197 }
198
CalcFreqSum(const std::vector<cppjieba::DictUnit> & node_infos)199 double CalcFreqSum(const std::vector<cppjieba::DictUnit> &node_infos) {
200 double sum = 0.0;
201 for (size_t i = 0; i < node_infos.size(); i++) {
202 sum += node_infos[i].weight;
203 }
204 return sum;
205 }
206
ValidateMPPPath(const std::string & dict_path)207 Status ValidateMPPPath(const std::string &dict_path) {
208 double freq_sum = 0.0;
209 std::vector<cppjieba::DictUnit> static_node_infos;
210 std::ifstream ifs(dict_path.c_str(), std::ios::in);
211 CHECK_FAIL_RETURN_UNEXPECTED(ifs.is_open(), "JiebaTokenizer: Failed to open file: " + dict_path);
212 std::string line;
213 std::vector<std::string> buf;
214
215 cppjieba::DictUnit node_info;
216 for (size_t lineno = 0; std::getline(ifs, line); lineno++) {
217 cppjieba::Split(line, buf, " ");
218 if (buf.size() != cppjieba::DICT_COLUMN_NUM) {
219 ifs.close();
220 RETURN_STATUS_UNEXPECTED("JiebaTokenizer: Split result illegal, line: " + line);
221 }
222 if (!(MakeNodeInfo(node_info, buf[0], std::atof(buf[1].c_str()), buf[2]))) {
223 ifs.close();
224 RETURN_STATUS_UNEXPECTED("JiebaTokenizer: Failed to make node info.");
225 }
226 static_node_infos.push_back(node_info);
227 }
228 freq_sum = CalcFreqSum(static_node_infos);
229 if (freq_sum <= 0) {
230 ifs.close();
231 RETURN_STATUS_UNEXPECTED("JiebaTokenizer: MPSegment algorithm file format is incorrect.");
232 }
233 ifs.close();
234 return Status::OK();
235 }
236
GetLine(std::ifstream & ifs,std::string * line)237 bool GetLine(std::ifstream &ifs, std::string *line) {
238 while (std::getline(ifs, *line)) {
239 cppjieba::Trim(*line);
240 if (line->empty()) {
241 continue;
242 }
243 if (cppjieba::StartsWith(*line, "#")) {
244 continue;
245 }
246 return true;
247 }
248 return false;
249 }
250
ValidateHMMPath(const std::string & dict_path)251 Status ValidateHMMPath(const std::string &dict_path) {
252 std::ifstream ifs(dict_path.c_str(), std::ios::in);
253 CHECK_FAIL_RETURN_UNEXPECTED(ifs.is_open(), "JiebaTokenizer: Failed to open file: " + dict_path);
254 std::string line;
255 std::vector<std::string> buf;
256
257 // Load startProb
258 if (!GetLine(ifs, &line)) {
259 ifs.close();
260 RETURN_STATUS_UNEXPECTED(
261 "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
262 "content fails to be obtained when startProb is loaded.");
263 }
264 cppjieba::Split(line, buf, " ");
265 if (buf.size() != kStatusSum) {
266 ifs.close();
267 RETURN_STATUS_UNEXPECTED(
268 "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
269 "content fails to be obtained when startProb is loaded.");
270 }
271
272 // Load transProb
273 for (size_t i = 0; i < kStatusSum; i++) {
274 if (!GetLine(ifs, &line)) {
275 ifs.close();
276 RETURN_STATUS_UNEXPECTED(
277 "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
278 "content fails to be obtained when transProb is loaded.");
279 }
280 cppjieba::Split(line, buf, " ");
281 if (buf.size() != kStatusSum) {
282 ifs.close();
283 RETURN_STATUS_UNEXPECTED(
284 "JiebaTokenizer: The file format of the MPSegment algorithm is incorrect, and the "
285 "content fails to be obtained when transProb is loaded.");
286 }
287 }
288 if (!GetLine(ifs, &line)) {
289 ifs.close();
290 RETURN_STATUS_UNEXPECTED("JiebaTokenizer: HMMSegment algorithm file format is incorrect.");
291 }
292 ifs.close();
293 return Status::OK();
294 }
295
ValidateParams()296 Status JiebaTokenizerOperation::ValidateParams() {
297 if (hmm_path_.empty()) {
298 std::string err_msg = "JiebaTokenizer: The dict of HMMSegment in cppjieba is not provided.";
299 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
300 }
301
302 if (mp_path_.empty()) {
303 std::string err_msg = "JiebaTokenizer: The dict of MPSegment in cppjieba is not provided.";
304 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
305 }
306
307 if (mode_ != JiebaMode::kMix && mode_ != JiebaMode::kMp && mode_ != JiebaMode::kHmm) {
308 std::string err_msg = "JiebaTokenizer: Invalid JiebaMode, check input value of enum.";
309 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
310 }
311
312 RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", hmm_path_));
313 RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", mp_path_));
314 RETURN_IF_NOT_OK(ValidateHMMPath(hmm_path_));
315 RETURN_IF_NOT_OK(ValidateMPPPath(mp_path_));
316 return Status::OK();
317 }
318
Build()319 std::shared_ptr<TensorOp> JiebaTokenizerOperation::Build() {
320 std::shared_ptr<JiebaTokenizerOp> tensor_op =
321 std::make_shared<JiebaTokenizerOp>(hmm_path_, mp_path_, mode_, with_offsets_);
322 for (auto &word : words_list_) {
323 Status rc = tensor_op->AddWord(word.first, word.second);
324 if (rc.IsError()) {
325 MS_LOG(ERROR) << rc;
326 return {};
327 }
328 }
329 return tensor_op;
330 }
331
AddWord(const std::string & word,int64_t freq)332 Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) {
333 words_list_.emplace_back(word, freq);
334 return Status::OK();
335 }
336
337 // LookupOperation
338 // DataType data_type - required for C++ API
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const DataType & data_type)339 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
340 const DataType &data_type)
341 : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
342
343 // std::string data_type - required for Pybind
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const std::string & data_type)344 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
345 const std::string &data_type)
346 : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {
347 // Convert from string to DEType
348 DataType temp_data_type(data_type);
349 data_type_ = temp_data_type;
350 }
351
352 LookupOperation::~LookupOperation() = default;
353
ValidateParams()354 Status LookupOperation::ValidateParams() {
355 if (vocab_ == nullptr) {
356 std::string err_msg = "Lookup: vocab object type is incorrect or null.";
357 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
358 }
359 if (unknown_token_ != std::nullopt) {
360 default_id_ = vocab_->TokensToIds(*unknown_token_);
361 if (default_id_ == Vocab::kNoTokenExists) {
362 std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab.";
363 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
364 }
365 }
366
367 if (!data_type_.IsNumeric()) {
368 // Note: For DEType, Bool is counted as numeric, and is a valid type for Lookup
369 std::string err_msg = "Lookup : The parameter data_type must be numeric including bool.";
370 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
371 }
372
373 return Status::OK();
374 }
375
Build()376 std::shared_ptr<TensorOp> LookupOperation::Build() {
377 std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, DataType(data_type_));
378 return tensor_op;
379 }
380
381 // NgramOperation
NgramOperation(const std::vector<int32_t> & ngrams,const std::pair<std::string,int32_t> & left_pad,const std::pair<std::string,int32_t> & right_pad,const std::string & separator)382 NgramOperation::NgramOperation(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad,
383 const std::pair<std::string, int32_t> &right_pad, const std::string &separator)
384 : ngrams_(ngrams), left_pad_(left_pad), right_pad_(right_pad), separator_(separator) {}
385
ValidateParams()386 Status NgramOperation::ValidateParams() {
387 if (ngrams_.size() == 0) {
388 std::string err_msg = "Ngram : The size of the parameter 'ngrams' is not to be 0.";
389 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
390 } else {
391 for (int32_t i = 0; i < ngrams_.size(); ++i) {
392 if (ngrams_[i] <= 0) {
393 std::string err_msg =
394 "Ngram : The value of ngrams vector must be greater than 0: " + std::to_string(ngrams_[i]);
395 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
396 }
397 }
398 }
399
400 if (left_pad_.second < 0) {
401 std::string err_msg =
402 "Ngram : The second parameter pad_width in left_pad vector must be greater than or equal to 0: " +
403 std::to_string(left_pad_.second);
404 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
405 }
406
407 if (right_pad_.second < 0) {
408 std::string err_msg =
409 "Ngram : The second parameter pad_width in right_pad vector must be greater than or equal to 0: " +
410 std::to_string(right_pad_.second);
411 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
412 }
413 return Status::OK();
414 }
415
Build()416 std::shared_ptr<TensorOp> NgramOperation::Build() {
417 int32_t l_len = left_pad_.second;
418 int32_t r_len = right_pad_.second;
419 std::string l_pad = left_pad_.first;
420 std::string r_pad = right_pad_.first;
421 std::shared_ptr<NgramOp> tensor_op = std::make_shared<NgramOp>(ngrams_, l_len, l_pad, r_len, r_pad, separator_);
422 return tensor_op;
423 }
424
425 #ifndef _WIN32
426 // NormalizeUTF8Operation
NormalizeUTF8Operation(NormalizeForm normalize_form)427 NormalizeUTF8Operation::NormalizeUTF8Operation(NormalizeForm normalize_form) : normalize_form_(normalize_form) {}
428
ValidateParams()429 Status NormalizeUTF8Operation::ValidateParams() {
430 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
431 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
432 normalize_form_ != NormalizeForm::kNfkd) {
433 std::string err_msg = "NormalizeUTF8: Invalid NormalizeForm, check input value of enum.";
434 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
435 }
436 return Status::OK();
437 }
438
Build()439 std::shared_ptr<TensorOp> NormalizeUTF8Operation::Build() {
440 std::shared_ptr<NormalizeUTF8Op> tensor_op = std::make_shared<NormalizeUTF8Op>(normalize_form_);
441 return tensor_op;
442 }
443
444 // RegexReplaceOperation
RegexReplaceOperation(std::string pattern,std::string replace,bool replace_all)445 RegexReplaceOperation::RegexReplaceOperation(std::string pattern, std::string replace, bool replace_all)
446 : pattern_(pattern), replace_(replace), replace_all_(replace_all) {}
447
ValidateParams()448 Status RegexReplaceOperation::ValidateParams() { return Status::OK(); }
449
Build()450 std::shared_ptr<TensorOp> RegexReplaceOperation::Build() {
451 std::shared_ptr<RegexReplaceOp> tensor_op = std::make_shared<RegexReplaceOp>(pattern_, replace_, replace_all_);
452 return tensor_op;
453 }
454
455 // RegexTokenizerOperation
RegexTokenizerOperation(std::string delim_pattern,std::string keep_delim_pattern,bool with_offsets)456 RegexTokenizerOperation::RegexTokenizerOperation(std::string delim_pattern, std::string keep_delim_pattern,
457 bool with_offsets)
458 : delim_pattern_(delim_pattern), keep_delim_pattern_(keep_delim_pattern), with_offsets_(with_offsets) {}
459
ValidateParams()460 Status RegexTokenizerOperation::ValidateParams() { return Status::OK(); }
461
Build()462 std::shared_ptr<TensorOp> RegexTokenizerOperation::Build() {
463 std::shared_ptr<RegexTokenizerOp> tensor_op =
464 std::make_shared<RegexTokenizerOp>(delim_pattern_, keep_delim_pattern_, with_offsets_);
465 return tensor_op;
466 }
467 #endif
468
469 // SentencePieceTokenizerOperation
470 SentencePieceTokenizerOperation::~SentencePieceTokenizerOperation() = default;
471
SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> & vocab,SPieceTokenizerOutType out_type)472 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab,
473 SPieceTokenizerOutType out_type)
474 : vocab_(vocab), vocab_path_(std::string()), load_type_(SPieceTokenizerLoadType::kModel), out_type_(out_type) {}
475
SentencePieceTokenizerOperation(const std::string & vocab_path,SPieceTokenizerOutType out_type)476 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::string &vocab_path,
477 SPieceTokenizerOutType out_type)
478 : vocab_(nullptr), vocab_path_(vocab_path), load_type_(SPieceTokenizerLoadType::kFile), out_type_(out_type) {}
479
ValidateParams()480 Status SentencePieceTokenizerOperation::ValidateParams() {
481 if (out_type_ != SPieceTokenizerOutType::kString && out_type_ != SPieceTokenizerOutType::kInt) {
482 std::string err_msg = "SentencePieceTokenizer: Invalid SPieceTokenizerOutType, check input value of enum.";
483 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
484 }
485 if (load_type_ == SPieceTokenizerLoadType::kModel) {
486 if (vocab_ == nullptr) {
487 std::string err_msg = "SentencePieceTokenizer: vocab object type is incorrect or null.";
488 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
489 }
490 } else {
491 std::string real_vocab_path;
492 RETURN_IF_NOT_OK(Path::RealPath(vocab_path_, real_vocab_path));
493 Path vocab_file(real_vocab_path);
494 if (!vocab_file.Exists() || vocab_file.IsDirectory()) {
495 std::string err_msg = "SentencePieceTokenizer : vocab file: [" + vocab_path_ + "] is invalid or does not exist.";
496 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
497 }
498 if (access(vocab_file.ToString().c_str(), R_OK) == -1) {
499 std::string err_msg = "SentencePieceTokenizer : no access to specified dataset file: " + vocab_path_;
500 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
501 }
502 }
503 return Status::OK();
504 }
505
Build()506 std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
507 std::shared_ptr<SentencePieceTokenizerOp> tensor_op;
508 if (load_type_ == SPieceTokenizerLoadType::kModel) {
509 tensor_op = std::make_shared<SentencePieceTokenizerOp>(vocab_, load_type_, out_type_);
510 } else {
511 Path vocab_file(vocab_path_);
512 std::string model_path = vocab_file.ParentPath();
513 std::string model_filename = vocab_file.Basename();
514 tensor_op = std::make_shared<SentencePieceTokenizerOp>(model_path, model_filename, load_type_, out_type_);
515 }
516 return tensor_op;
517 }
518
519 // SlidingWindowOperation
SlidingWindowOperation(const int32_t width,const int32_t axis)520 SlidingWindowOperation::SlidingWindowOperation(const int32_t width, const int32_t axis) : width_(width), axis_(axis) {}
521
ValidateParams()522 Status SlidingWindowOperation::ValidateParams() {
523 if (width_ < 1) {
524 std::string err_msg =
525 "SlidingWindow : The parameter width must be greater than or equal to 1: " + std::to_string(width_);
526 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
527 }
528 return Status::OK();
529 }
530
Build()531 std::shared_ptr<TensorOp> SlidingWindowOperation::Build() {
532 std::shared_ptr<SlidingWindowOp> tensor_op = std::make_shared<SlidingWindowOp>(static_cast<uint32_t>(width_), axis_);
533 return tensor_op;
534 }
535
536 // ToNumberOperation
537 // DataType data_type - required for C++ API
ToNumberOperation(const DataType & data_type)538 ToNumberOperation::ToNumberOperation(const DataType &data_type) : data_type_(data_type) {}
539
540 // std::string data_type - required for Pybind
ToNumberOperation(const std::string & data_type)541 ToNumberOperation::ToNumberOperation(const std::string &data_type) {
542 // Convert from string to DEType
543 DataType temp_data_type(data_type);
544 data_type_ = temp_data_type;
545 }
546
ValidateParams()547 Status ToNumberOperation::ValidateParams() {
548 if (!data_type_.IsNumeric() || data_type_.IsBool()) {
549 // Note: For DEType, Bool is counted as numeric, but is not a valid type for ToNumber.
550 std::string err_msg = "ToNumber : The parameter data_type must be numeric and excludes bool.";
551 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
552 }
553
554 return Status::OK();
555 }
556
Build()557 std::shared_ptr<TensorOp> ToNumberOperation::Build() {
558 std::shared_ptr<ToNumberOp> tensor_op = std::make_shared<ToNumberOp>(data_type_);
559 return tensor_op;
560 }
561
to_json(nlohmann::json * out_json)562 Status ToNumberOperation::to_json(nlohmann::json *out_json) {
563 nlohmann::json args;
564 args["data_type"] = data_type_.ToString();
565 *out_json = args;
566 return Status::OK();
567 }
568
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)569 Status ToNumberOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
570 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "data_type", kToNumberOperation));
571 std::string data_type = op_params["data_type"];
572 *operation = std::make_shared<text::ToNumberOperation>(data_type);
573 return Status::OK();
574 }
575
576 // ToVectorsOperation
ToVectorsOperation(const std::shared_ptr<Vectors> & vectors,const std::vector<float> & unk_init,bool lower_case_backup)577 ToVectorsOperation::ToVectorsOperation(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init,
578 bool lower_case_backup)
579 : vectors_(vectors), unk_init_(unk_init), lower_case_backup_(lower_case_backup) {}
580
581 ToVectorsOperation::~ToVectorsOperation() = default;
582
ValidateParams()583 Status ToVectorsOperation::ValidateParams() {
584 if (vectors_ == nullptr) {
585 std::string err_msg = "ToVectors: vectors can't be nullptr.";
586 MS_LOG(ERROR) << err_msg;
587 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
588 }
589 return Status::OK();
590 }
591
Build()592 std::shared_ptr<TensorOp> ToVectorsOperation::Build() {
593 std::shared_ptr<ToVectorsOp> tensor_op = std::make_shared<ToVectorsOp>(vectors_, unk_init_, lower_case_backup_);
594 return tensor_op;
595 }
596
597 // TruncateOperation
TruncateOperation(int32_t max_seq_len)598 TruncateOperation::TruncateOperation(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
599
ValidateParams()600 Status TruncateOperation::ValidateParams() {
601 RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Truncate", "max_seq_len", max_seq_len_));
602
603 return Status::OK();
604 }
605
Build()606 std::shared_ptr<TensorOp> TruncateOperation::Build() {
607 std::shared_ptr<TruncateOp> tensor_op = std::make_shared<TruncateOp>(max_seq_len_);
608 return tensor_op;
609 }
610
to_json(nlohmann::json * out_json)611 Status TruncateOperation::to_json(nlohmann::json *out_json) {
612 nlohmann::json args;
613 args["max_seq_len"] = max_seq_len_;
614 *out_json = args;
615 return Status::OK();
616 }
617
618 // TruncateSequencePairOperation
TruncateSequencePairOperation(int32_t max_length)619 TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
620
ValidateParams()621 Status TruncateSequencePairOperation::ValidateParams() {
622 if (max_length_ < 0) {
623 std::string err_msg = "TruncateSequencePair : The parameter max_length must be greater than or equal to 0: " +
624 std::to_string(max_length_);
625 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
626 }
627
628 return Status::OK();
629 }
630
Build()631 std::shared_ptr<TensorOp> TruncateSequencePairOperation::Build() {
632 std::shared_ptr<TruncateSequencePairOp> tensor_op = std::make_shared<TruncateSequencePairOp>(max_length_);
633 return tensor_op;
634 }
635
636 // UnicodeCharTokenizerOperation
UnicodeCharTokenizerOperation(bool with_offsets)637 UnicodeCharTokenizerOperation::UnicodeCharTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
638
ValidateParams()639 Status UnicodeCharTokenizerOperation::ValidateParams() { return Status::OK(); }
640
Build()641 std::shared_ptr<TensorOp> UnicodeCharTokenizerOperation::Build() {
642 std::shared_ptr<UnicodeCharTokenizerOp> tensor_op = std::make_shared<UnicodeCharTokenizerOp>(with_offsets_);
643 return tensor_op;
644 }
645
646 // WordpieceTokenizerOperation
WordpieceTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool with_offsets)647 WordpieceTokenizerOperation::WordpieceTokenizerOperation(const std::shared_ptr<Vocab> &vocab,
648 const std::string &suffix_indicator,
649 int32_t max_bytes_per_token, const std::string &unknown_token,
650 bool with_offsets)
651 : vocab_(vocab),
652 suffix_indicator_(suffix_indicator),
653 max_bytes_per_token_(max_bytes_per_token),
654 unknown_token_(unknown_token),
655 with_offsets_(with_offsets) {}
656
ValidateParams()657 Status WordpieceTokenizerOperation::ValidateParams() {
658 if (vocab_ == nullptr) {
659 std::string err_msg = "WordpieceTokenizer: vocab object type is incorrect or null.";
660 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
661 }
662 if (max_bytes_per_token_ < 0) {
663 std::string err_msg =
664 "WordpieceTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
665 std::to_string(max_bytes_per_token_);
666 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
667 }
668 return Status::OK();
669 }
670
Build()671 std::shared_ptr<TensorOp> WordpieceTokenizerOperation::Build() {
672 std::shared_ptr<WordpieceTokenizerOp> tensor_op = std::make_shared<WordpieceTokenizerOp>(
673 vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, with_offsets_);
674 return tensor_op;
675 }
676
677 #ifndef _WIN32
678 // UnicodeScriptTokenizerOperation
UnicodeScriptTokenizerOperation(bool keep_whitespace,bool with_offsets)679 UnicodeScriptTokenizerOperation::UnicodeScriptTokenizerOperation(bool keep_whitespace, bool with_offsets)
680 : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {}
681
ValidateParams()682 Status UnicodeScriptTokenizerOperation::ValidateParams() { return Status::OK(); }
683
Build()684 std::shared_ptr<TensorOp> UnicodeScriptTokenizerOperation::Build() {
685 std::shared_ptr<UnicodeScriptTokenizerOp> tensor_op =
686 std::make_shared<UnicodeScriptTokenizerOp>(keep_whitespace_, with_offsets_);
687 return tensor_op;
688 }
689
690 // WhitespaceTokenizerOperation
WhitespaceTokenizerOperation(bool with_offsets)691 WhitespaceTokenizerOperation::WhitespaceTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
692
ValidateParams()693 Status WhitespaceTokenizerOperation::ValidateParams() { return Status::OK(); }
694
Build()695 std::shared_ptr<TensorOp> WhitespaceTokenizerOperation::Build() {
696 std::shared_ptr<WhitespaceTokenizerOp> tensor_op = std::make_shared<WhitespaceTokenizerOp>(with_offsets_);
697 return tensor_op;
698 }
699 #endif
700 } // namespace text
701 } // namespace dataset
702 } // namespace mindspore
703