• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/text/sentence_piece_vocab.h"
18 
19 #include <sentencepiece_trainer.h>
20 #include <sentencepiece_processor.h>
21 #include <fstream>
22 
23 #include "utils/file_utils.h"
24 #include "utils/ms_utils.h"
25 #include "utils/utils.h"
26 #include "minddata/dataset/util/path.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 
SentencePieceVocab()31 SentencePieceVocab::SentencePieceVocab() : model_proto_("") {}
32 
BuildFromFile(const std::vector<std::string> & path_list,const int32_t vocab_size,const float character_coverage,const SentencePieceModel model_type,const std::unordered_map<std::string,std::string> & params,std::shared_ptr<SentencePieceVocab> * vocab)33 Status SentencePieceVocab::BuildFromFile(const std::vector<std::string> &path_list, const int32_t vocab_size,
34                                          const float character_coverage, const SentencePieceModel model_type,
35                                          const std::unordered_map<std::string, std::string> &params,
36                                          std::shared_ptr<SentencePieceVocab> *vocab) {
37   if (vocab == nullptr) {
38     RETURN_STATUS_UNEXPECTED("SentencePieceVocab::BuildFromFile: input vocab can not be null");
39   }
40   std::unordered_map<std::string, std::string> unorder_map;
41 
42   // the input of sentence is comma separated string
43   std::string input_str = "";
44   for (auto path : path_list) {
45     input_str += path;
46     input_str += ",";
47   }
48   input_str.pop_back();
49   unorder_map["input"] = input_str;
50   unorder_map["vocab_size"] = std::to_string(vocab_size);
51   unorder_map["model_prefix"] = "";
52   unorder_map["minloglevel"] = "1";
53   unorder_map["character_coverage"] = std::to_string(character_coverage);
54   if (model_type == SentencePieceModel::kWord) {
55     unorder_map["model_type"] = "WORD";
56   } else if (model_type == SentencePieceModel::kBpe) {
57     unorder_map["model_type"] = "BPE";
58   } else if (model_type == SentencePieceModel::kChar) {
59     unorder_map["model_type"] = "CHAR";
60   } else {
61     unorder_map["model_type"] = "UNIGRAM";
62   }
63 
64   // filter some params that set by function param
65   // filter  model_prefix that must be empty
66   for (auto param : params) {
67     std::string key = param.first;
68     if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
69         key == "model_type") {
70       continue;
71     }
72     unorder_map[key] = param.second;
73   }
74 
75   // set sentence lib's log
76   unorder_map["minloglevel"] = "1";
77   *vocab = std::make_shared<SentencePieceVocab>();
78   std::string model_proto;
79   sentencepiece::util::Status s_status = sentencepiece::SentencePieceTrainer::Train(unorder_map, nullptr, &model_proto);
80   if (!s_status.ok()) {
81     std::string err_msg = "SentencePieceVocab: " + std::string(s_status.message());
82     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
83   }
84   vocab->get()->set_model_proto(model_proto);
85 
86   return Status::OK();
87 }
88 
SaveModel(const std::shared_ptr<SentencePieceVocab> * vocab,std::string path,std::string filename)89 Status SentencePieceVocab::SaveModel(const std::shared_ptr<SentencePieceVocab> *vocab, std::string path,
90                                      std::string filename) {
91   if (vocab == nullptr) {
92     RETURN_STATUS_UNEXPECTED("SentencePieceVocab::SaveModel: input vocab can not be null");
93   }
94   auto realpath = FileUtils::GetRealPath(path.data());
95   if (!realpath.has_value()) {
96     RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + path);
97   }
98 
99   std::optional<std::string> whole_path = "";
100   std::optional<std::string> local_file_name = filename;
101   FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
102 
103   std::ofstream os_file(whole_path.value(), std::ios::out);
104   (void)os_file.write(vocab->get()->model_proto().data(), vocab->get()->model_proto().size());
105   os_file.close();
106 
107   ChangeFileMode(whole_path.value(), S_IRUSR | S_IWUSR);
108 
109   return Status::OK();
110 }
111 
model_proto()112 const std::string &SentencePieceVocab::model_proto() { return model_proto_; }
113 
set_model_proto(const std::string model_proto)114 void SentencePieceVocab::set_model_proto(const std::string model_proto) { model_proto_ = model_proto; }
115 
116 }  // namespace dataset
117 }  // namespace mindspore
118