• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id.h"
18 
19 #include <stdio.h>
20 
21 #include <algorithm>
22 #include <limits>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 #include "common/algorithm.h"
28 #include "common/embedding-network-params-from-proto.h"
29 #include "common/embedding-network.pb.h"
30 #include "common/embedding-network.h"
31 #include "common/feature-extractor.h"
32 #include "common/file-utils.h"
33 #include "common/list-of-strings.pb.h"
34 #include "common/memory_image/in-memory-model-data.h"
35 #include "common/mmap.h"
36 #include "common/softmax.h"
37 #include "common/task-context.h"
38 #include "lang_id/custom-tokenizer.h"
39 #include "lang_id/lang-id-brain-interface.h"
40 #include "lang_id/language-identifier-features.h"
41 #include "lang_id/light-sentence-features.h"
42 #include "lang_id/light-sentence.h"
43 #include "lang_id/relevant-script-feature.h"
44 #include "util/base/logging.h"
45 #include "util/base/macros.h"
46 
47 using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory;
48 
49 namespace libtextclassifier {
50 namespace nlp_core {
51 namespace lang_id {
52 
53 namespace {
54 // Default value for the probability threshold; see comments for
55 // LangId::SetProbabilityThreshold().
56 static const float kDefaultProbabilityThreshold = 0.50;
57 
58 // Default value for min text size below which our model can't provide a
59 // meaningful prediction.
60 static const int kDefaultMinTextSizeInBytes = 20;
61 
62 // Initial value for the default language for LangId::FindLanguage().  The
63 // default language can be changed (for an individual LangId object) using
64 // LangId::SetDefaultLanguage().
65 static const char kInitialDefaultLanguage[] = "";
66 
67 // Returns total number of bytes of the words from sentence, without the ^
68 // (start-of-word) and $ (end-of-word) markers.  Note: "real text" means that
69 // this ignores whitespace and punctuation characters from the original text.
GetRealTextSize(const LightSentence & sentence)70 int GetRealTextSize(const LightSentence &sentence) {
71   int total = 0;
72   for (int i = 0; i < sentence.num_words(); ++i) {
73     TC_DCHECK(!sentence.word(i).empty());
74     TC_DCHECK_EQ('^', sentence.word(i).front());
75     TC_DCHECK_EQ('$', sentence.word(i).back());
76     total += sentence.word(i).size() - 2;
77   }
78   return total;
79 }
80 
81 }  // namespace
82 
83 // Class that performs all work behind LangId.
84 class LangIdImpl {
85  public:
LangIdImpl(const std::string & filename)86   explicit LangIdImpl(const std::string &filename) {
87     // Using mmap as a fast way to read the model bytes.
88     ScopedMmap scoped_mmap(filename);
89     MmapHandle mmap_handle = scoped_mmap.handle();
90     if (!mmap_handle.ok()) {
91       TC_LOG(ERROR) << "Unable to read model bytes.";
92       return;
93     }
94 
95     Initialize(mmap_handle.to_stringpiece());
96   }
97 
LangIdImpl(int fd)98   explicit LangIdImpl(int fd) {
99     // Using mmap as a fast way to read the model bytes.
100     ScopedMmap scoped_mmap(fd);
101     MmapHandle mmap_handle = scoped_mmap.handle();
102     if (!mmap_handle.ok()) {
103       TC_LOG(ERROR) << "Unable to read model bytes.";
104       return;
105     }
106 
107     Initialize(mmap_handle.to_stringpiece());
108   }
109 
LangIdImpl(const char * ptr,size_t length)110   LangIdImpl(const char *ptr, size_t length) {
111     Initialize(StringPiece(ptr, length));
112   }
113 
Initialize(StringPiece model_bytes)114   void Initialize(StringPiece model_bytes) {
115     // Will set valid_ to true only on successful initialization.
116     valid_ = false;
117 
118     // Make sure all relevant features are registered:
119     ContinuousBagOfNgramsFunction::RegisterClass();
120     RelevantScriptFeature::RegisterClass();
121 
122     // NOTE(salcianu): code below relies on the fact that the current features
123     // do not rely on data from a TaskInput.  Otherwise, one would have to use
124     // the more complex model registration mechanism, which requires more code.
125     InMemoryModelData model_data(model_bytes);
126     TaskContext context;
127     if (!model_data.GetTaskSpec(context.mutable_spec())) {
128       TC_LOG(ERROR) << "Unable to get model TaskSpec";
129       return;
130     }
131 
132     if (!ParseNetworkParams(model_data, &context)) {
133       return;
134     }
135     if (!ParseListOfKnownLanguages(model_data, &context)) {
136       return;
137     }
138 
139     network_.reset(new EmbeddingNetwork(network_params_.get()));
140     if (!network_->is_valid()) {
141       return;
142     }
143 
144     probability_threshold_ =
145         context.Get("reliability_thresh", kDefaultProbabilityThreshold);
146     min_text_size_in_bytes_ =
147         context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes);
148     version_ = context.Get("version", 0);
149 
150     if (!lang_id_brain_interface_.Init(&context)) {
151       return;
152     }
153     valid_ = true;
154   }
155 
SetProbabilityThreshold(float threshold)156   void SetProbabilityThreshold(float threshold) {
157     probability_threshold_ = threshold;
158   }
159 
SetDefaultLanguage(const std::string & lang)160   void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; }
161 
FindLanguage(const std::string & text) const162   std::string FindLanguage(const std::string &text) const {
163     std::vector<float> scores = ScoreLanguages(text);
164     if (scores.empty()) {
165       return default_language_;
166     }
167 
168     // Softmax label with max score.
169     int label = GetArgMax(scores);
170     float probability = scores[label];
171     if (probability < probability_threshold_) {
172       return default_language_;
173     }
174     return GetLanguageForSoftmaxLabel(label);
175   }
176 
FindLanguages(const std::string & text) const177   std::vector<std::pair<std::string, float>> FindLanguages(
178       const std::string &text) const {
179     std::vector<float> scores = ScoreLanguages(text);
180 
181     std::vector<std::pair<std::string, float>> result;
182     for (int i = 0; i < scores.size(); i++) {
183       result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]});
184     }
185 
186     // To avoid crashing clients that always expect at least one predicted
187     // language, we promised (see doc for this method) that the result always
188     // contains at least one element.
189     if (result.empty()) {
190       // We use a tiny probability, such that any client that uses a meaningful
191       // probability threshold ignores this prediction.  We don't use 0.0f, to
192       // avoid crashing clients that normalize the probabilities we return here.
193       result.push_back({default_language_, 0.001f});
194     }
195     return result;
196   }
197 
ScoreLanguages(const std::string & text) const198   std::vector<float> ScoreLanguages(const std::string &text) const {
199     if (!is_valid()) {
200       return {};
201     }
202 
203     // Create a Sentence storing the input text.
204     LightSentence sentence;
205     TokenizeTextForLangId(text, &sentence);
206 
207     if (GetRealTextSize(sentence) < min_text_size_in_bytes_) {
208       return {};
209     }
210 
211     // TODO(salcianu): reuse vector<FeatureVector>.
212     std::vector<FeatureVector> features(
213         lang_id_brain_interface_.NumEmbeddings());
214     lang_id_brain_interface_.GetFeatures(&sentence, &features);
215 
216     // Predict language.
217     EmbeddingNetwork::Vector scores;
218     network_->ComputeFinalScores(features, &scores);
219 
220     return ComputeSoftmax(scores);
221   }
222 
is_valid() const223   bool is_valid() const { return valid_; }
224 
version() const225   int version() const { return version_; }
226 
227  private:
228   // Returns name of the (in-memory) file for the indicated TaskInput from
229   // context.
GetInMemoryFileNameForTaskInput(const std::string & input_name,TaskContext * context)230   static std::string GetInMemoryFileNameForTaskInput(
231       const std::string &input_name, TaskContext *context) {
232     TaskInput *task_input = context->GetInput(input_name);
233     if (task_input->part_size() != 1) {
234       TC_LOG(ERROR) << "TaskInput " << input_name << " has "
235                     << task_input->part_size() << " parts";
236       return "";
237     }
238     return task_input->part(0).file_pattern();
239   }
240 
ParseNetworkParams(const InMemoryModelData & model_data,TaskContext * context)241   bool ParseNetworkParams(const InMemoryModelData &model_data,
242                           TaskContext *context) {
243     const std::string input_name = "language-identifier-network";
244     const std::string input_file_name =
245         GetInMemoryFileNameForTaskInput(input_name, context);
246     if (input_file_name.empty()) {
247       TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
248       return false;
249     }
250     StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
251     if (bytes.data() == nullptr) {
252       TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
253       return false;
254     }
255     std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto());
256     if (!ParseProtoFromMemory(bytes, proto.get())) {
257       TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto";
258       return false;
259     }
260     network_params_.reset(
261         new EmbeddingNetworkParamsFromProto(std::move(proto)));
262     if (!network_params_->is_valid()) {
263       TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid";
264       return false;
265     }
266     return true;
267   }
268 
269   // Parses dictionary with known languages (i.e., field languages_) from a
270   // TaskInput of context.  Note: that TaskInput should be a ListOfStrings proto
271   // with a single element, the serialized form of a ListOfStrings.
272   //
ParseListOfKnownLanguages(const InMemoryModelData & model_data,TaskContext * context)273   bool ParseListOfKnownLanguages(const InMemoryModelData &model_data,
274                                  TaskContext *context) {
275     const std::string input_name = "language-name-id-map";
276     const std::string input_file_name =
277         GetInMemoryFileNameForTaskInput(input_name, context);
278     if (input_file_name.empty()) {
279       TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
280       return false;
281     }
282     StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
283     if (bytes.data() == nullptr) {
284       TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
285       return false;
286     }
287     ListOfStrings records;
288     if (!ParseProtoFromMemory(bytes, &records)) {
289       TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput "
290                     << input_name;
291       return false;
292     }
293     if (records.element_size() != 1) {
294       TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name
295                     << " : " << records.element_size();
296       return false;
297     }
298     if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) {
299       TC_LOG(ERROR) << "Unable to parse dictionary with known languages";
300       return false;
301     }
302     return true;
303   }
304 
305   // Returns language code for a softmax label.  See comments for languages_
306   // field.  If label is out of range, returns default_language_.
GetLanguageForSoftmaxLabel(int label) const307   std::string GetLanguageForSoftmaxLabel(int label) const {
308     if ((label >= 0) && (label < languages_.element_size())) {
309       return languages_.element(label);
310     } else {
311       TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
312                     << languages_.element_size() << ")";
313       return default_language_;
314     }
315   }
316 
317   LangIdBrainInterface lang_id_brain_interface_;
318 
319   // Parameters for the neural network network_ (see below).
320   std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_;
321 
322   // Neural network to use for scoring.
323   std::unique_ptr<EmbeddingNetwork> network_;
324 
325   // True if this object is ready to perform language predictions.
326   bool valid_;
327 
328   // Only predictions with a probability (confidence) above this threshold are
329   // reported.  Otherwise, we report default_language_.
330   float probability_threshold_ = kDefaultProbabilityThreshold;
331 
332   // Min size of the input text for our predictions to be meaningful.  Below
333   // this threshold, the underlying model may report a wrong language and a high
334   // confidence score.
335   int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes;
336 
337   // Version of the model.
338   int version_ = -1;
339 
340   // Known languages: softmax label i (an integer) means languages_.element(i)
341   // (something like "en", "fr", "ru", etc).
342   ListOfStrings languages_;
343 
344   // Language code to return in case of errors.
345   std::string default_language_ = kInitialDefaultLanguage;
346 
347   TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl);
348 };
349 
LangId(const std::string & filename)350 LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) {
351   if (!pimpl_->is_valid()) {
352     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
353                   << "on the data from " << filename
354                   << "; nothing should crash, but "
355                   << "accuracy will be bad.";
356   }
357 }
358 
LangId(int fd)359 LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) {
360   if (!pimpl_->is_valid()) {
361     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
362                   << "on the data from descriptor " << fd
363                   << "; nothing should crash, "
364                   << "but accuracy will be bad.";
365   }
366 }
367 
LangId(const char * ptr,size_t length)368 LangId::LangId(const char *ptr, size_t length)
369     : pimpl_(new LangIdImpl(ptr, length)) {
370   if (!pimpl_->is_valid()) {
371     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
372                   << "on the memory region; nothing should crash, "
373                   << "but accuracy will be bad.";
374   }
375 }
376 
377 LangId::~LangId() = default;
378 
SetProbabilityThreshold(float threshold)379 void LangId::SetProbabilityThreshold(float threshold) {
380   pimpl_->SetProbabilityThreshold(threshold);
381 }
382 
SetDefaultLanguage(const std::string & lang)383 void LangId::SetDefaultLanguage(const std::string &lang) {
384   pimpl_->SetDefaultLanguage(lang);
385 }
386 
FindLanguage(const std::string & text) const387 std::string LangId::FindLanguage(const std::string &text) const {
388   return pimpl_->FindLanguage(text);
389 }
390 
FindLanguages(const std::string & text) const391 std::vector<std::pair<std::string, float>> LangId::FindLanguages(
392     const std::string &text) const {
393   return pimpl_->FindLanguages(text);
394 }
395 
is_valid() const396 bool LangId::is_valid() const { return pimpl_->is_valid(); }
397 
version() const398 int LangId::version() const { return pimpl_->version(); }
399 
400 }  // namespace lang_id
401 }  // namespace nlp_core
402 }  // namespace libtextclassifier
403