1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 18 #define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 19 20 #include <cstddef> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "lang_id/common/fel/task-context.h" 26 #include "lang_id/common/file/mmap.h" 27 #include "lang_id/common/flatbuffers/model_generated.h" 28 #include "lang_id/common/lite_strings/stringpiece.h" 29 #include "lang_id/model-provider.h" 30 31 namespace libtextclassifier3 { 32 namespace mobile { 33 namespace lang_id { 34 35 // ModelProvider for LangId, based on a SAFT model in flatbuffer format. 36 class ModelProviderFromFlatbuffer : public ModelProvider { 37 public: 38 // Constructs a model provider based on a flatbuffer-format SAFT model from 39 // |filename|. 40 explicit ModelProviderFromFlatbuffer(const std::string &filename); 41 42 // Constructs a model provider based on a flatbuffer-format SAFT model from 43 // file descriptor |fd|. 44 explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd); 45 46 // Constructs a model provider based on a flatbuffer-format SAFT model from 47 // file descriptor |fd|. 48 ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd, std::size_t offset, 49 std::size_t size); 50 51 // Constructs a model provider from a flatbuffer-format SAFT model the bytes 52 // of which are already in RAM (size bytes starting from address data). 53 // Useful if you "transport" these bytes otherwise than via a normal file 54 // (e.g., if you embed them somehow in your binary). 55 // 56 // IMPORTANT: |data| should be alive during the lifetime of the 57 // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure 58 // for data that's statically embedded in your binary, but more complex in 59 // other cases. To avoid overhead (e.g., heap allocation), this method does 60 // not make a private copy of the data. In general, the ownership of the 61 // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a 62 // LangId object (which doesn't pass it further); hence, one needs to make 63 // sure |data| is alive during the lifetime of that LangId object. ModelProviderFromFlatbuffer(const char * data,std::size_t size)64 ModelProviderFromFlatbuffer(const char *data, std::size_t size) { 65 StringPiece model_bytes(data, size); 66 Initialize(model_bytes); 67 } 68 69 ~ModelProviderFromFlatbuffer() override = default; 70 GetTaskContext()71 const TaskContext *GetTaskContext() const override { 72 return &context_; 73 } 74 GetNnParams()75 const EmbeddingNetworkParams *GetNnParams() const override { 76 return nn_params_.get(); 77 } 78 GetLanguages()79 std::vector<std::string> GetLanguages() const override { return languages_; } 80 81 private: 82 // Initializes the fields of this class based on the flatbuffer from 83 // |model_bytes|. These bytes are supposed to be the representation of a 84 // Model flatbuffer and should be alive during the lifetime of this object. 85 void Initialize(StringPiece model_bytes); 86 87 // Initializes nn_params_ based on model_. 88 bool InitNetworkParams(); 89 90 // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped 91 // during the lifetime of this object, such that references inside the Model 92 // flatbuffer from those bytes remain valid. 93 const std::unique_ptr<ScopedMmap> scoped_mmap_; 94 95 // Pointer to the flatbuffer from 96 // 97 // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_ 98 // (for safety considerations, see comment for that field), or 99 // 100 // (b) [of (data, size) constructor was used:] the bytes from [data, 101 // data+size). Please read carefully the doc for that constructor. 102 const saft_fbs::Model *model_; 103 104 // Context returned by this model provider. We set its parameters based on 105 // model_, at construction time. 106 TaskContext context_; 107 108 // List of supported languages, see GetLanguages(). We expect this list to be 109 // specified by the ModelParameter named "supported_languages" from model_. 110 std::vector<std::string> languages_; 111 112 // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput 113 // named "language-identifier-network" from model_. 114 std::unique_ptr<EmbeddingNetworkParams> nn_params_; 115 }; 116 117 } // namespace lang_id 118 } // namespace mobile 119 } // namespace nlp_saft 120 121 #endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_ 122