1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ 18 #define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ 19 20 // Clients who want to perform language identification should use this header. 21 // 22 // Note for lang id implementors: keep this header as linght as possible. E.g., 23 // any macro defined here (or in a transitively #included file) is a potential 24 // name conflict with our clients. 25 26 #include <memory> 27 #include <string> 28 #include <vector> 29 30 #include "util/base/macros.h" 31 32 namespace libtextclassifier { 33 namespace nlp_core { 34 namespace lang_id { 35 36 // Forward-declaration of the class that performs all underlying work. 37 class LangIdImpl; 38 39 // Class for detecting the language of a document. 40 // 41 // NOTE: this class is thread-unsafe. 42 class LangId { 43 public: 44 // Constructs a LangId object, loading an EmbeddingNetworkProto model from the 45 // indicated file. 46 // 47 // Note: we don't crash if we detect a problem at construction time (e.g., 48 // file doesn't exist, or its content is corrupted). Instead, we mark the 49 // newly-constructed object as invalid; clients can invoke FindLanguage() on 50 // an invalid object: nothing crashes, but accuracy will be bad. 51 explicit LangId(const std::string &filename); 52 53 // Same as above but uses a file descriptor. 54 explicit LangId(int fd); 55 56 // Same as above but uses already mapped memory region 57 explicit LangId(const char *ptr, size_t length); 58 59 virtual ~LangId(); 60 61 // Sets probability threshold for predictions. If our likeliest prediction is 62 // below this threshold, we report the default language (see 63 // SetDefaultLanguage()). Othewise, we report the likelist language. 64 // 65 // By default (if this method is not called) we use the probability threshold 66 // stored in the model, as the task parameter "reliability_thresh". If that 67 // task parameter is not specified, we use 0.5. A client can use this method 68 // to get a different precision / recall trade-off. The higher the threshold, 69 // the higher the precision and lower the recall rate. 70 void SetProbabilityThreshold(float threshold); 71 72 // Sets default language to report if errors prevent running the real 73 // inference code or if prediction confidence is too small. 74 void SetDefaultLanguage(const std::string &lang); 75 76 // Returns language code for the most likely language that text is written in. 77 // Note: if this LangId object is not valid (see 78 // is_valid()), this method returns the default language specified via 79 // SetDefaultLanguage() or (if that method was never invoked), the empty 80 // std::string. 81 std::string FindLanguage(const std::string &text) const; 82 83 // Returns a vector of language codes along with the probability for each 84 // language. The result contains at least one element. The sum of 85 // probabilities may be less than 1.0. 86 std::vector<std::pair<std::string, float>> FindLanguages( 87 const std::string &text) const; 88 89 // Returns true if this object has been correctly initialized and is ready to 90 // perform predictions. For more info, see doc for LangId 91 // constructor above. 92 bool is_valid() const; 93 94 // Returns version number for the model. 95 int version() const; 96 97 private: 98 // Returns a vector of probabilities of languages of the text. 99 std::vector<float> ScoreLanguages(const std::string &text) const; 100 101 // Pimpl ("pointer to implementation") pattern, to hide all internals from our 102 // clients. 103 std::unique_ptr<LangIdImpl> pimpl_; 104 105 TC_DISALLOW_COPY_AND_ASSIGN(LangId); 106 }; 107 108 } // namespace lang_id 109 } // namespace nlp_core 110 } // namespace libtextclassifier 111 112 #endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ 113