1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/lang-id.h"
18
19 #include <stdio.h>
20
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25
26 #include "lang_id/common/embedding-feature-interface.h"
27 #include "lang_id/common/embedding-network-params.h"
28 #include "lang_id/common/embedding-network.h"
29 #include "lang_id/common/fel/feature-extractor.h"
30 #include "lang_id/common/lite_base/logging.h"
31 #include "lang_id/common/lite_strings/numbers.h"
32 #include "lang_id/common/lite_strings/str-split.h"
33 #include "lang_id/common/lite_strings/stringpiece.h"
34 #include "lang_id/common/math/algorithm.h"
35 #include "lang_id/common/math/softmax.h"
36 #include "lang_id/custom-tokenizer.h"
37 #include "lang_id/features/light-sentence-features.h"
38 // The two features/ headers below are needed only for RegisterClass().
39 #include "lang_id/features/char-ngram-feature.h"
40 #include "lang_id/features/relevant-script-feature.h"
41 #include "lang_id/light-sentence.h"
42 // The two script/ headers below are needed only for RegisterClass().
43 #include "lang_id/script/approx-script.h"
44 #include "lang_id/script/tiny-script-detector.h"
45
46 namespace libtextclassifier3 {
47 namespace mobile {
48 namespace lang_id {
49
50 namespace {
51 // Default value for the confidence threshold. If the confidence of the top
52 // prediction is below this threshold, then FindLanguage() returns
53 // LangId::kUnknownLanguageCode. Note: this is just a default value; if the
54 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
55 // use that value instead. Note: for legacy reasons, our code and comments use
56 // the terms "confidence", "probability" and "reliability" equivalently.
57 static const float kDefaultConfidenceThreshold = 0.50f;
58 } // namespace
59
60 // Class that performs all work behind LangId.
61 class LangIdImpl {
62 public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)63 explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
64 : model_provider_(std::move(model_provider)),
65 lang_id_brain_interface_("language_identifier") {
66 // Note: in the code below, we set valid_ to true only if all initialization
67 // steps completed successfully. Otherwise, we return early, leaving valid_
68 // to its default value false.
69 if (!model_provider_ || !model_provider_->is_valid()) {
70 SAFTM_LOG(ERROR) << "Invalid model provider";
71 return;
72 }
73
74 auto *nn_params = model_provider_->GetNnParams();
75 if (!nn_params) {
76 SAFTM_LOG(ERROR) << "No NN params";
77 return;
78 }
79 network_.reset(new EmbeddingNetwork(nn_params));
80
81 languages_ = model_provider_->GetLanguages();
82 if (languages_.empty()) {
83 SAFTM_LOG(ERROR) << "No known languages";
84 return;
85 }
86
87 TaskContext context = *model_provider_->GetTaskContext();
88 if (!Setup(&context)) {
89 SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
90 return;
91 }
92 if (!Init(&context)) {
93 SAFTM_LOG(ERROR) << "Unable to Init() LangId";
94 return;
95 }
96 valid_ = true;
97 }
98
FindLanguage(StringPiece text) const99 std::string FindLanguage(StringPiece text) const {
100 LangIdResult lang_id_result;
101 FindLanguages(text, &lang_id_result, /* max_results = */ 1);
102 if (lang_id_result.predictions.empty()) {
103 return LangId::kUnknownLanguageCode;
104 }
105
106 const std::string &language = lang_id_result.predictions[0].first;
107 const float probability = lang_id_result.predictions[0].second;
108 SAFTM_DLOG(INFO) << "Predicted " << language
109 << " with prob: " << probability << " for \"" << text
110 << "\"";
111
112 // Find confidence threshold for language.
113 float threshold = default_threshold_;
114 auto it = per_lang_thresholds_.find(language);
115 if (it != per_lang_thresholds_.end()) {
116 threshold = it->second;
117 }
118 if (probability < threshold) {
119 SAFTM_DLOG(INFO) << " below threshold => "
120 << LangId::kUnknownLanguageCode;
121 return LangId::kUnknownLanguageCode;
122 }
123 return language;
124 }
125
FindLanguages(StringPiece text,LangIdResult * result,int max_results) const126 void FindLanguages(StringPiece text, LangIdResult *result,
127 int max_results) const {
128 if (result == nullptr) return;
129
130 if (max_results <= 0) {
131 max_results = languages_.size();
132 }
133 result->predictions.clear();
134 if (!is_valid() || (max_results == 0)) {
135 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
136 return;
137 }
138
139 // Tokenize the input text (this also does some pre-processing, like
140 // removing ASCII digits, punctuation, etc).
141 LightSentence sentence;
142 tokenizer_.Tokenize(text, &sentence);
143
144 // Test input size here, after pre-processing removed irrelevant chars.
145 if (IsTooShort(sentence)) {
146 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
147 return;
148 }
149
150 // Extract features from the tokenized text.
151 std::vector<FeatureVector> features =
152 lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
153
154 // Run feed-forward neural network to compute scores (softmax logits).
155 std::vector<float> scores;
156 network_->ComputeFinalScores(features, &scores);
157
158 if (max_results == 1) {
159 // Optimization for the case when the user wants only the top result.
160 // Computing argmax is faster than the general top-k code.
161 int prediction_id = GetArgMax(scores);
162 const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
163 float probability = ComputeSoftmaxProbability(scores, prediction_id);
164 result->predictions.emplace_back(language, probability);
165 } else {
166 // Compute and sort softmax in descending order by probability and convert
167 // IDs to language code strings. When probabilities are equal, we sort by
168 // language code string in ascending order.
169 const std::vector<float> softmax = ComputeSoftmax(scores);
170 const std::vector<int> indices = GetTopKIndices(max_results, softmax);
171 for (const int index : indices) {
172 result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
173 softmax[index]);
174 }
175 }
176 }
177
is_valid() const178 bool is_valid() const { return valid_; }
179
GetModelVersion() const180 int GetModelVersion() const { return model_version_; }
181
182 // Returns a property stored in the model file.
183 template <typename T, typename R>
GetProperty(const std::string & property,T default_value) const184 R GetProperty(const std::string &property, T default_value) const {
185 return model_provider_->GetTaskContext()->Get(property, default_value);
186 }
187
188 // Perform any necessary static initialization.
189 // This function is thread-safe.
190 // It's also safe to call this function multiple times.
191 //
192 // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
193 // the BUILD file, because the build process for some users of this code
194 // doesn't support any equivalent to alwayslink=1 (in particular the
195 // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might
196 // be possible to add such support, avoiding the need for an equivalent to
197 // alwayslink=1 is preferable because it avoids unnecessarily bloating code
198 // size in apps that link against this code but don't use it.
RegisterClasses()199 static void RegisterClasses() {
200 static bool initialized = []() -> bool {
201 libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
202 libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
203 libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
204 libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
205 return true;
206 }();
207 (void)initialized; // Variable used only for initializer's side effects.
208 }
209
210 private:
Setup(TaskContext * context)211 bool Setup(TaskContext *context) {
212 tokenizer_.Setup(context);
213 if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
214
215 min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
216 default_threshold_ =
217 context->Get("reliability_thresh", kDefaultConfidenceThreshold);
218
219 // Parse task parameter "per_lang_reliability_thresholds", fill
220 // per_lang_thresholds_.
221 const std::string thresholds_str =
222 context->Get("per_lang_reliability_thresholds", "");
223 std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
224 for (const auto &token : tokens) {
225 if (token.empty()) continue;
226 std::vector<StringPiece> parts = LiteStrSplit(token, '=');
227 float threshold = 0.0f;
228 if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
229 per_lang_thresholds_[std::string(parts[0])] = threshold;
230 } else {
231 SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
232 }
233 }
234 model_version_ = context->Get("model_version", model_version_);
235 return true;
236 }
237
Init(TaskContext * context)238 bool Init(TaskContext *context) {
239 return lang_id_brain_interface_.InitForProcessing(context);
240 }
241
242 // Returns language code for a softmax label. See comments for languages_
243 // field. If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const244 std::string GetLanguageForSoftmaxLabel(int label) const {
245 if ((label >= 0) && (label < languages_.size())) {
246 return languages_[label];
247 } else {
248 SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
249 << languages_.size() << ")";
250 return LangId::kUnknownLanguageCode;
251 }
252 }
253
IsTooShort(const LightSentence & sentence) const254 bool IsTooShort(const LightSentence &sentence) const {
255 int text_size = 0;
256 for (const std::string &token : sentence) {
257 // Each token has the form ^...$: we subtract 2 because we want to count
258 // only the real text, not the chars added by us.
259 text_size += token.size() - 2;
260 }
261 return text_size < min_text_size_in_bytes_;
262 }
263
264 std::unique_ptr<ModelProvider> model_provider_;
265
266 TokenizerForLangId tokenizer_;
267
268 EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
269 lang_id_brain_interface_;
270
271 // Neural network to use for scoring.
272 std::unique_ptr<EmbeddingNetwork> network_;
273
274 // True if this object is ready to perform language predictions.
275 bool valid_ = false;
276
277 // The model returns LangId::kUnknownLanguageCode for input text that has
278 // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
279 // digits, and punctuation).
280 int min_text_size_in_bytes_ = 0;
281
282 // Only predictions with a probability (confidence) above this threshold are
283 // reported. Otherwise, we report LangId::kUnknownLanguageCode.
284 float default_threshold_ = kDefaultConfidenceThreshold;
285
286 std::unordered_map<std::string, float> per_lang_thresholds_;
287
288 // Recognized languages: softmax label i means languages_[i] (something like
289 // "en", "fr", "ru", etc).
290 std::vector<std::string> languages_;
291
292 // Version of the model used by this LangIdImpl object. Zero means that the
293 // model version could not be determined.
294 int model_version_ = 0;
295 };
296
297 const char LangId::kUnknownLanguageCode[] = "und";
298
LangId(std::unique_ptr<ModelProvider> model_provider)299 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
300 : pimpl_(new LangIdImpl(std::move(model_provider))) {
301 LangIdImpl::RegisterClasses();
302 }
303
304 LangId::~LangId() = default;
305
FindLanguage(const char * data,size_t num_bytes) const306 std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
307 StringPiece text(data, num_bytes);
308 return pimpl_->FindLanguage(text);
309 }
310
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result,int max_results) const311 void LangId::FindLanguages(const char *data, size_t num_bytes,
312 LangIdResult *result, int max_results) const {
313 SAFTM_DCHECK(result) << "LangIdResult must not be null.";
314 StringPiece text(data, num_bytes);
315 pimpl_->FindLanguages(text, result, max_results);
316 }
317
is_valid() const318 bool LangId::is_valid() const { return pimpl_->is_valid(); }
319
GetModelVersion() const320 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
321
GetFloatProperty(const std::string & property,float default_value) const322 float LangId::GetFloatProperty(const std::string &property,
323 float default_value) const {
324 return pimpl_->GetProperty<float, float>(property, default_value);
325 }
326
327 } // namespace lang_id
328 } // namespace mobile
329 } // namespace nlp_saft
330