1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 18 #define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 19 20 #include <cstddef> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "lang_id/common/fel/task-context.h" 26 #include "lang_id/common/file/mmap.h" 27 #include "lang_id/common/flatbuffers/model_generated.h" 28 #include "lang_id/common/lite_strings/stringpiece.h" 29 #include "lang_id/model-provider.h" 30 31 namespace libtextclassifier3 { 32 namespace mobile { 33 namespace lang_id { 34 35 // ModelProvider for LangId, based on a SAFT model in flatbuffer format. 36 class ModelProviderFromFlatbuffer : public ModelProvider { 37 public: 38 // Constructs a model provider based on a flatbuffer-format SAFT model from 39 // |filename|. 40 explicit ModelProviderFromFlatbuffer(const string &filename); 41 42 // Constructs a model provider based on a flatbuffer-format SAFT model from 43 // file descriptor |fd|. 44 explicit ModelProviderFromFlatbuffer(int fd); 45 46 // Constructs a model provider from a flatbuffer-format SAFT model the bytes 47 // of which are already in RAM (size bytes starting from address data). 48 // Useful if you "transport" these bytes otherwise than via a normal file 49 // (e.g., if you embed them somehow in your binary). 50 // 51 // IMPORTANT: |data| should be alive during the lifetime of the 52 // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure 53 // for data that's statically embedded in your binary, but more complex in 54 // other cases. To avoid overhead (e.g., heap allocation), this method does 55 // not make a private copy of the data. In general, the ownership of the 56 // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a 57 // LangId object (which doesn't pass it further); hence, one needs to make 58 // sure |data| is alive during the lifetime of that LangId object. ModelProviderFromFlatbuffer(const char * data,std::size_t size)59 ModelProviderFromFlatbuffer(const char *data, std::size_t size) { 60 StringPiece model_bytes(data, size); 61 Initialize(model_bytes); 62 } 63 64 ~ModelProviderFromFlatbuffer() override = default; 65 GetTaskContext()66 const TaskContext *GetTaskContext() const override { 67 return &context_; 68 } 69 GetNnParams()70 const EmbeddingNetworkParams *GetNnParams() const override { 71 return nn_params_.get(); 72 } 73 GetLanguages()74 std::vector<string> GetLanguages() const override { 75 return languages_; 76 } 77 78 private: 79 // Initializes the fields of this class based on the flatbuffer from 80 // |model_bytes|. These bytes are supposed to be the representation of a 81 // Model flatbuffer and should be alive during the lifetime of this object. 82 void Initialize(StringPiece model_bytes); 83 84 // Initializes nn_params_ based on model_. 85 bool InitNetworkParams(); 86 87 // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped 88 // during the lifetime of this object, such that references inside the Model 89 // flatbuffer from those bytes remain valid. 90 const std::unique_ptr<ScopedMmap> scoped_mmap_; 91 92 // Pointer to the flatbuffer from 93 // 94 // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_ 95 // (for safety considerations, see comment for that field), or 96 // 97 // (b) [of (data, size) constructor was used:] the bytes from [data, 98 // data+size). Please read carefully the doc for that constructor. 99 const saft_fbs::Model *model_; 100 101 // Context returned by this model provider. We set its parameters based on 102 // model_, at construction time. 103 TaskContext context_; 104 105 // List of supported languages, see GetLanguages(). We expect this list to be 106 // specified by the ModelParameter named "supported_languages" from model_. 107 std::vector<string> languages_; 108 109 // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput 110 // named "language-identifier-network" from model_. 111 std::unique_ptr<EmbeddingNetworkParams> nn_params_; 112 }; 113 114 } // namespace lang_id 115 } // namespace mobile 116 } // namespace nlp_saft 117 118 #endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 119