/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "lang_id/lang-id.h" #include #include #include #include #include #include #include "common/algorithm.h" #include "common/embedding-network-params-from-proto.h" #include "common/embedding-network.pb.h" #include "common/embedding-network.h" #include "common/feature-extractor.h" #include "common/file-utils.h" #include "common/list-of-strings.pb.h" #include "common/memory_image/in-memory-model-data.h" #include "common/mmap.h" #include "common/softmax.h" #include "common/task-context.h" #include "lang_id/custom-tokenizer.h" #include "lang_id/lang-id-brain-interface.h" #include "lang_id/language-identifier-features.h" #include "lang_id/light-sentence-features.h" #include "lang_id/light-sentence.h" #include "lang_id/relevant-script-feature.h" #include "util/base/logging.h" #include "util/base/macros.h" using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory; namespace libtextclassifier { namespace nlp_core { namespace lang_id { namespace { // Default value for the probability threshold; see comments for // LangId::SetProbabilityThreshold(). static const float kDefaultProbabilityThreshold = 0.50; // Default value for min text size below which our model can't provide a // meaningful prediction. static const int kDefaultMinTextSizeInBytes = 20; // Initial value for the default language for LangId::FindLanguage(). The // default language can be changed (for an individual LangId object) using // LangId::SetDefaultLanguage(). static const char kInitialDefaultLanguage[] = ""; // Returns total number of bytes of the words from sentence, without the ^ // (start-of-word) and $ (end-of-word) markers. Note: "real text" means that // this ignores whitespace and punctuation characters from the original text. int GetRealTextSize(const LightSentence &sentence) { int total = 0; for (int i = 0; i < sentence.num_words(); ++i) { TC_DCHECK(!sentence.word(i).empty()); TC_DCHECK_EQ('^', sentence.word(i).front()); TC_DCHECK_EQ('$', sentence.word(i).back()); total += sentence.word(i).size() - 2; } return total; } } // namespace // Class that performs all work behind LangId. class LangIdImpl { public: explicit LangIdImpl(const std::string &filename) { // Using mmap as a fast way to read the model bytes. ScopedMmap scoped_mmap(filename); MmapHandle mmap_handle = scoped_mmap.handle(); if (!mmap_handle.ok()) { TC_LOG(ERROR) << "Unable to read model bytes."; return; } Initialize(mmap_handle.to_stringpiece()); } explicit LangIdImpl(int fd) { // Using mmap as a fast way to read the model bytes. ScopedMmap scoped_mmap(fd); MmapHandle mmap_handle = scoped_mmap.handle(); if (!mmap_handle.ok()) { TC_LOG(ERROR) << "Unable to read model bytes."; return; } Initialize(mmap_handle.to_stringpiece()); } LangIdImpl(const char *ptr, size_t length) { Initialize(StringPiece(ptr, length)); } void Initialize(StringPiece model_bytes) { // Will set valid_ to true only on successful initialization. valid_ = false; // Make sure all relevant features are registered: ContinuousBagOfNgramsFunction::RegisterClass(); RelevantScriptFeature::RegisterClass(); // NOTE(salcianu): code below relies on the fact that the current features // do not rely on data from a TaskInput. Otherwise, one would have to use // the more complex model registration mechanism, which requires more code. InMemoryModelData model_data(model_bytes); TaskContext context; if (!model_data.GetTaskSpec(context.mutable_spec())) { TC_LOG(ERROR) << "Unable to get model TaskSpec"; return; } if (!ParseNetworkParams(model_data, &context)) { return; } if (!ParseListOfKnownLanguages(model_data, &context)) { return; } network_.reset(new EmbeddingNetwork(network_params_.get())); if (!network_->is_valid()) { return; } probability_threshold_ = context.Get("reliability_thresh", kDefaultProbabilityThreshold); min_text_size_in_bytes_ = context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes); version_ = context.Get("version", 0); if (!lang_id_brain_interface_.Init(&context)) { return; } valid_ = true; } void SetProbabilityThreshold(float threshold) { probability_threshold_ = threshold; } void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; } std::string FindLanguage(const std::string &text) const { std::vector scores = ScoreLanguages(text); if (scores.empty()) { return default_language_; } // Softmax label with max score. int label = GetArgMax(scores); float probability = scores[label]; if (probability < probability_threshold_) { return default_language_; } return GetLanguageForSoftmaxLabel(label); } std::vector> FindLanguages( const std::string &text) const { std::vector scores = ScoreLanguages(text); std::vector> result; for (int i = 0; i < scores.size(); i++) { result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]}); } // To avoid crashing clients that always expect at least one predicted // language, we promised (see doc for this method) that the result always // contains at least one element. if (result.empty()) { // We use a tiny probability, such that any client that uses a meaningful // probability threshold ignores this prediction. We don't use 0.0f, to // avoid crashing clients that normalize the probabilities we return here. result.push_back({default_language_, 0.001f}); } return result; } std::vector ScoreLanguages(const std::string &text) const { if (!is_valid()) { return {}; } // Create a Sentence storing the input text. LightSentence sentence; TokenizeTextForLangId(text, &sentence); if (GetRealTextSize(sentence) < min_text_size_in_bytes_) { return {}; } // TODO(salcianu): reuse vector. std::vector features( lang_id_brain_interface_.NumEmbeddings()); lang_id_brain_interface_.GetFeatures(&sentence, &features); // Predict language. EmbeddingNetwork::Vector scores; network_->ComputeFinalScores(features, &scores); return ComputeSoftmax(scores); } bool is_valid() const { return valid_; } int version() const { return version_; } private: // Returns name of the (in-memory) file for the indicated TaskInput from // context. static std::string GetInMemoryFileNameForTaskInput( const std::string &input_name, TaskContext *context) { TaskInput *task_input = context->GetInput(input_name); if (task_input->part_size() != 1) { TC_LOG(ERROR) << "TaskInput " << input_name << " has " << task_input->part_size() << " parts"; return ""; } return task_input->part(0).file_pattern(); } bool ParseNetworkParams(const InMemoryModelData &model_data, TaskContext *context) { const std::string input_name = "language-identifier-network"; const std::string input_file_name = GetInMemoryFileNameForTaskInput(input_name, context); if (input_file_name.empty()) { TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; return false; } StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); if (bytes.data() == nullptr) { TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; return false; } std::unique_ptr proto(new EmbeddingNetworkProto()); if (!ParseProtoFromMemory(bytes, proto.get())) { TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto"; return false; } network_params_.reset( new EmbeddingNetworkParamsFromProto(std::move(proto))); if (!network_params_->is_valid()) { TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid"; return false; } return true; } // Parses dictionary with known languages (i.e., field languages_) from a // TaskInput of context. Note: that TaskInput should be a ListOfStrings proto // with a single element, the serialized form of a ListOfStrings. // bool ParseListOfKnownLanguages(const InMemoryModelData &model_data, TaskContext *context) { const std::string input_name = "language-name-id-map"; const std::string input_file_name = GetInMemoryFileNameForTaskInput(input_name, context); if (input_file_name.empty()) { TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; return false; } StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); if (bytes.data() == nullptr) { TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; return false; } ListOfStrings records; if (!ParseProtoFromMemory(bytes, &records)) { TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput " << input_name; return false; } if (records.element_size() != 1) { TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name << " : " << records.element_size(); return false; } if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) { TC_LOG(ERROR) << "Unable to parse dictionary with known languages"; return false; } return true; } // Returns language code for a softmax label. See comments for languages_ // field. If label is out of range, returns default_language_. std::string GetLanguageForSoftmaxLabel(int label) const { if ((label >= 0) && (label < languages_.element_size())) { return languages_.element(label); } else { TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, " << languages_.element_size() << ")"; return default_language_; } } LangIdBrainInterface lang_id_brain_interface_; // Parameters for the neural network network_ (see below). std::unique_ptr network_params_; // Neural network to use for scoring. std::unique_ptr network_; // True if this object is ready to perform language predictions. bool valid_; // Only predictions with a probability (confidence) above this threshold are // reported. Otherwise, we report default_language_. float probability_threshold_ = kDefaultProbabilityThreshold; // Min size of the input text for our predictions to be meaningful. Below // this threshold, the underlying model may report a wrong language and a high // confidence score. int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes; // Version of the model. int version_ = -1; // Known languages: softmax label i (an integer) means languages_.element(i) // (something like "en", "fr", "ru", etc). ListOfStrings languages_; // Language code to return in case of errors. std::string default_language_ = kInitialDefaultLanguage; TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl); }; LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) { if (!pimpl_->is_valid()) { TC_LOG(ERROR) << "Unable to construct a valid LangId based " << "on the data from " << filename << "; nothing should crash, but " << "accuracy will be bad."; } } LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) { if (!pimpl_->is_valid()) { TC_LOG(ERROR) << "Unable to construct a valid LangId based " << "on the data from descriptor " << fd << "; nothing should crash, " << "but accuracy will be bad."; } } LangId::LangId(const char *ptr, size_t length) : pimpl_(new LangIdImpl(ptr, length)) { if (!pimpl_->is_valid()) { TC_LOG(ERROR) << "Unable to construct a valid LangId based " << "on the memory region; nothing should crash, " << "but accuracy will be bad."; } } LangId::~LangId() = default; void LangId::SetProbabilityThreshold(float threshold) { pimpl_->SetProbabilityThreshold(threshold); } void LangId::SetDefaultLanguage(const std::string &lang) { pimpl_->SetDefaultLanguage(lang); } std::string LangId::FindLanguage(const std::string &text) const { return pimpl_->FindLanguage(text); } std::vector> LangId::FindLanguages( const std::string &text) const { return pimpl_->FindLanguages(text); } bool LangId::is_valid() const { return pimpl_->is_valid(); } int LangId::version() const { return pimpl_->version(); } } // namespace lang_id } // namespace nlp_core } // namespace libtextclassifier