1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/bert_tokenizer.h"
18
19 #include "utils/test-data-test-utils.h"
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22
23 namespace libtextclassifier3 {
24
25 using ::testing::ElementsAre;
26
27 namespace {
28 constexpr char kTestVocabPath[] = "annotator/pod_ner/test_data/vocab.txt";
29
AssertTokenizerResults(std::unique_ptr<BertTokenizer> tokenizer)30 void AssertTokenizerResults(std::unique_ptr<BertTokenizer> tokenizer) {
31 auto results = tokenizer->Tokenize("i'm question");
32
33 EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question"));
34 }
35
TEST(BertTokenizerTest,TestTokenizerCreationFromBuffer)36 TEST(BertTokenizerTest, TestTokenizerCreationFromBuffer) {
37 std::string buffer = GetTestFileContent(kTestVocabPath);
38
39 auto tokenizer =
40 absl::make_unique<BertTokenizer>(buffer.data(), buffer.size());
41
42 AssertTokenizerResults(std::move(tokenizer));
43 }
44
TEST(BertTokenizerTest,TestTokenizerCreationFromFile)45 TEST(BertTokenizerTest, TestTokenizerCreationFromFile) {
46 auto tokenizer =
47 absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
48
49 AssertTokenizerResults(std::move(tokenizer));
50 }
51
TEST(BertTokenizerTest,TestTokenizerCreationFromVector)52 TEST(BertTokenizerTest, TestTokenizerCreationFromVector) {
53 std::vector<std::string> vocab;
54 vocab.emplace_back("i");
55 vocab.emplace_back("'");
56 vocab.emplace_back("m");
57 vocab.emplace_back("question");
58 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
59
60 AssertTokenizerResults(std::move(tokenizer));
61 }
62
TEST(BertTokenizerTest,TestTokenizerMultipleRows)63 TEST(BertTokenizerTest, TestTokenizerMultipleRows) {
64 auto tokenizer =
65 absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
66
67 auto results = tokenizer->Tokenize("i'm questionansweraskask");
68
69 EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question", "##ans",
70 "##wer", "##ask", "##ask"));
71 }
72
TEST(BertTokenizerTest,TestTokenizeIntoWordpieces)73 TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) {
74 auto tokenizer =
75 absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
76
77 auto results = tokenizer->TokenizeIntoWordpieces("i'm questionansweraskask");
78
79 EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question", "##ans",
80 "##wer", "##ask", "##ask"));
81 EXPECT_THAT(results.wp_begin_offset, ElementsAre(0, 1, 2, 4, 12, 15, 18, 21));
82 EXPECT_THAT(results.wp_end_offset, ElementsAre(1, 2, 3, 12, 15, 18, 21, 24));
83 EXPECT_THAT(results.row_lengths, ElementsAre(1, 1, 1, 5));
84 }
85
TEST(BertTokenizerTest,TestTokenizeIntoWordpiecesLongNonAscii)86 TEST(BertTokenizerTest, TestTokenizeIntoWordpiecesLongNonAscii) {
87 auto tokenizer =
88 absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
89
90 std::string token;
91 for (int i = 0; i < 100; ++i) {
92 token += "ń";
93 }
94 auto results = tokenizer->TokenizeIntoWordpieces(token);
95
96 EXPECT_THAT(results.subwords, ElementsAre("[UNK]"));
97 EXPECT_THAT(results.wp_begin_offset, ElementsAre(0));
98 EXPECT_THAT(results.wp_end_offset, ElementsAre(100));
99 EXPECT_THAT(results.row_lengths, ElementsAre(1));
100 }
101
TEST(BertTokenizerTest,TestTokenizerUnknownTokens)102 TEST(BertTokenizerTest, TestTokenizerUnknownTokens) {
103 std::vector<std::string> vocab;
104 vocab.emplace_back("i");
105 vocab.emplace_back("'");
106 vocab.emplace_back("m");
107 vocab.emplace_back("question");
108 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
109
110 auto results = tokenizer->Tokenize("i'm questionansweraskask");
111
112 EXPECT_THAT(results.subwords,
113 ElementsAre("i", "'", "m", kDefaultUnknownToken));
114 }
115
TEST(BertTokenizerTest,TestLookupId)116 TEST(BertTokenizerTest, TestLookupId) {
117 std::vector<std::string> vocab;
118 vocab.emplace_back("i");
119 vocab.emplace_back("'");
120 vocab.emplace_back("m");
121 vocab.emplace_back("question");
122 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
123
124 int i;
125 ASSERT_FALSE(tokenizer->LookupId("iDontExist", &i));
126
127 ASSERT_TRUE(tokenizer->LookupId("i", &i));
128 ASSERT_EQ(i, 0);
129 ASSERT_TRUE(tokenizer->LookupId("'", &i));
130 ASSERT_EQ(i, 1);
131 ASSERT_TRUE(tokenizer->LookupId("m", &i));
132 ASSERT_EQ(i, 2);
133 ASSERT_TRUE(tokenizer->LookupId("question", &i));
134 ASSERT_EQ(i, 3);
135 }
136
TEST(BertTokenizerTest,TestLookupWord)137 TEST(BertTokenizerTest, TestLookupWord) {
138 std::vector<std::string> vocab;
139 vocab.emplace_back("i");
140 vocab.emplace_back("'");
141 vocab.emplace_back("m");
142 vocab.emplace_back("question");
143 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
144
145 absl::string_view result;
146 ASSERT_FALSE(tokenizer->LookupWord(6, &result));
147
148 ASSERT_TRUE(tokenizer->LookupWord(0, &result));
149 ASSERT_EQ(result, "i");
150 ASSERT_TRUE(tokenizer->LookupWord(1, &result));
151 ASSERT_EQ(result, "'");
152 ASSERT_TRUE(tokenizer->LookupWord(2, &result));
153 ASSERT_EQ(result, "m");
154 ASSERT_TRUE(tokenizer->LookupWord(3, &result));
155 ASSERT_EQ(result, "question");
156 }
157
TEST(BertTokenizerTest,TestContains)158 TEST(BertTokenizerTest, TestContains) {
159 std::vector<std::string> vocab;
160 vocab.emplace_back("i");
161 vocab.emplace_back("'");
162 vocab.emplace_back("m");
163 vocab.emplace_back("question");
164 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
165
166 bool result;
167 tokenizer->Contains("iDontExist", &result);
168 ASSERT_FALSE(result);
169
170 tokenizer->Contains("i", &result);
171 ASSERT_TRUE(result);
172 tokenizer->Contains("'", &result);
173 ASSERT_TRUE(result);
174 tokenizer->Contains("m", &result);
175 ASSERT_TRUE(result);
176 tokenizer->Contains("question", &result);
177 ASSERT_TRUE(result);
178 }
179
TEST(BertTokenizerTest,TestLVocabularySize)180 TEST(BertTokenizerTest, TestLVocabularySize) {
181 std::vector<std::string> vocab;
182 vocab.emplace_back("i");
183 vocab.emplace_back("'");
184 vocab.emplace_back("m");
185 vocab.emplace_back("question");
186 auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
187
188 ASSERT_EQ(tokenizer->VocabularySize(), 4);
189 }
190
TEST(BertTokenizerTest,SimpleEnglishWithPunctuation)191 TEST(BertTokenizerTest, SimpleEnglishWithPunctuation) {
192 absl::string_view input = "I am fine, thanks!";
193
194 std::vector<std::string> tokens = BertTokenizer::PreTokenize(input);
195
196 EXPECT_THAT(tokens, testing::ElementsAreArray(
197 {"I", "am", "fine", ",", "thanks", "!"}));
198 }
199 } // namespace
200 } // namespace libtextclassifier3
201