1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_ 19 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "minddata/dataset/kernels/ir/tensor_operation.h" 27 28 namespace mindspore { 29 namespace dataset { 30 class SentencePieceVocab; 31 class Vectors; 32 class Vocab; 33 34 // Transform operations for text 35 namespace text { 36 constexpr int kStatusSum = 4; 37 // Char arrays storing name of corresponding classes (in alphabetical order) 38 constexpr char kAddTokenOperation[] = "AddToken"; 39 constexpr char kBasicTokenizerOperation[] = "BasicTokenizer"; 40 constexpr char kBertTokenizerOperation[] = "BertTokenizer"; 41 constexpr char kCaseFoldOperation[] = "CaseFold"; 42 constexpr char kFilterWikipediaXMLOperation[] = "FilterWikipediaXML"; 43 constexpr char kJiebaTokenizerOperation[] = "JiebaTokenizer"; 44 constexpr char kLookupOperation[] = "Lookup"; 45 constexpr char kNgramOperation[] = "Ngram"; 46 constexpr char kNormalizeUTF8Operation[] = "NormalizeUTF8"; 47 constexpr char kRegexReplaceOperation[] = "RegexReplace"; 48 constexpr char kRegexTokenizerOperation[] = "RegexTokenizer"; 49 constexpr char kSentencepieceTokenizerOperation[] = "SentencepieceTokenizer"; 50 constexpr char kSlidingWindowOperation[] = "SlidingWindow"; 51 constexpr char kToNumberOperation[] = "ToNumber"; 52 constexpr char kToVectorsOperation[] = "ToVectors"; 53 constexpr char kTruncateOperation[] = "Truncate"; 54 constexpr char kTruncateSequencePairOperation[] = "TruncateSequencePair"; 55 constexpr char kUnicodeCharTokenizerOperation[] = "UnicodeCharTokenizer"; 56 constexpr char kUnicodeScriptTokenizerOperation[] = "UnicodeScriptTokenizer"; 57 constexpr char kWhitespaceTokenizerOperation[] = "WhitespaceTokenizer"; 58 constexpr char kWordpieceTokenizerOperation[] = "WordpieceTokenizer"; 59 60 /* ####################################### Derived TensorOperation classes ################################# */ 61 62 class AddTokenOperation : public TensorOperation { 63 public: 64 /// \brief Constructor. 65 /// \param[in] token The token to be added. 66 /// \param[in] begin Whether to insert token at start or end of sequence. 67 AddTokenOperation(const std::string &token, bool begin); 68 69 ~AddTokenOperation(); 70 71 std::shared_ptr<TensorOp> Build() override; 72 73 Status ValidateParams() override; 74 75 std::string Name() const override; 76 77 Status to_json(nlohmann::json *out_json) override; 78 79 private: 80 std::string token_; 81 bool begin_; 82 }; 83 84 #ifndef _WIN32 85 class BasicTokenizerOperation : public TensorOperation { 86 public: 87 BasicTokenizerOperation(bool lower_case, bool keep_whitespace, const NormalizeForm normalize_form, 88 bool preserve_unused_token, bool with_offsets); 89 90 ~BasicTokenizerOperation() = default; 91 92 std::shared_ptr<TensorOp> Build() override; 93 94 Status ValidateParams() override; 95 Name()96 std::string Name() const override { return kBasicTokenizerOperation; } 97 98 private: 99 bool lower_case_; 100 bool keep_whitespace_; 101 NormalizeForm normalize_form_; 102 bool preserve_unused_token_; 103 bool with_offsets_; 104 }; 105 106 class BertTokenizerOperation : public TensorOperation { 107 public: 108 BertTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator, 109 int32_t max_bytes_per_token, const std::string &unknown_token, bool lower_case, 110 bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token, 111 bool with_offsets); 112 113 ~BertTokenizerOperation(); 114 115 std::shared_ptr<TensorOp> Build() override; 116 117 Status ValidateParams() override; 118 Name()119 std::string Name() const override { return kBertTokenizerOperation; } 120 121 private: 122 std::shared_ptr<Vocab> vocab_; 123 std::string suffix_indicator_; 124 int32_t max_bytes_per_token_; 125 std::string unknown_token_; 126 bool lower_case_; 127 bool keep_whitespace_; 128 NormalizeForm normalize_form_; 129 bool preserve_unused_token_; 130 bool with_offsets_; 131 }; 132 133 class CaseFoldOperation : public TensorOperation { 134 public: 135 CaseFoldOperation() = default; 136 137 ~CaseFoldOperation() = default; 138 139 std::shared_ptr<TensorOp> Build() override; 140 141 Status ValidateParams() override; 142 Name()143 std::string Name() const override { return kCaseFoldOperation; } 144 }; 145 146 class FilterWikipediaXMLOperation : public TensorOperation { 147 public: 148 FilterWikipediaXMLOperation(); 149 150 ~FilterWikipediaXMLOperation() = default; 151 152 std::shared_ptr<TensorOp> Build() override; 153 154 Status ValidateParams() override; 155 Name()156 std::string Name() const override { return kFilterWikipediaXMLOperation; } 157 }; 158 #endif 159 160 class JiebaTokenizerOperation : public TensorOperation { 161 public: 162 explicit JiebaTokenizerOperation(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode, 163 bool with_offsets); 164 165 ~JiebaTokenizerOperation() = default; 166 167 std::shared_ptr<TensorOp> Build() override; 168 169 Status ValidateParams() override; 170 Name()171 std::string Name() const override { return kJiebaTokenizerOperation; } 172 173 Status AddWord(const std::string &word, int64_t freq = 0); 174 175 private: 176 std::string hmm_path_; 177 std::string mp_path_; 178 JiebaMode mode_; 179 bool with_offsets_; 180 std::vector<std::pair<std::string, int64_t>> words_list_; 181 }; 182 183 class LookupOperation : public TensorOperation { 184 public: 185 explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, 186 const DataType &data_type); // Used for C++ API 187 explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, 188 const std::string &data_type); // Used for Pybind 189 190 ~LookupOperation(); 191 192 std::shared_ptr<TensorOp> Build() override; 193 194 Status ValidateParams() override; 195 Name()196 std::string Name() const override { return kLookupOperation; } 197 198 private: 199 std::shared_ptr<Vocab> vocab_; 200 std::optional<std::string> unknown_token_; 201 int32_t default_id_; 202 DataType data_type_; 203 }; 204 205 class NgramOperation : public TensorOperation { 206 public: 207 explicit NgramOperation(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad, 208 const std::pair<std::string, int32_t> &right_pad, const std::string &separator); 209 210 ~NgramOperation() = default; 211 212 std::shared_ptr<TensorOp> Build() override; 213 214 Status ValidateParams() override; 215 Name()216 std::string Name() const override { return kNgramOperation; } 217 218 private: 219 std::vector<int32_t> ngrams_; 220 std::pair<std::string, int32_t> left_pad_; 221 std::pair<std::string, int32_t> right_pad_; 222 std::string separator_; 223 }; 224 225 #ifndef _WIN32 226 class NormalizeUTF8Operation : public TensorOperation { 227 public: 228 explicit NormalizeUTF8Operation(NormalizeForm normalize_form); 229 230 ~NormalizeUTF8Operation() = default; 231 232 std::shared_ptr<TensorOp> Build() override; 233 234 Status ValidateParams() override; 235 Name()236 std::string Name() const override { return kNormalizeUTF8Operation; } 237 238 private: 239 NormalizeForm normalize_form_; 240 }; 241 242 class RegexReplaceOperation : public TensorOperation { 243 public: 244 RegexReplaceOperation(std::string pattern, std::string replace, bool replace_all); 245 246 ~RegexReplaceOperation() = default; 247 248 std::shared_ptr<TensorOp> Build() override; 249 250 Status ValidateParams() override; 251 Name()252 std::string Name() const override { return kRegexReplaceOperation; } 253 254 private: 255 std::string pattern_; 256 std::string replace_; 257 bool replace_all_; 258 }; 259 260 class RegexTokenizerOperation : public TensorOperation { 261 public: 262 explicit RegexTokenizerOperation(std::string delim_pattern, std::string keep_delim_pattern, bool with_offsets); 263 264 ~RegexTokenizerOperation() = default; 265 266 std::shared_ptr<TensorOp> Build() override; 267 268 Status ValidateParams() override; 269 Name()270 std::string Name() const override { return kRegexTokenizerOperation; } 271 272 private: 273 std::string delim_pattern_; 274 std::string keep_delim_pattern_; 275 bool with_offsets_; 276 }; 277 #endif 278 279 class SentencePieceTokenizerOperation : public TensorOperation { 280 public: 281 SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab, SPieceTokenizerOutType out_type); 282 283 SentencePieceTokenizerOperation(const std::string &vocab_path, SPieceTokenizerOutType out_type); 284 285 ~SentencePieceTokenizerOperation(); 286 287 std::shared_ptr<TensorOp> Build() override; 288 289 Status ValidateParams() override; 290 Name()291 std::string Name() const override { return kSentencepieceTokenizerOperation; } 292 293 private: 294 std::shared_ptr<SentencePieceVocab> vocab_; 295 std::string vocab_path_; 296 SPieceTokenizerLoadType load_type_; 297 SPieceTokenizerOutType out_type_; 298 }; 299 300 class SlidingWindowOperation : public TensorOperation { 301 public: 302 explicit SlidingWindowOperation(const int32_t width, const int32_t axis); 303 304 ~SlidingWindowOperation() = default; 305 306 std::shared_ptr<TensorOp> Build() override; 307 308 Status ValidateParams() override; 309 Name()310 std::string Name() const override { return kSlidingWindowOperation; } 311 312 private: 313 int32_t width_; 314 int32_t axis_; 315 }; 316 317 class ToNumberOperation : public TensorOperation { 318 public: 319 explicit ToNumberOperation(const DataType &data_type); // Used for C++ API 320 explicit ToNumberOperation(const std::string &data_type); // Used for Pybind 321 322 ~ToNumberOperation() = default; 323 324 std::shared_ptr<TensorOp> Build() override; 325 326 Status ValidateParams() override; 327 Name()328 std::string Name() const override { return kToNumberOperation; } 329 330 Status to_json(nlohmann::json *out_json) override; 331 332 static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); 333 334 private: 335 DataType data_type_; 336 }; 337 338 class ToVectorsOperation : public TensorOperation { 339 public: 340 ToVectorsOperation(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init, 341 bool lower_case_backup); 342 343 ~ToVectorsOperation(); 344 345 std::shared_ptr<TensorOp> Build() override; 346 347 Status ValidateParams() override; 348 Name()349 std::string Name() const override { return kToVectorsOperation; } 350 351 private: 352 std::shared_ptr<Vectors> vectors_; 353 std::vector<float> unk_init_; 354 bool lower_case_backup_; 355 }; 356 357 class TruncateOperation : public TensorOperation { 358 public: 359 explicit TruncateOperation(int32_t max_seq_len); 360 361 ~TruncateOperation() = default; 362 363 std::shared_ptr<TensorOp> Build() override; 364 365 Status ValidateParams() override; 366 Name()367 std::string Name() const override { return kTruncateOperation; } 368 369 Status to_json(nlohmann::json *out_json) override; 370 371 private: 372 int32_t max_seq_len_; 373 }; 374 375 class TruncateSequencePairOperation : public TensorOperation { 376 public: 377 explicit TruncateSequencePairOperation(int32_t max_length); 378 379 ~TruncateSequencePairOperation() = default; 380 381 std::shared_ptr<TensorOp> Build() override; 382 383 Status ValidateParams() override; 384 Name()385 std::string Name() const override { return kTruncateSequencePairOperation; } 386 387 private: 388 int32_t max_length_; 389 }; 390 391 class UnicodeCharTokenizerOperation : public TensorOperation { 392 public: 393 explicit UnicodeCharTokenizerOperation(bool with_offsets); 394 395 ~UnicodeCharTokenizerOperation() = default; 396 397 std::shared_ptr<TensorOp> Build() override; 398 399 Status ValidateParams() override; 400 Name()401 std::string Name() const override { return kUnicodeCharTokenizerOperation; } 402 403 private: 404 bool with_offsets_; 405 }; 406 407 class WordpieceTokenizerOperation : public TensorOperation { 408 public: 409 explicit WordpieceTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator, 410 int32_t max_bytes_per_token, const std::string &unknown_token, 411 bool with_offsets); 412 413 ~WordpieceTokenizerOperation() = default; 414 415 std::shared_ptr<TensorOp> Build() override; 416 417 Status ValidateParams() override; 418 Name()419 std::string Name() const override { return kWordpieceTokenizerOperation; } 420 421 private: 422 std::shared_ptr<Vocab> vocab_; 423 std::string suffix_indicator_; 424 int32_t max_bytes_per_token_; 425 std::string unknown_token_; 426 bool with_offsets_; 427 }; 428 429 #ifndef _WIN32 430 class UnicodeScriptTokenizerOperation : public TensorOperation { 431 public: 432 explicit UnicodeScriptTokenizerOperation(bool keep_whitespace, bool with_offsets); 433 434 ~UnicodeScriptTokenizerOperation() = default; 435 436 std::shared_ptr<TensorOp> Build() override; 437 438 Status ValidateParams() override; 439 Name()440 std::string Name() const override { return kUnicodeScriptTokenizerOperation; } 441 442 private: 443 bool keep_whitespace_; 444 bool with_offsets_; 445 }; 446 447 class WhitespaceTokenizerOperation : public TensorOperation { 448 public: 449 explicit WhitespaceTokenizerOperation(bool with_offsets); 450 451 ~WhitespaceTokenizerOperation() = default; 452 453 std::shared_ptr<TensorOp> Build() override; 454 455 Status ValidateParams() override; 456 Name()457 std::string Name() const override { return kWhitespaceTokenizerOperation; } 458 459 private: 460 bool with_offsets_; 461 }; 462 #endif 463 } // namespace text 464 } // namespace dataset 465 } // namespace mindspore 466 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_ 467