• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id.h"
18 
19 #include <stdio.h>
20 
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "lang_id/common/embedding-feature-interface.h"
27 #include "lang_id/common/embedding-network-params.h"
28 #include "lang_id/common/embedding-network.h"
29 #include "lang_id/common/fel/feature-extractor.h"
30 #include "lang_id/common/lite_base/logging.h"
31 #include "lang_id/common/lite_strings/numbers.h"
32 #include "lang_id/common/lite_strings/str-split.h"
33 #include "lang_id/common/lite_strings/stringpiece.h"
34 #include "lang_id/common/math/algorithm.h"
35 #include "lang_id/common/math/softmax.h"
36 #include "lang_id/custom-tokenizer.h"
37 #include "lang_id/features/light-sentence-features.h"
38 // The two features/ headers below are needed only for RegisterClass().
39 #include "lang_id/features/char-ngram-feature.h"
40 #include "lang_id/features/relevant-script-feature.h"
41 #include "lang_id/light-sentence.h"
42 // The two script/ headers below are needed only for RegisterClass().
43 #include "lang_id/script/approx-script.h"
44 #include "lang_id/script/tiny-script-detector.h"
45 
46 namespace libtextclassifier3 {
47 namespace mobile {
48 namespace lang_id {
49 
50 namespace {
51 // Default value for the confidence threshold.  If the confidence of the top
52 // prediction is below this threshold, then FindLanguage() returns
53 // LangId::kUnknownLanguageCode.  Note: this is just a default value; if the
54 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
55 // use that value instead.  Note: for legacy reasons, our code and comments use
56 // the terms "confidence", "probability" and "reliability" equivalently.
57 static const float kDefaultConfidenceThreshold = 0.50f;
58 }  // namespace
59 
60 // Class that performs all work behind LangId.
61 class LangIdImpl {
62  public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)63   explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
64       : model_provider_(std::move(model_provider)),
65         lang_id_brain_interface_("language_identifier") {
66     // Note: in the code below, we set valid_ to true only if all initialization
67     // steps completed successfully.  Otherwise, we return early, leaving valid_
68     // to its default value false.
69     if (!model_provider_ || !model_provider_->is_valid()) {
70       SAFTM_LOG(ERROR) << "Invalid model provider";
71       return;
72     }
73 
74     auto *nn_params = model_provider_->GetNnParams();
75     if (!nn_params) {
76       SAFTM_LOG(ERROR) << "No NN params";
77       return;
78     }
79     network_.reset(new EmbeddingNetwork(nn_params));
80 
81     languages_ = model_provider_->GetLanguages();
82     if (languages_.empty()) {
83       SAFTM_LOG(ERROR) << "No known languages";
84       return;
85     }
86 
87     TaskContext context = *model_provider_->GetTaskContext();
88     if (!Setup(&context)) {
89       SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
90       return;
91     }
92     if (!Init(&context)) {
93       SAFTM_LOG(ERROR) << "Unable to Init() LangId";
94       return;
95     }
96     valid_ = true;
97   }
98 
FindLanguage(StringPiece text) const99   std::string FindLanguage(StringPiece text) const {
100     LangIdResult lang_id_result;
101     FindLanguages(text, &lang_id_result, /* max_results = */ 1);
102     if (lang_id_result.predictions.empty()) {
103       return LangId::kUnknownLanguageCode;
104     }
105 
106     const std::string &language = lang_id_result.predictions[0].first;
107     const float probability = lang_id_result.predictions[0].second;
108     SAFTM_DLOG(INFO) << "Predicted " << language
109                      << " with prob: " << probability << " for \"" << text
110                      << "\"";
111 
112     // Find confidence threshold for language.
113     float threshold = default_threshold_;
114     auto it = per_lang_thresholds_.find(language);
115     if (it != per_lang_thresholds_.end()) {
116       threshold = it->second;
117     }
118     if (probability < threshold) {
119       SAFTM_DLOG(INFO) << "  below threshold => "
120                        << LangId::kUnknownLanguageCode;
121       return LangId::kUnknownLanguageCode;
122     }
123     return language;
124   }
125 
FindLanguages(StringPiece text,LangIdResult * result,int max_results) const126   void FindLanguages(StringPiece text, LangIdResult *result,
127                      int max_results) const {
128     if (result == nullptr) return;
129 
130     if (max_results <= 0) {
131       max_results = languages_.size();
132     }
133     result->predictions.clear();
134     if (!is_valid() || (max_results == 0)) {
135       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
136       return;
137     }
138 
139     // Tokenize the input text (this also does some pre-processing, like
140     // removing ASCII digits, punctuation, etc).
141     LightSentence sentence;
142     tokenizer_.Tokenize(text, &sentence);
143 
144     // Test input size here, after pre-processing removed irrelevant chars.
145     if (IsTooShort(sentence)) {
146       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
147       return;
148     }
149 
150     // Extract features from the tokenized text.
151     std::vector<FeatureVector> features =
152         lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
153 
154     // Run feed-forward neural network to compute scores (softmax logits).
155     std::vector<float> scores;
156     network_->ComputeFinalScores(features, &scores);
157 
158     if (max_results == 1) {
159       // Optimization for the case when the user wants only the top result.
160       // Computing argmax is faster than the general top-k code.
161       int prediction_id = GetArgMax(scores);
162       const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
163       float probability = ComputeSoftmaxProbability(scores, prediction_id);
164       result->predictions.emplace_back(language, probability);
165     } else {
166       // Compute and sort softmax in descending order by probability and convert
167       // IDs to language code strings.  When probabilities are equal, we sort by
168       // language code string in ascending order.
169       const std::vector<float> softmax = ComputeSoftmax(scores);
170       const std::vector<int> indices = GetTopKIndices(max_results, softmax);
171       for (const int index : indices) {
172         result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
173                                          softmax[index]);
174       }
175     }
176   }
177 
is_valid() const178   bool is_valid() const { return valid_; }
179 
GetModelVersion() const180   int GetModelVersion() const { return model_version_; }
181 
182   // Returns a property stored in the model file.
183   template <typename T, typename R>
GetProperty(const std::string & property,T default_value) const184   R GetProperty(const std::string &property, T default_value) const {
185     return model_provider_->GetTaskContext()->Get(property, default_value);
186   }
187 
188   // Perform any necessary static initialization.
189   // This function is thread-safe.
190   // It's also safe to call this function multiple times.
191   //
192   // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
193   // the BUILD file, because the build process for some users of this code
194   // doesn't support any equivalent to alwayslink=1 (in particular the
195   // Firebase C++ SDK build uses a Kokoro-based CMake build).  While it might
196   // be possible to add such support, avoiding the need for an equivalent to
197   // alwayslink=1 is preferable because it avoids unnecessarily bloating code
198   // size in apps that link against this code but don't use it.
RegisterClasses()199   static void RegisterClasses() {
200     static bool initialized = []() -> bool {
201       libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
202       libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
203       libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
204       libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
205       return true;
206     }();
207     (void)initialized;  // Variable used only for initializer's side effects.
208   }
209 
210  private:
Setup(TaskContext * context)211   bool Setup(TaskContext *context) {
212     tokenizer_.Setup(context);
213     if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
214 
215     min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
216     default_threshold_ =
217         context->Get("reliability_thresh", kDefaultConfidenceThreshold);
218 
219     // Parse task parameter "per_lang_reliability_thresholds", fill
220     // per_lang_thresholds_.
221     const std::string thresholds_str =
222         context->Get("per_lang_reliability_thresholds", "");
223     std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
224     for (const auto &token : tokens) {
225       if (token.empty()) continue;
226       std::vector<StringPiece> parts = LiteStrSplit(token, '=');
227       float threshold = 0.0f;
228       if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
229         per_lang_thresholds_[std::string(parts[0])] = threshold;
230       } else {
231         SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
232       }
233     }
234     model_version_ = context->Get("model_version", model_version_);
235     return true;
236   }
237 
Init(TaskContext * context)238   bool Init(TaskContext *context) {
239     return lang_id_brain_interface_.InitForProcessing(context);
240   }
241 
242   // Returns language code for a softmax label.  See comments for languages_
243   // field.  If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const244   std::string GetLanguageForSoftmaxLabel(int label) const {
245     if ((label >= 0) && (label < languages_.size())) {
246       return languages_[label];
247     } else {
248       SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
249                        << languages_.size() << ")";
250       return LangId::kUnknownLanguageCode;
251     }
252   }
253 
IsTooShort(const LightSentence & sentence) const254   bool IsTooShort(const LightSentence &sentence) const {
255     int text_size = 0;
256     for (const std::string &token : sentence) {
257       // Each token has the form ^...$: we subtract 2 because we want to count
258       // only the real text, not the chars added by us.
259       text_size += token.size() - 2;
260     }
261     return text_size < min_text_size_in_bytes_;
262   }
263 
264   std::unique_ptr<ModelProvider> model_provider_;
265 
266   TokenizerForLangId tokenizer_;
267 
268   EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
269       lang_id_brain_interface_;
270 
271   // Neural network to use for scoring.
272   std::unique_ptr<EmbeddingNetwork> network_;
273 
274   // True if this object is ready to perform language predictions.
275   bool valid_ = false;
276 
277   // The model returns LangId::kUnknownLanguageCode for input text that has
278   // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
279   // digits, and punctuation).
280   int min_text_size_in_bytes_ = 0;
281 
282   // Only predictions with a probability (confidence) above this threshold are
283   // reported.  Otherwise, we report LangId::kUnknownLanguageCode.
284   float default_threshold_ = kDefaultConfidenceThreshold;
285 
286   std::unordered_map<std::string, float> per_lang_thresholds_;
287 
288   // Recognized languages: softmax label i means languages_[i] (something like
289   // "en", "fr", "ru", etc).
290   std::vector<std::string> languages_;
291 
292   // Version of the model used by this LangIdImpl object.  Zero means that the
293   // model version could not be determined.
294   int model_version_ = 0;
295 };
296 
297 const char LangId::kUnknownLanguageCode[] = "und";
298 
LangId(std::unique_ptr<ModelProvider> model_provider)299 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
300     : pimpl_(new LangIdImpl(std::move(model_provider))) {
301   LangIdImpl::RegisterClasses();
302 }
303 
304 LangId::~LangId() = default;
305 
FindLanguage(const char * data,size_t num_bytes) const306 std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
307   StringPiece text(data, num_bytes);
308   return pimpl_->FindLanguage(text);
309 }
310 
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result,int max_results) const311 void LangId::FindLanguages(const char *data, size_t num_bytes,
312                            LangIdResult *result, int max_results) const {
313   SAFTM_DCHECK(result) << "LangIdResult must not be null.";
314   StringPiece text(data, num_bytes);
315   pimpl_->FindLanguages(text, result, max_results);
316 }
317 
is_valid() const318 bool LangId::is_valid() const { return pimpl_->is_valid(); }
319 
GetModelVersion() const320 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
321 
GetFloatProperty(const std::string & property,float default_value) const322 float LangId::GetFloatProperty(const std::string &property,
323                                float default_value) const {
324   return pimpl_->GetProperty<float, float>(property, default_value);
325 }
326 
327 }  // namespace lang_id
328 }  // namespace mobile
329 }  // namespace nlp_saft
330