1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <unistd.h>
18 #include "minddata/dataset/text/ir/kernels/text_ir.h"
19
20 #ifndef _WIN32
21 #include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
22 #include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
23 #include "minddata/dataset/text/kernels/case_fold_op.h"
24 #endif
25 #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
26 #include "minddata/dataset/text/kernels/lookup_op.h"
27 #include "minddata/dataset/text/kernels/ngram_op.h"
28 #ifndef _WIN32
29 #include "minddata/dataset/text/kernels/normalize_utf8_op.h"
30 #include "minddata/dataset/text/kernels/regex_replace_op.h"
31 #include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
32 #endif
33 #include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
34 #include "minddata/dataset/text/kernels/sliding_window_op.h"
35 #include "minddata/dataset/text/kernels/to_number_op.h"
36 #include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
37 #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
38 #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
39 #ifndef _WIN32
40 #include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
41 #include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
42 #endif
43 #include "minddata/dataset/core/data_type.h"
44 #include "minddata/dataset/core/type_id.h"
45 #include "minddata/dataset/util/path.h"
46
47 #include "minddata/dataset/text/ir/validators.h"
48
49 namespace mindspore {
50 namespace dataset {
51 // Transform operations for text.
52 namespace text {
53 /* ####################################### Derived TensorOperation classes ################################# */
54
55 // (In alphabetical order)
56
57 #ifndef _WIN32
58 // BasicTokenizerOperation
BasicTokenizerOperation(bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)59 BasicTokenizerOperation::BasicTokenizerOperation(bool lower_case, bool keep_whitespace,
60 const NormalizeForm normalize_form, bool preserve_unused_token,
61 bool with_offsets)
62 : lower_case_(lower_case),
63 keep_whitespace_(keep_whitespace),
64 normalize_form_(normalize_form),
65 preserve_unused_token_(preserve_unused_token),
66 with_offsets_(with_offsets) {}
67
ValidateParams()68 Status BasicTokenizerOperation::ValidateParams() {
69 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
70 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
71 normalize_form_ != NormalizeForm::kNfkd) {
72 std::string err_msg = "BasicTokenizer: Invalid NormalizeForm, check input value of enum.";
73 MS_LOG(ERROR) << err_msg;
74 RETURN_STATUS_SYNTAX_ERROR(err_msg);
75 }
76 return Status::OK();
77 }
78
Build()79 std::shared_ptr<TensorOp> BasicTokenizerOperation::Build() {
80 std::shared_ptr<BasicTokenizerOp> tensor_op = std::make_shared<BasicTokenizerOp>(
81 lower_case_, keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
82 return tensor_op;
83 }
84
85 // BertTokenizerOperation
BertTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool lower_case,bool keep_whitespace,const NormalizeForm normalize_form,bool preserve_unused_token,bool with_offsets)86 BertTokenizerOperation::BertTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
87 int32_t max_bytes_per_token, const std::string &unknown_token,
88 bool lower_case, bool keep_whitespace,
89 const NormalizeForm normalize_form, bool preserve_unused_token,
90 bool with_offsets)
91 : vocab_(vocab),
92 suffix_indicator_(suffix_indicator),
93 max_bytes_per_token_(max_bytes_per_token),
94 unknown_token_(unknown_token),
95 lower_case_(lower_case),
96 keep_whitespace_(keep_whitespace),
97 normalize_form_(normalize_form),
98 preserve_unused_token_(preserve_unused_token),
99 with_offsets_(with_offsets) {}
100
101 BertTokenizerOperation::~BertTokenizerOperation() = default;
102
ValidateParams()103 Status BertTokenizerOperation::ValidateParams() {
104 if (vocab_ == nullptr) {
105 std::string err_msg = "BertTokenizer: vocab object type is incorrect or null.";
106 MS_LOG(ERROR) << err_msg;
107 RETURN_STATUS_SYNTAX_ERROR(err_msg);
108 }
109
110 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
111 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
112 normalize_form_ != NormalizeForm::kNfkd) {
113 std::string err_msg = "BertTokenizer: Invalid NormalizeForm, check input value of enum.";
114 MS_LOG(ERROR) << err_msg;
115 RETURN_STATUS_SYNTAX_ERROR(err_msg);
116 }
117
118 if (max_bytes_per_token_ < 0) {
119 std::string err_msg = "BertTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
120 std::to_string(max_bytes_per_token_);
121 MS_LOG(ERROR) << err_msg;
122 RETURN_STATUS_SYNTAX_ERROR(err_msg);
123 }
124
125 return Status::OK();
126 }
127
Build()128 std::shared_ptr<TensorOp> BertTokenizerOperation::Build() {
129 std::shared_ptr<BertTokenizerOp> tensor_op =
130 std::make_shared<BertTokenizerOp>(vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, lower_case_,
131 keep_whitespace_, normalize_form_, preserve_unused_token_, with_offsets_);
132 return tensor_op;
133 }
134
135 // CaseFoldOperation
ValidateParams()136 Status CaseFoldOperation::ValidateParams() { return Status::OK(); }
137
Build()138 std::shared_ptr<TensorOp> CaseFoldOperation::Build() {
139 std::shared_ptr<CaseFoldOp> tensor_op = std::make_shared<CaseFoldOp>();
140 return tensor_op;
141 }
142 #endif
143
144 // JiebaTokenizerOperation
JiebaTokenizerOperation(const std::string & hmm_path,const std::string & mp_path,const JiebaMode & mode,bool with_offsets)145 JiebaTokenizerOperation::JiebaTokenizerOperation(const std::string &hmm_path, const std::string &mp_path,
146 const JiebaMode &mode, bool with_offsets)
147 : hmm_path_(hmm_path), mp_path_(mp_path), mode_(mode), with_offsets_(with_offsets) {}
148
ValidateParams()149 Status JiebaTokenizerOperation::ValidateParams() {
150 if (hmm_path_.empty()) {
151 std::string err_msg = "JiebaTokenizer: The dict of HMMSegment in cppjieba is not provided.";
152 MS_LOG(ERROR) << err_msg;
153 RETURN_STATUS_SYNTAX_ERROR(err_msg);
154 }
155
156 if (mp_path_.empty()) {
157 std::string err_msg = "JiebaTokenizer: The dict of MPSegment in cppjieba is not provided.";
158 MS_LOG(ERROR) << err_msg;
159 RETURN_STATUS_SYNTAX_ERROR(err_msg);
160 }
161
162 if (mode_ != JiebaMode::kMix && mode_ != JiebaMode::kMp && mode_ != JiebaMode::kHmm) {
163 std::string err_msg = "JiebaTokenizer: Invalid JiebaMode, check input value of enum.";
164 MS_LOG(ERROR) << err_msg;
165 RETURN_STATUS_SYNTAX_ERROR(err_msg);
166 }
167
168 RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", hmm_path_));
169 RETURN_IF_NOT_OK(ValidateTokenizerDirParam("JiebaTokenizer", mp_path_));
170 return Status::OK();
171 }
172
Build()173 std::shared_ptr<TensorOp> JiebaTokenizerOperation::Build() {
174 std::shared_ptr<JiebaTokenizerOp> tensor_op =
175 std::make_shared<JiebaTokenizerOp>(hmm_path_, mp_path_, mode_, with_offsets_);
176 for (auto &word : words_list_) {
177 Status rc = tensor_op->AddWord(word.first, word.second);
178 if (rc.IsError()) {
179 MS_LOG(ERROR) << rc;
180 return {};
181 }
182 }
183 return tensor_op;
184 }
185
AddWord(const std::string & word,int64_t freq)186 Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) {
187 words_list_.emplace_back(word, freq);
188 return Status::OK();
189 }
190
191 // LookupOperation
192 // DataType data_type - required for C++ API
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const DataType & data_type)193 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
194 const DataType &data_type)
195 : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
196
197 // std::string data_type - required for Pybind
LookupOperation(const std::shared_ptr<Vocab> & vocab,const std::optional<std::string> & unknown_token,const std::string & data_type)198 LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
199 const std::string &data_type)
200 : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {
201 // Convert from string to DEType
202 DataType temp_data_type(data_type);
203 data_type_ = temp_data_type;
204 }
205
206 LookupOperation::~LookupOperation() = default;
207
ValidateParams()208 Status LookupOperation::ValidateParams() {
209 if (vocab_ == nullptr) {
210 std::string err_msg = "Lookup: vocab object type is incorrect or null.";
211 MS_LOG(ERROR) << err_msg;
212 RETURN_STATUS_SYNTAX_ERROR(err_msg);
213 }
214 if (unknown_token_ != std::nullopt) {
215 default_id_ = vocab_->Lookup(*unknown_token_);
216 if (default_id_ == Vocab::kNoTokenExists) {
217 std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab.";
218 MS_LOG(ERROR) << err_msg;
219 RETURN_STATUS_SYNTAX_ERROR(err_msg);
220 }
221 }
222
223 if (!data_type_.IsNumeric()) {
224 // Note: For DEType, Bool is counted as numeric, and is a valid type for Lookup
225 std::string err_msg = "Lookup : The parameter data_type must be numeric including bool.";
226 MS_LOG(ERROR) << err_msg;
227 RETURN_STATUS_SYNTAX_ERROR(err_msg);
228 }
229
230 return Status::OK();
231 }
232
Build()233 std::shared_ptr<TensorOp> LookupOperation::Build() {
234 std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, DataType(data_type_));
235 return tensor_op;
236 }
237
238 // NgramOperation
NgramOperation(const std::vector<int32_t> & ngrams,const std::pair<std::string,int32_t> & left_pad,const std::pair<std::string,int32_t> & right_pad,const std::string & separator)239 NgramOperation::NgramOperation(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad,
240 const std::pair<std::string, int32_t> &right_pad, const std::string &separator)
241 : ngrams_(ngrams), left_pad_(left_pad), right_pad_(right_pad), separator_(separator) {}
242
ValidateParams()243 Status NgramOperation::ValidateParams() {
244 if (ngrams_.size() == 0) {
245 std::string err_msg = "Ngram : Container cannot be empty.";
246 MS_LOG(ERROR) << err_msg;
247 RETURN_STATUS_SYNTAX_ERROR(err_msg);
248 } else {
249 for (int32_t i = 0; i < ngrams_.size(); ++i) {
250 if (ngrams_[i] <= 0) {
251 std::string err_msg =
252 "Ngram : The value of ngrams vector must be greater than 0: " + std::to_string(ngrams_[i]);
253 MS_LOG(ERROR) << err_msg;
254 RETURN_STATUS_SYNTAX_ERROR(err_msg);
255 }
256 }
257 }
258
259 if (left_pad_.second < 0) {
260 std::string err_msg =
261 "Ngram : The second parameter pad_width in left_pad vector must be greater than or equal to 0: " +
262 std::to_string(left_pad_.second);
263 MS_LOG(ERROR) << err_msg;
264 RETURN_STATUS_SYNTAX_ERROR(err_msg);
265 }
266
267 if (right_pad_.second < 0) {
268 std::string err_msg =
269 "Ngram : The second parameter pad_width in right_pad vector must be greater than or equal to 0: " +
270 std::to_string(right_pad_.second);
271 MS_LOG(ERROR) << err_msg;
272 RETURN_STATUS_SYNTAX_ERROR(err_msg);
273 }
274 return Status::OK();
275 }
276
Build()277 std::shared_ptr<TensorOp> NgramOperation::Build() {
278 int32_t l_len = left_pad_.second;
279 int32_t r_len = right_pad_.second;
280 std::string l_pad = left_pad_.first;
281 std::string r_pad = right_pad_.first;
282 std::shared_ptr<NgramOp> tensor_op = std::make_shared<NgramOp>(ngrams_, l_len, l_pad, r_len, r_pad, separator_);
283 return tensor_op;
284 }
285
286 #ifndef _WIN32
287 // NormalizeUTF8Operation
NormalizeUTF8Operation(NormalizeForm normalize_form)288 NormalizeUTF8Operation::NormalizeUTF8Operation(NormalizeForm normalize_form) : normalize_form_(normalize_form) {}
289
ValidateParams()290 Status NormalizeUTF8Operation::ValidateParams() {
291 if (normalize_form_ != NormalizeForm::kNone && normalize_form_ != NormalizeForm::kNfc &&
292 normalize_form_ != NormalizeForm::kNfkc && normalize_form_ != NormalizeForm::kNfd &&
293 normalize_form_ != NormalizeForm::kNfkd) {
294 std::string err_msg = "NormalizeUTF8: Invalid NormalizeForm, check input value of enum.";
295 MS_LOG(ERROR) << err_msg;
296 RETURN_STATUS_SYNTAX_ERROR(err_msg);
297 }
298 return Status::OK();
299 }
300
Build()301 std::shared_ptr<TensorOp> NormalizeUTF8Operation::Build() {
302 std::shared_ptr<NormalizeUTF8Op> tensor_op = std::make_shared<NormalizeUTF8Op>(normalize_form_);
303 return tensor_op;
304 }
305
306 // RegexReplaceOperation
RegexReplaceOperation(std::string pattern,std::string replace,bool replace_all)307 RegexReplaceOperation::RegexReplaceOperation(std::string pattern, std::string replace, bool replace_all)
308 : pattern_(pattern), replace_(replace), replace_all_(replace_all) {}
309
ValidateParams()310 Status RegexReplaceOperation::ValidateParams() { return Status::OK(); }
311
Build()312 std::shared_ptr<TensorOp> RegexReplaceOperation::Build() {
313 std::shared_ptr<RegexReplaceOp> tensor_op = std::make_shared<RegexReplaceOp>(pattern_, replace_, replace_all_);
314 return tensor_op;
315 }
316
317 // RegexTokenizerOperation
RegexTokenizerOperation(std::string delim_pattern,std::string keep_delim_pattern,bool with_offsets)318 RegexTokenizerOperation::RegexTokenizerOperation(std::string delim_pattern, std::string keep_delim_pattern,
319 bool with_offsets)
320 : delim_pattern_(delim_pattern), keep_delim_pattern_(keep_delim_pattern), with_offsets_(with_offsets) {}
321
ValidateParams()322 Status RegexTokenizerOperation::ValidateParams() { return Status::OK(); }
323
Build()324 std::shared_ptr<TensorOp> RegexTokenizerOperation::Build() {
325 std::shared_ptr<RegexTokenizerOp> tensor_op =
326 std::make_shared<RegexTokenizerOp>(delim_pattern_, keep_delim_pattern_, with_offsets_);
327 return tensor_op;
328 }
329 #endif
330
331 // SentencePieceTokenizerOperation
332 SentencePieceTokenizerOperation::~SentencePieceTokenizerOperation() = default;
333
SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> & vocab,SPieceTokenizerOutType out_type)334 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab,
335 SPieceTokenizerOutType out_type)
336 : vocab_(vocab), vocab_path_(std::string()), load_type_(SPieceTokenizerLoadType::kModel), out_type_(out_type) {}
337
SentencePieceTokenizerOperation(const std::string & vocab_path,SPieceTokenizerOutType out_type)338 SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::string &vocab_path,
339 SPieceTokenizerOutType out_type)
340 : vocab_(nullptr), vocab_path_(vocab_path), load_type_(SPieceTokenizerLoadType::kFile), out_type_(out_type) {}
341
ValidateParams()342 Status SentencePieceTokenizerOperation::ValidateParams() {
343 if (out_type_ != SPieceTokenizerOutType::kString && out_type_ != SPieceTokenizerOutType::kInt) {
344 std::string err_msg = "SentencePieceTokenizer: Invalid SPieceTokenizerOutType, check input value of enum.";
345 MS_LOG(ERROR) << err_msg;
346 RETURN_STATUS_SYNTAX_ERROR(err_msg);
347 }
348 if (load_type_ == SPieceTokenizerLoadType::kModel) {
349 if (vocab_ == nullptr) {
350 std::string err_msg = "SentencePieceTokenizer: vocab object type is incorrect or null.";
351 MS_LOG(ERROR) << err_msg;
352 RETURN_STATUS_SYNTAX_ERROR(err_msg);
353 }
354 } else {
355 std::string real_vocab_path;
356 RETURN_IF_NOT_OK(Path::RealPath(vocab_path_, real_vocab_path));
357 Path vocab_file(real_vocab_path);
358 if (!vocab_file.Exists() || vocab_file.IsDirectory()) {
359 std::string err_msg = "SentencePieceTokenizer : vocab file: [" + vocab_path_ + "] is invalid or does not exist.";
360 MS_LOG(ERROR) << err_msg;
361 RETURN_STATUS_SYNTAX_ERROR(err_msg);
362 }
363 if (access(vocab_file.ToString().c_str(), R_OK) == -1) {
364 std::string err_msg = "SentencePieceTokenizer : no access to specified dataset file: " + vocab_path_;
365 MS_LOG(ERROR) << err_msg;
366 RETURN_STATUS_SYNTAX_ERROR(err_msg);
367 }
368 }
369 return Status::OK();
370 }
371
Build()372 std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
373 std::shared_ptr<SentencePieceTokenizerOp> tensor_op;
374 if (load_type_ == SPieceTokenizerLoadType::kModel) {
375 tensor_op = std::make_shared<SentencePieceTokenizerOp>(vocab_, load_type_, out_type_);
376 } else {
377 Path vocab_file(vocab_path_);
378 std::string model_path = vocab_file.ParentPath();
379 std::string model_filename = vocab_file.Basename();
380 tensor_op = std::make_shared<SentencePieceTokenizerOp>(model_path, model_filename, load_type_, out_type_);
381 }
382 return tensor_op;
383 }
384
385 // SlidingWindowOperation
SlidingWindowOperation(const int32_t width,const int32_t axis)386 SlidingWindowOperation::SlidingWindowOperation(const int32_t width, const int32_t axis) : width_(width), axis_(axis) {}
387
ValidateParams()388 Status SlidingWindowOperation::ValidateParams() {
389 if (width_ < 1) {
390 std::string err_msg =
391 "SlidingWindow : The parameter width must be greater than or equal to 1: " + std::to_string(width_);
392 MS_LOG(ERROR) << err_msg;
393 RETURN_STATUS_SYNTAX_ERROR(err_msg);
394 }
395 return Status::OK();
396 }
397
Build()398 std::shared_ptr<TensorOp> SlidingWindowOperation::Build() {
399 std::shared_ptr<SlidingWindowOp> tensor_op = std::make_shared<SlidingWindowOp>(static_cast<uint32_t>(width_), axis_);
400 return tensor_op;
401 }
402
403 // ToNumberOperation
404 // DataType data_type - required for C++ API
ToNumberOperation(const DataType & data_type)405 ToNumberOperation::ToNumberOperation(const DataType &data_type) : data_type_(data_type) {}
406
407 // std::string data_type - required for Pybind
ToNumberOperation(const std::string & data_type)408 ToNumberOperation::ToNumberOperation(const std::string &data_type) {
409 // Convert from string to DEType
410 DataType temp_data_type(data_type);
411 data_type_ = temp_data_type;
412 }
413
ValidateParams()414 Status ToNumberOperation::ValidateParams() {
415 if (!data_type_.IsNumeric() || data_type_.IsBool()) {
416 // Note: For DEType, Bool is counted as numeric, but is not a valid type for ToNumber.
417 std::string err_msg = "ToNumber : The parameter data_type must be numeric and excludes bool.";
418 MS_LOG(ERROR) << err_msg;
419 RETURN_STATUS_SYNTAX_ERROR(err_msg);
420 }
421
422 return Status::OK();
423 }
424
Build()425 std::shared_ptr<TensorOp> ToNumberOperation::Build() {
426 std::shared_ptr<ToNumberOp> tensor_op = std::make_shared<ToNumberOp>(data_type_);
427 return tensor_op;
428 }
429
to_json(nlohmann::json * out_json)430 Status ToNumberOperation::to_json(nlohmann::json *out_json) {
431 nlohmann::json args;
432 args["data_type"] = data_type_.ToString();
433 *out_json = args;
434 return Status::OK();
435 }
436
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)437 Status ToNumberOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
438 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("data_type") != op_params.end(), "Failed to find data_type");
439 std::string data_type = op_params["data_type"];
440 *operation = std::make_shared<text::ToNumberOperation>(data_type);
441 return Status::OK();
442 }
443
444 // TruncateSequencePairOperation
TruncateSequencePairOperation(int32_t max_length)445 TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
446
ValidateParams()447 Status TruncateSequencePairOperation::ValidateParams() {
448 if (max_length_ < 0) {
449 std::string err_msg = "TruncateSequencePair : The parameter max_length must be greater than or equal to 0: " +
450 std::to_string(max_length_);
451 MS_LOG(ERROR) << err_msg;
452 RETURN_STATUS_SYNTAX_ERROR(err_msg);
453 }
454
455 return Status::OK();
456 }
457
Build()458 std::shared_ptr<TensorOp> TruncateSequencePairOperation::Build() {
459 std::shared_ptr<TruncateSequencePairOp> tensor_op = std::make_shared<TruncateSequencePairOp>(max_length_);
460 return tensor_op;
461 }
462
463 // UnicodeCharTokenizerOperation
UnicodeCharTokenizerOperation(bool with_offsets)464 UnicodeCharTokenizerOperation::UnicodeCharTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
465
ValidateParams()466 Status UnicodeCharTokenizerOperation::ValidateParams() { return Status::OK(); }
467
Build()468 std::shared_ptr<TensorOp> UnicodeCharTokenizerOperation::Build() {
469 std::shared_ptr<UnicodeCharTokenizerOp> tensor_op = std::make_shared<UnicodeCharTokenizerOp>(with_offsets_);
470 return tensor_op;
471 }
472
473 // WordpieceTokenizerOperation
WordpieceTokenizerOperation(const std::shared_ptr<Vocab> & vocab,const std::string & suffix_indicator,int32_t max_bytes_per_token,const std::string & unknown_token,bool with_offsets)474 WordpieceTokenizerOperation::WordpieceTokenizerOperation(const std::shared_ptr<Vocab> &vocab,
475 const std::string &suffix_indicator,
476 int32_t max_bytes_per_token, const std::string &unknown_token,
477 bool with_offsets)
478 : vocab_(vocab),
479 suffix_indicator_(suffix_indicator),
480 max_bytes_per_token_(max_bytes_per_token),
481 unknown_token_(unknown_token),
482 with_offsets_(with_offsets) {}
483
ValidateParams()484 Status WordpieceTokenizerOperation::ValidateParams() {
485 if (vocab_ == nullptr) {
486 std::string err_msg = "WordpieceTokenizer: vocab object type is incorrect or null.";
487 MS_LOG(ERROR) << err_msg;
488 RETURN_STATUS_SYNTAX_ERROR(err_msg);
489 }
490 if (max_bytes_per_token_ < 0) {
491 std::string err_msg =
492 "WordpieceTokenizer : The parameter max_bytes_per_token must be greater than or equal to 0: " +
493 std::to_string(max_bytes_per_token_);
494 MS_LOG(ERROR) << err_msg;
495 RETURN_STATUS_SYNTAX_ERROR(err_msg);
496 }
497 return Status::OK();
498 }
499
Build()500 std::shared_ptr<TensorOp> WordpieceTokenizerOperation::Build() {
501 std::shared_ptr<WordpieceTokenizerOp> tensor_op = std::make_shared<WordpieceTokenizerOp>(
502 vocab_, suffix_indicator_, max_bytes_per_token_, unknown_token_, with_offsets_);
503 return tensor_op;
504 }
505
506 #ifndef _WIN32
507 // UnicodeScriptTokenizerOperation
UnicodeScriptTokenizerOperation(bool keep_whitespace,bool with_offsets)508 UnicodeScriptTokenizerOperation::UnicodeScriptTokenizerOperation(bool keep_whitespace, bool with_offsets)
509 : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {}
510
ValidateParams()511 Status UnicodeScriptTokenizerOperation::ValidateParams() { return Status::OK(); }
512
Build()513 std::shared_ptr<TensorOp> UnicodeScriptTokenizerOperation::Build() {
514 std::shared_ptr<UnicodeScriptTokenizerOp> tensor_op =
515 std::make_shared<UnicodeScriptTokenizerOp>(keep_whitespace_, with_offsets_);
516 return tensor_op;
517 }
518
519 // WhitespaceTokenizerOperation
WhitespaceTokenizerOperation(bool with_offsets)520 WhitespaceTokenizerOperation::WhitespaceTokenizerOperation(bool with_offsets) : with_offsets_(with_offsets) {}
521
ValidateParams()522 Status WhitespaceTokenizerOperation::ValidateParams() { return Status::OK(); }
523
Build()524 std::shared_ptr<TensorOp> WhitespaceTokenizerOperation::Build() {
525 std::shared_ptr<WhitespaceTokenizerOp> tensor_op = std::make_shared<WhitespaceTokenizerOp>(with_offsets_);
526 return tensor_op;
527 }
528 #endif
529 } // namespace text
530 } // namespace dataset
531 } // namespace mindspore
532