• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/fb_model/model-provider-from-fb.h"
18 
19 #include <string>
20 
21 #include "lang_id/common/file/file-utils.h"
22 #include "lang_id/common/file/mmap.h"
23 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
24 #include "lang_id/common/flatbuffers/model-utils.h"
25 #include "lang_id/common/lite_strings/str-split.h"
26 
27 namespace libtextclassifier3 {
28 namespace mobile {
29 namespace lang_id {
30 
ModelProviderFromFlatbuffer(const std::string & filename)31 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
32     const std::string &filename)
33 
34     // Using mmap as a fast way to read the model bytes.  As the file is
35     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
36     // stay alive for the entire lifetime of this object.
37     : scoped_mmap_(new ScopedMmap(filename)) {
38   Initialize(scoped_mmap_->handle().to_stringpiece());
39 }
40 
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd)41 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
42     FileDescriptorOrHandle fd)
43 
44     // Using mmap as a fast way to read the model bytes.  As the file is
45     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
46     // stay alive for the entire lifetime of this object.
47     : scoped_mmap_(new ScopedMmap(fd)) {
48   Initialize(scoped_mmap_->handle().to_stringpiece());
49 }
50 
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd,std::size_t offset,std::size_t size)51 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
52     FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
53 
54     // Using mmap as a fast way to read the model bytes.  As the file is
55     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
56     // stay alive for the entire lifetime of this object.
57     : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
58   Initialize(scoped_mmap_->handle().to_stringpiece());
59 }
60 
Initialize(StringPiece model_bytes)61 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
62   // Note: valid_ was initialized to false.  In the code below, we set valid_ to
63   // true only if all initialization steps completed successfully.  Otherwise,
64   // we return early, leaving valid_ to its default value false.
65   model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
66   if (model_ == nullptr) {
67     SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
68     return;
69   }
70 
71   // Initialize context_ parameters.
72   if (!saft_fbs::FillParameters(*model_, &context_)) {
73     // FillParameters already performs error logging.
74     return;
75   }
76 
77   // Init languages_.
78   const std::string known_languages_str =
79       context_.Get("supported_languages", "");
80   for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
81     languages_.emplace_back(sp);
82   }
83   if (languages_.empty()) {
84     SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
85     return;
86   }
87 
88   // Init nn_params_.
89   if (!InitNetworkParams()) {
90     // InitNetworkParams already performs error logging.
91     return;
92   }
93 
94   // Everything looks fine.
95   valid_ = true;
96 }
97 
InitNetworkParams()98 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
99   const std::string kInputName = "language-identifier-network";
100   StringPiece bytes =
101       saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
102   if ((bytes.data() == nullptr) || bytes.empty()) {
103     SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
104     return false;
105   }
106   std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
107       new EmbeddingNetworkParamsFromFlatbuffer(bytes));
108   if (!nn_params_from_fb->is_valid()) {
109     SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
110     return false;
111   }
112   nn_params_ = std::move(nn_params_from_fb);
113   return true;
114 }
115 
116 }  // namespace lang_id
117 }  // namespace mobile
118 }  // namespace nlp_saft
119