1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/vocab/vocab-level-table.h"
18
19 #include <cstddef>
20 #include <memory>
21
22 #include "annotator/model_generated.h"
23 #include "utils/base/endian.h"
24 #include "utils/container/bit-vector.h"
25 #include "utils/optional.h"
26 #include "marisa/trie.h"
27
28 namespace libtextclassifier3 {
29
Create(const VocabModel * model)30 std::unique_ptr<VocabLevelTable> VocabLevelTable::Create(
31 const VocabModel* model) {
32 if (!LittleEndian::IsLittleEndian()) {
33 // TODO(tonymak) Consider making this work on a big endian device.
34 TC3_LOG(ERROR)
35 << "VocabLevelTable is only working on a little endian device.";
36 return nullptr;
37 }
38 const flatbuffers::Vector<uint8_t>* trie_data = model->vocab_trie();
39 if (trie_data == nullptr) {
40 TC3_LOG(ERROR) << "vocab_trie is missing from the model file.";
41 return nullptr;
42 }
43 std::unique_ptr<marisa::Trie> vocab_trie(new marisa::Trie);
44 vocab_trie->map(trie_data->data(), trie_data->size());
45
46 return std::unique_ptr<VocabLevelTable>(new VocabLevelTable(
47 model, std::move(vocab_trie), BitVector(model->beginner_level()),
48 BitVector(model->do_not_trigger_in_upper_case())));
49 }
50
VocabLevelTable(const VocabModel * model,std::unique_ptr<marisa::Trie> vocab_trie,const BitVector beginner_level,const BitVector do_not_trigger_in_upper_case)51 VocabLevelTable::VocabLevelTable(const VocabModel* model,
52 std::unique_ptr<marisa::Trie> vocab_trie,
53 const BitVector beginner_level,
54 const BitVector do_not_trigger_in_upper_case)
55 : model_(model),
56 vocab_trie_(std::move(vocab_trie)),
57 beginner_level_(beginner_level),
58 do_not_trigger_in_upper_case_(do_not_trigger_in_upper_case) {}
59
Lookup(const std::string & vocab) const60 Optional<LookupResult> VocabLevelTable::Lookup(const std::string& vocab) const {
61 marisa::Agent agent;
62 agent.set_query(vocab.data(), vocab.size());
63 if (vocab_trie_->lookup(agent)) {
64 const int vector_idx = agent.key().id();
65 return Optional<LookupResult>({beginner_level_[vector_idx],
66 do_not_trigger_in_upper_case_[vector_idx]});
67 }
68 return Optional<LookupResult>();
69 }
70 } // namespace libtextclassifier3
71