1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ 18 #define DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ 19 20 #include <sentencepiece_trainer.h> 21 #include <sentencepiece_processor.h> 22 #include <vector> 23 #include <memory> 24 #include <unordered_map> 25 #include <string> 26 #include <utility> 27 28 #include "minddata/dataset/core/tensor.h" 29 #include "minddata/dataset/engine/dataset_iterator.h" 30 #include "minddata/dataset/engine/datasetops/pipeline_op.h" 31 #include "minddata/dataset/util/status.h" 32 #include "minddata/dataset/util/queue.h" 33 #include "minddata/dataset/text/sentence_piece_vocab.h" 34 #include "pybind11/pybind11.h" 35 36 namespace mindspore { 37 namespace dataset { 38 namespace py = pybind11; 39 40 class BuildSentencePieceVocabOp : public PipelineOp { 41 public: 42 class DatasetSentenceIterator : public sentencepiece::SentenceIterator { 43 public: 44 explicit DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr); ~DatasetSentenceIterator()45 ~DatasetSentenceIterator() {} 46 47 bool done() const override; 48 void Next() override; value()49 const std::string &value() const override { return value_; } status()50 sentencepiece::util::Status status() const override { return sentencepiece::util::Status(); } 51 52 private: 53 std::string value_; 54 BuildSentencePieceVocabOp *s_p_vocab_ptr_; 55 }; 56 57 BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab, std::vector<std::string> col_names, 58 int32_t vocab_size, float character_coverage, SentencePieceModel model_type, 59 const std::unordered_map<std::string, std::string> ¶ms, int32_t op_conn_size); 60 61 ~BuildSentencePieceVocabOp() = default; 62 63 // the thread for sentence train 64 Status SentenceThread(); 65 EofReceived(int32_t)66 Status EofReceived(int32_t) override { return Status::OK(); } 67 EoeReceived(int32_t)68 Status EoeReceived(int32_t) override { return Status::OK(); } 69 70 Status operator()() override; 71 72 // Getter 73 // @return the number of workers NumProducers()74 int32_t NumProducers() const override { return 1; } 75 76 // Getter 77 // @return the number of threads consuming from the previous Connector NumConsumers()78 int32_t NumConsumers() const override { return 1; } 79 Reset()80 Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); } 81 Name()82 std::string Name() const override { return kBuildSentencePieceVocabOp; } 83 84 // build the input params for sentence api 85 std::unordered_map<std::string, std::string> BuildParams(); 86 87 bool Done(); 88 void Next(std::string *sentence); 89 90 private: 91 bool read_done_; 92 Status ret_status_; 93 int32_t vocab_size_; 94 float character_coverage_; 95 SentencePieceModel model_type_; 96 std::unordered_map<std::string, std::string> params_; 97 std::shared_ptr<SentencePieceVocab> vocab_; 98 std::vector<std::string> col_names_; 99 uint32_t col_id_; 100 std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1 101 std::unique_ptr<Queue<TensorRow>> sentence_queue_; // master thread assigns each worker TensorRow via this 102 }; 103 } // namespace dataset 104 } // namespace mindspore 105 #endif // DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ 106