1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tokenizer.h"
18
19 #include <algorithm>
20
21 #include "util/base/logging.h"
22 #include "util/strings/utf8.h"
23
24 namespace libtextclassifier2 {
25
Tokenizer(const std::vector<const TokenizationCodepointRange * > & codepoint_ranges,bool split_on_script_change)26 Tokenizer::Tokenizer(
27 const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
28 bool split_on_script_change)
29 : split_on_script_change_(split_on_script_change) {
30 for (const TokenizationCodepointRange* range : codepoint_ranges) {
31 codepoint_ranges_.emplace_back(range->UnPack());
32 }
33
34 std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
35 [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
36 const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
37 return a->start < b->start;
38 });
39 }
40
FindTokenizationRange(int codepoint) const41 const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
42 int codepoint) const {
43 auto it = std::lower_bound(
44 codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
45 [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
46 int codepoint) {
47 // This function compares range with the codepoint for the purpose of
48 // finding the first greater or equal range. Because of the use of
49 // std::lower_bound it needs to return true when range < codepoint;
50 // the first time it will return false the lower bound is found and
51 // returned.
52 //
53 // It might seem weird that the condition is range.end <= codepoint
54 // here but when codepoint == range.end it means it's actually just
55 // outside of the range, thus the range is less than the codepoint.
56 return range->end <= codepoint;
57 });
58 if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
59 (*it)->end > codepoint) {
60 return it->get();
61 } else {
62 return nullptr;
63 }
64 }
65
GetScriptAndRole(char32 codepoint,TokenizationCodepointRange_::Role * role,int * script) const66 void Tokenizer::GetScriptAndRole(char32 codepoint,
67 TokenizationCodepointRange_::Role* role,
68 int* script) const {
69 const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
70 if (range) {
71 *role = range->role;
72 *script = range->script_id;
73 } else {
74 *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
75 *script = kUnknownScript;
76 }
77 }
78
Tokenize(const std::string & text) const79 std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
80 UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
81 return Tokenize(text_unicode);
82 }
83
Tokenize(const UnicodeText & text_unicode) const84 std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
85 std::vector<Token> result;
86 Token new_token("", 0, 0);
87 int codepoint_index = 0;
88
89 int last_script = kInvalidScript;
90 for (auto it = text_unicode.begin(); it != text_unicode.end();
91 ++it, ++codepoint_index) {
92 TokenizationCodepointRange_::Role role;
93 int script;
94 GetScriptAndRole(*it, &role, &script);
95
96 if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
97 (split_on_script_change_ && last_script != kInvalidScript &&
98 last_script != script)) {
99 if (!new_token.value.empty()) {
100 result.push_back(new_token);
101 }
102 new_token = Token("", codepoint_index, codepoint_index);
103 }
104 if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
105 new_token.value += std::string(
106 it.utf8_data(),
107 it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
108 ++new_token.end;
109 }
110 if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
111 if (!new_token.value.empty()) {
112 result.push_back(new_token);
113 }
114 new_token = Token("", codepoint_index + 1, codepoint_index + 1);
115 }
116
117 last_script = script;
118 }
119 if (!new_token.value.empty()) {
120 result.push_back(new_token);
121 }
122
123 return result;
124 }
125
126 } // namespace libtextclassifier2
127