• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_
19 
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "minddata/dataset/kernels/ir/tensor_operation.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 class SentencePieceVocab;
31 class Vectors;
32 class Vocab;
33 
34 // Transform operations for text
35 namespace text {
36 constexpr int kStatusSum = 4;
37 // Char arrays storing name of corresponding classes (in alphabetical order)
38 constexpr char kAddTokenOperation[] = "AddToken";
39 constexpr char kBasicTokenizerOperation[] = "BasicTokenizer";
40 constexpr char kBertTokenizerOperation[] = "BertTokenizer";
41 constexpr char kCaseFoldOperation[] = "CaseFold";
42 constexpr char kFilterWikipediaXMLOperation[] = "FilterWikipediaXML";
43 constexpr char kJiebaTokenizerOperation[] = "JiebaTokenizer";
44 constexpr char kLookupOperation[] = "Lookup";
45 constexpr char kNgramOperation[] = "Ngram";
46 constexpr char kNormalizeUTF8Operation[] = "NormalizeUTF8";
47 constexpr char kRegexReplaceOperation[] = "RegexReplace";
48 constexpr char kRegexTokenizerOperation[] = "RegexTokenizer";
49 constexpr char kSentencepieceTokenizerOperation[] = "SentencepieceTokenizer";
50 constexpr char kSlidingWindowOperation[] = "SlidingWindow";
51 constexpr char kToNumberOperation[] = "ToNumber";
52 constexpr char kToVectorsOperation[] = "ToVectors";
53 constexpr char kTruncateOperation[] = "Truncate";
54 constexpr char kTruncateSequencePairOperation[] = "TruncateSequencePair";
55 constexpr char kUnicodeCharTokenizerOperation[] = "UnicodeCharTokenizer";
56 constexpr char kUnicodeScriptTokenizerOperation[] = "UnicodeScriptTokenizer";
57 constexpr char kWhitespaceTokenizerOperation[] = "WhitespaceTokenizer";
58 constexpr char kWordpieceTokenizerOperation[] = "WordpieceTokenizer";
59 
60 /* ####################################### Derived TensorOperation classes ################################# */
61 
62 class AddTokenOperation : public TensorOperation {
63  public:
64   /// \brief Constructor.
65   /// \param[in] token The token to be added.
66   /// \param[in] begin Whether to insert token at start or end of sequence.
67   AddTokenOperation(const std::string &token, bool begin);
68 
69   ~AddTokenOperation();
70 
71   std::shared_ptr<TensorOp> Build() override;
72 
73   Status ValidateParams() override;
74 
75   std::string Name() const override;
76 
77   Status to_json(nlohmann::json *out_json) override;
78 
79  private:
80   std::string token_;
81   bool begin_;
82 };
83 
84 #ifndef _WIN32
85 class BasicTokenizerOperation : public TensorOperation {
86  public:
87   BasicTokenizerOperation(bool lower_case, bool keep_whitespace, const NormalizeForm normalize_form,
88                           bool preserve_unused_token, bool with_offsets);
89 
90   ~BasicTokenizerOperation() = default;
91 
92   std::shared_ptr<TensorOp> Build() override;
93 
94   Status ValidateParams() override;
95 
Name()96   std::string Name() const override { return kBasicTokenizerOperation; }
97 
98  private:
99   bool lower_case_;
100   bool keep_whitespace_;
101   NormalizeForm normalize_form_;
102   bool preserve_unused_token_;
103   bool with_offsets_;
104 };
105 
106 class BertTokenizerOperation : public TensorOperation {
107  public:
108   BertTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
109                          int32_t max_bytes_per_token, const std::string &unknown_token, bool lower_case,
110                          bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token,
111                          bool with_offsets);
112 
113   ~BertTokenizerOperation();
114 
115   std::shared_ptr<TensorOp> Build() override;
116 
117   Status ValidateParams() override;
118 
Name()119   std::string Name() const override { return kBertTokenizerOperation; }
120 
121  private:
122   std::shared_ptr<Vocab> vocab_;
123   std::string suffix_indicator_;
124   int32_t max_bytes_per_token_;
125   std::string unknown_token_;
126   bool lower_case_;
127   bool keep_whitespace_;
128   NormalizeForm normalize_form_;
129   bool preserve_unused_token_;
130   bool with_offsets_;
131 };
132 
133 class CaseFoldOperation : public TensorOperation {
134  public:
135   CaseFoldOperation() = default;
136 
137   ~CaseFoldOperation() = default;
138 
139   std::shared_ptr<TensorOp> Build() override;
140 
141   Status ValidateParams() override;
142 
Name()143   std::string Name() const override { return kCaseFoldOperation; }
144 };
145 
146 class FilterWikipediaXMLOperation : public TensorOperation {
147  public:
148   FilterWikipediaXMLOperation();
149 
150   ~FilterWikipediaXMLOperation() = default;
151 
152   std::shared_ptr<TensorOp> Build() override;
153 
154   Status ValidateParams() override;
155 
Name()156   std::string Name() const override { return kFilterWikipediaXMLOperation; }
157 };
158 #endif
159 
160 class JiebaTokenizerOperation : public TensorOperation {
161  public:
162   explicit JiebaTokenizerOperation(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode,
163                                    bool with_offsets);
164 
165   ~JiebaTokenizerOperation() = default;
166 
167   std::shared_ptr<TensorOp> Build() override;
168 
169   Status ValidateParams() override;
170 
Name()171   std::string Name() const override { return kJiebaTokenizerOperation; }
172 
173   Status AddWord(const std::string &word, int64_t freq = 0);
174 
175  private:
176   std::string hmm_path_;
177   std::string mp_path_;
178   JiebaMode mode_;
179   bool with_offsets_;
180   std::vector<std::pair<std::string, int64_t>> words_list_;
181 };
182 
183 class LookupOperation : public TensorOperation {
184  public:
185   explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
186                            const DataType &data_type);  // Used for C++ API
187   explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
188                            const std::string &data_type);  // Used for Pybind
189 
190   ~LookupOperation();
191 
192   std::shared_ptr<TensorOp> Build() override;
193 
194   Status ValidateParams() override;
195 
Name()196   std::string Name() const override { return kLookupOperation; }
197 
198  private:
199   std::shared_ptr<Vocab> vocab_;
200   std::optional<std::string> unknown_token_;
201   int32_t default_id_;
202   DataType data_type_;
203 };
204 
205 class NgramOperation : public TensorOperation {
206  public:
207   explicit NgramOperation(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad,
208                           const std::pair<std::string, int32_t> &right_pad, const std::string &separator);
209 
210   ~NgramOperation() = default;
211 
212   std::shared_ptr<TensorOp> Build() override;
213 
214   Status ValidateParams() override;
215 
Name()216   std::string Name() const override { return kNgramOperation; }
217 
218  private:
219   std::vector<int32_t> ngrams_;
220   std::pair<std::string, int32_t> left_pad_;
221   std::pair<std::string, int32_t> right_pad_;
222   std::string separator_;
223 };
224 
225 #ifndef _WIN32
226 class NormalizeUTF8Operation : public TensorOperation {
227  public:
228   explicit NormalizeUTF8Operation(NormalizeForm normalize_form);
229 
230   ~NormalizeUTF8Operation() = default;
231 
232   std::shared_ptr<TensorOp> Build() override;
233 
234   Status ValidateParams() override;
235 
Name()236   std::string Name() const override { return kNormalizeUTF8Operation; }
237 
238  private:
239   NormalizeForm normalize_form_;
240 };
241 
242 class RegexReplaceOperation : public TensorOperation {
243  public:
244   RegexReplaceOperation(std::string pattern, std::string replace, bool replace_all);
245 
246   ~RegexReplaceOperation() = default;
247 
248   std::shared_ptr<TensorOp> Build() override;
249 
250   Status ValidateParams() override;
251 
Name()252   std::string Name() const override { return kRegexReplaceOperation; }
253 
254  private:
255   std::string pattern_;
256   std::string replace_;
257   bool replace_all_;
258 };
259 
260 class RegexTokenizerOperation : public TensorOperation {
261  public:
262   explicit RegexTokenizerOperation(std::string delim_pattern, std::string keep_delim_pattern, bool with_offsets);
263 
264   ~RegexTokenizerOperation() = default;
265 
266   std::shared_ptr<TensorOp> Build() override;
267 
268   Status ValidateParams() override;
269 
Name()270   std::string Name() const override { return kRegexTokenizerOperation; }
271 
272  private:
273   std::string delim_pattern_;
274   std::string keep_delim_pattern_;
275   bool with_offsets_;
276 };
277 #endif
278 
279 class SentencePieceTokenizerOperation : public TensorOperation {
280  public:
281   SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab, SPieceTokenizerOutType out_type);
282 
283   SentencePieceTokenizerOperation(const std::string &vocab_path, SPieceTokenizerOutType out_type);
284 
285   ~SentencePieceTokenizerOperation();
286 
287   std::shared_ptr<TensorOp> Build() override;
288 
289   Status ValidateParams() override;
290 
Name()291   std::string Name() const override { return kSentencepieceTokenizerOperation; }
292 
293  private:
294   std::shared_ptr<SentencePieceVocab> vocab_;
295   std::string vocab_path_;
296   SPieceTokenizerLoadType load_type_;
297   SPieceTokenizerOutType out_type_;
298 };
299 
300 class SlidingWindowOperation : public TensorOperation {
301  public:
302   explicit SlidingWindowOperation(const int32_t width, const int32_t axis);
303 
304   ~SlidingWindowOperation() = default;
305 
306   std::shared_ptr<TensorOp> Build() override;
307 
308   Status ValidateParams() override;
309 
Name()310   std::string Name() const override { return kSlidingWindowOperation; }
311 
312  private:
313   int32_t width_;
314   int32_t axis_;
315 };
316 
317 class ToNumberOperation : public TensorOperation {
318  public:
319   explicit ToNumberOperation(const DataType &data_type);     // Used for C++ API
320   explicit ToNumberOperation(const std::string &data_type);  // Used for Pybind
321 
322   ~ToNumberOperation() = default;
323 
324   std::shared_ptr<TensorOp> Build() override;
325 
326   Status ValidateParams() override;
327 
Name()328   std::string Name() const override { return kToNumberOperation; }
329 
330   Status to_json(nlohmann::json *out_json) override;
331 
332   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
333 
334  private:
335   DataType data_type_;
336 };
337 
338 class ToVectorsOperation : public TensorOperation {
339  public:
340   ToVectorsOperation(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init,
341                      bool lower_case_backup);
342 
343   ~ToVectorsOperation();
344 
345   std::shared_ptr<TensorOp> Build() override;
346 
347   Status ValidateParams() override;
348 
Name()349   std::string Name() const override { return kToVectorsOperation; }
350 
351  private:
352   std::shared_ptr<Vectors> vectors_;
353   std::vector<float> unk_init_;
354   bool lower_case_backup_;
355 };
356 
357 class TruncateOperation : public TensorOperation {
358  public:
359   explicit TruncateOperation(int32_t max_seq_len);
360 
361   ~TruncateOperation() = default;
362 
363   std::shared_ptr<TensorOp> Build() override;
364 
365   Status ValidateParams() override;
366 
Name()367   std::string Name() const override { return kTruncateOperation; }
368 
369   Status to_json(nlohmann::json *out_json) override;
370 
371  private:
372   int32_t max_seq_len_;
373 };
374 
375 class TruncateSequencePairOperation : public TensorOperation {
376  public:
377   explicit TruncateSequencePairOperation(int32_t max_length);
378 
379   ~TruncateSequencePairOperation() = default;
380 
381   std::shared_ptr<TensorOp> Build() override;
382 
383   Status ValidateParams() override;
384 
Name()385   std::string Name() const override { return kTruncateSequencePairOperation; }
386 
387  private:
388   int32_t max_length_;
389 };
390 
391 class UnicodeCharTokenizerOperation : public TensorOperation {
392  public:
393   explicit UnicodeCharTokenizerOperation(bool with_offsets);
394 
395   ~UnicodeCharTokenizerOperation() = default;
396 
397   std::shared_ptr<TensorOp> Build() override;
398 
399   Status ValidateParams() override;
400 
Name()401   std::string Name() const override { return kUnicodeCharTokenizerOperation; }
402 
403  private:
404   bool with_offsets_;
405 };
406 
407 class WordpieceTokenizerOperation : public TensorOperation {
408  public:
409   explicit WordpieceTokenizerOperation(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
410                                        int32_t max_bytes_per_token, const std::string &unknown_token,
411                                        bool with_offsets);
412 
413   ~WordpieceTokenizerOperation() = default;
414 
415   std::shared_ptr<TensorOp> Build() override;
416 
417   Status ValidateParams() override;
418 
Name()419   std::string Name() const override { return kWordpieceTokenizerOperation; }
420 
421  private:
422   std::shared_ptr<Vocab> vocab_;
423   std::string suffix_indicator_;
424   int32_t max_bytes_per_token_;
425   std::string unknown_token_;
426   bool with_offsets_;
427 };
428 
429 #ifndef _WIN32
430 class UnicodeScriptTokenizerOperation : public TensorOperation {
431  public:
432   explicit UnicodeScriptTokenizerOperation(bool keep_whitespace, bool with_offsets);
433 
434   ~UnicodeScriptTokenizerOperation() = default;
435 
436   std::shared_ptr<TensorOp> Build() override;
437 
438   Status ValidateParams() override;
439 
Name()440   std::string Name() const override { return kUnicodeScriptTokenizerOperation; }
441 
442  private:
443   bool keep_whitespace_;
444   bool with_offsets_;
445 };
446 
447 class WhitespaceTokenizerOperation : public TensorOperation {
448  public:
449   explicit WhitespaceTokenizerOperation(bool with_offsets);
450 
451   ~WhitespaceTokenizerOperation() = default;
452 
453   std::shared_ptr<TensorOp> Build() override;
454 
455   Status ValidateParams() override;
456 
Name()457   std::string Name() const override { return kWhitespaceTokenizerOperation; }
458 
459  private:
460   bool with_offsets_;
461 };
462 #endif
463 }  // namespace text
464 }  // namespace dataset
465 }  // namespace mindspore
466 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_IR_KERNELS_TEXT_IR_H_
467