1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/fb_model/model-provider-from-fb.h"
18
19 #include <string>
20
21 #include "lang_id/common/file/file-utils.h"
22 #include "lang_id/common/file/mmap.h"
23 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
24 #include "lang_id/common/flatbuffers/model-utils.h"
25 #include "lang_id/common/lite_strings/str-split.h"
26
27 namespace libtextclassifier3 {
28 namespace mobile {
29 namespace lang_id {
30
ModelProviderFromFlatbuffer(const std::string & filename)31 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
32 const std::string &filename)
33
34 // Using mmap as a fast way to read the model bytes. As the file is
35 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
36 // stay alive for the entire lifetime of this object.
37 : scoped_mmap_(new ScopedMmap(filename)) {
38 Initialize(scoped_mmap_->handle().to_stringpiece());
39 }
40
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd)41 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
42 FileDescriptorOrHandle fd)
43
44 // Using mmap as a fast way to read the model bytes. As the file is
45 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
46 // stay alive for the entire lifetime of this object.
47 : scoped_mmap_(new ScopedMmap(fd)) {
48 Initialize(scoped_mmap_->handle().to_stringpiece());
49 }
50
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd,std::size_t offset,std::size_t size)51 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
52 FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
53
54 // Using mmap as a fast way to read the model bytes. As the file is
55 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
56 // stay alive for the entire lifetime of this object.
57 : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
58 Initialize(scoped_mmap_->handle().to_stringpiece());
59 }
60
Initialize(StringPiece model_bytes)61 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
62 // Note: valid_ was initialized to false. In the code below, we set valid_ to
63 // true only if all initialization steps completed successfully. Otherwise,
64 // we return early, leaving valid_ to its default value false.
65 model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
66 if (model_ == nullptr) {
67 SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
68 return;
69 }
70
71 // Initialize context_ parameters.
72 if (!saft_fbs::FillParameters(*model_, &context_)) {
73 // FillParameters already performs error logging.
74 return;
75 }
76
77 // Init languages_.
78 const std::string known_languages_str =
79 context_.Get("supported_languages", "");
80 for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
81 languages_.emplace_back(sp);
82 }
83 if (languages_.empty()) {
84 SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
85 return;
86 }
87
88 // Init nn_params_.
89 if (!InitNetworkParams()) {
90 // InitNetworkParams already performs error logging.
91 return;
92 }
93
94 // Everything looks fine.
95 valid_ = true;
96 }
97
InitNetworkParams()98 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
99 const std::string kInputName = "language-identifier-network";
100 StringPiece bytes =
101 saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
102 if ((bytes.data() == nullptr) || bytes.empty()) {
103 SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
104 return false;
105 }
106 std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
107 new EmbeddingNetworkParamsFromFlatbuffer(bytes));
108 if (!nn_params_from_fb->is_valid()) {
109 SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
110 return false;
111 }
112 nn_params_ = std::move(nn_params_from_fb);
113 return true;
114 }
115
116 } // namespace lang_id
117 } // namespace mobile
118 } // namespace nlp_saft
119