1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/text/sentence_piece_vocab.h"
18
19 #include <sentencepiece_trainer.h>
20 #include <sentencepiece_processor.h>
21 #include <fstream>
22
23 #include "utils/file_utils.h"
24 #include "utils/ms_utils.h"
25 #include "utils/utils.h"
26 #include "minddata/dataset/util/path.h"
27
28 namespace mindspore {
29 namespace dataset {
30
SentencePieceVocab()31 SentencePieceVocab::SentencePieceVocab() : model_proto_("") {}
32
BuildFromFile(const std::vector<std::string> & path_list,const int32_t vocab_size,const float character_coverage,const SentencePieceModel model_type,const std::unordered_map<std::string,std::string> & params,std::shared_ptr<SentencePieceVocab> * vocab)33 Status SentencePieceVocab::BuildFromFile(const std::vector<std::string> &path_list, const int32_t vocab_size,
34 const float character_coverage, const SentencePieceModel model_type,
35 const std::unordered_map<std::string, std::string> ¶ms,
36 std::shared_ptr<SentencePieceVocab> *vocab) {
37 if (vocab == nullptr) {
38 RETURN_STATUS_UNEXPECTED("SentencePieceVocab::BuildFromFile: input vocab can not be null");
39 }
40 std::unordered_map<std::string, std::string> unorder_map;
41
42 // the input of sentence is comma separated string
43 std::string input_str = "";
44 for (auto path : path_list) {
45 input_str += path;
46 input_str += ",";
47 }
48 input_str.pop_back();
49 unorder_map["input"] = input_str;
50 unorder_map["vocab_size"] = std::to_string(vocab_size);
51 unorder_map["model_prefix"] = "";
52 unorder_map["minloglevel"] = "1";
53 unorder_map["character_coverage"] = std::to_string(character_coverage);
54 if (model_type == SentencePieceModel::kWord) {
55 unorder_map["model_type"] = "WORD";
56 } else if (model_type == SentencePieceModel::kBpe) {
57 unorder_map["model_type"] = "BPE";
58 } else if (model_type == SentencePieceModel::kChar) {
59 unorder_map["model_type"] = "CHAR";
60 } else {
61 unorder_map["model_type"] = "UNIGRAM";
62 }
63
64 // filter some params that set by function param
65 // filter model_prefix that must be empty
66 for (auto param : params) {
67 std::string key = param.first;
68 if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
69 key == "model_type") {
70 continue;
71 }
72 unorder_map[key] = param.second;
73 }
74
75 // set sentence lib's log
76 unorder_map["minloglevel"] = "1";
77 *vocab = std::make_shared<SentencePieceVocab>();
78 std::string model_proto;
79 sentencepiece::util::Status s_status = sentencepiece::SentencePieceTrainer::Train(unorder_map, nullptr, &model_proto);
80 if (!s_status.ok()) {
81 std::string err_msg = "SentencePieceVocab: " + std::string(s_status.message());
82 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
83 }
84 vocab->get()->set_model_proto(model_proto);
85
86 return Status::OK();
87 }
88
SaveModel(const std::shared_ptr<SentencePieceVocab> * vocab,std::string path,std::string filename)89 Status SentencePieceVocab::SaveModel(const std::shared_ptr<SentencePieceVocab> *vocab, std::string path,
90 std::string filename) {
91 if (vocab == nullptr) {
92 RETURN_STATUS_UNEXPECTED("SentencePieceVocab::SaveModel: input vocab can not be null");
93 }
94 auto realpath = FileUtils::GetRealPath(path.data());
95 if (!realpath.has_value()) {
96 RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + path);
97 }
98
99 std::optional<std::string> whole_path = "";
100 std::optional<std::string> local_file_name = filename;
101 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
102
103 std::ofstream os_file(whole_path.value(), std::ios::out);
104 (void)os_file.write(vocab->get()->model_proto().data(), vocab->get()->model_proto().size());
105 os_file.close();
106
107 ChangeFileMode(whole_path.value(), S_IRUSR | S_IWUSR);
108
109 return Status::OK();
110 }
111
model_proto()112 const std::string &SentencePieceVocab::model_proto() { return model_proto_; }
113
set_model_proto(const std::string model_proto)114 void SentencePieceVocab::set_model_proto(const std::string model_proto) { model_proto_ = model_proto; }
115
116 } // namespace dataset
117 } // namespace mindspore
118