• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/bert_tokenizer.h"
18 
19 #include "utils/test-data-test-utils.h"
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 
23 namespace libtextclassifier3 {
24 
25 using ::testing::ElementsAre;
26 
27 namespace {
28 constexpr char kTestVocabPath[] = "annotator/pod_ner/test_data/vocab.txt";
29 
AssertTokenizerResults(std::unique_ptr<BertTokenizer> tokenizer)30 void AssertTokenizerResults(std::unique_ptr<BertTokenizer> tokenizer) {
31   auto results = tokenizer->Tokenize("i'm question");
32 
33   EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question"));
34 }
35 
TEST(BertTokenizerTest,TestTokenizerCreationFromBuffer)36 TEST(BertTokenizerTest, TestTokenizerCreationFromBuffer) {
37   std::string buffer = GetTestFileContent(kTestVocabPath);
38 
39   auto tokenizer =
40       absl::make_unique<BertTokenizer>(buffer.data(), buffer.size());
41 
42   AssertTokenizerResults(std::move(tokenizer));
43 }
44 
TEST(BertTokenizerTest,TestTokenizerCreationFromFile)45 TEST(BertTokenizerTest, TestTokenizerCreationFromFile) {
46   auto tokenizer =
47       absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
48 
49   AssertTokenizerResults(std::move(tokenizer));
50 }
51 
TEST(BertTokenizerTest,TestTokenizerCreationFromVector)52 TEST(BertTokenizerTest, TestTokenizerCreationFromVector) {
53   std::vector<std::string> vocab;
54   vocab.emplace_back("i");
55   vocab.emplace_back("'");
56   vocab.emplace_back("m");
57   vocab.emplace_back("question");
58   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
59 
60   AssertTokenizerResults(std::move(tokenizer));
61 }
62 
TEST(BertTokenizerTest,TestTokenizerMultipleRows)63 TEST(BertTokenizerTest, TestTokenizerMultipleRows) {
64   auto tokenizer =
65       absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
66 
67   auto results = tokenizer->Tokenize("i'm questionansweraskask");
68 
69   EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question", "##ans",
70                                             "##wer", "##ask", "##ask"));
71 }
72 
TEST(BertTokenizerTest,TestTokenizeIntoWordpieces)73 TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) {
74   auto tokenizer =
75       absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
76 
77   auto results = tokenizer->TokenizeIntoWordpieces("i'm questionansweraskask");
78 
79   EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question", "##ans",
80                                             "##wer", "##ask", "##ask"));
81   EXPECT_THAT(results.wp_begin_offset, ElementsAre(0, 1, 2, 4, 12, 15, 18, 21));
82   EXPECT_THAT(results.wp_end_offset, ElementsAre(1, 2, 3, 12, 15, 18, 21, 24));
83   EXPECT_THAT(results.row_lengths, ElementsAre(1, 1, 1, 5));
84 }
85 
TEST(BertTokenizerTest,TestTokenizeIntoWordpiecesLongNonAscii)86 TEST(BertTokenizerTest, TestTokenizeIntoWordpiecesLongNonAscii) {
87   auto tokenizer =
88       absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
89 
90   std::string token;
91   for (int i = 0; i < 100; ++i) {
92     token += "ń";
93   }
94   auto results = tokenizer->TokenizeIntoWordpieces(token);
95 
96   EXPECT_THAT(results.subwords, ElementsAre("[UNK]"));
97   EXPECT_THAT(results.wp_begin_offset, ElementsAre(0));
98   EXPECT_THAT(results.wp_end_offset, ElementsAre(100));
99   EXPECT_THAT(results.row_lengths, ElementsAre(1));
100 }
101 
TEST(BertTokenizerTest,TestTokenizerUnknownTokens)102 TEST(BertTokenizerTest, TestTokenizerUnknownTokens) {
103   std::vector<std::string> vocab;
104   vocab.emplace_back("i");
105   vocab.emplace_back("'");
106   vocab.emplace_back("m");
107   vocab.emplace_back("question");
108   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
109 
110   auto results = tokenizer->Tokenize("i'm questionansweraskask");
111 
112   EXPECT_THAT(results.subwords,
113               ElementsAre("i", "'", "m", kDefaultUnknownToken));
114 }
115 
TEST(BertTokenizerTest,TestLookupId)116 TEST(BertTokenizerTest, TestLookupId) {
117   std::vector<std::string> vocab;
118   vocab.emplace_back("i");
119   vocab.emplace_back("'");
120   vocab.emplace_back("m");
121   vocab.emplace_back("question");
122   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
123 
124   int i;
125   ASSERT_FALSE(tokenizer->LookupId("iDontExist", &i));
126 
127   ASSERT_TRUE(tokenizer->LookupId("i", &i));
128   ASSERT_EQ(i, 0);
129   ASSERT_TRUE(tokenizer->LookupId("'", &i));
130   ASSERT_EQ(i, 1);
131   ASSERT_TRUE(tokenizer->LookupId("m", &i));
132   ASSERT_EQ(i, 2);
133   ASSERT_TRUE(tokenizer->LookupId("question", &i));
134   ASSERT_EQ(i, 3);
135 }
136 
TEST(BertTokenizerTest,TestLookupWord)137 TEST(BertTokenizerTest, TestLookupWord) {
138   std::vector<std::string> vocab;
139   vocab.emplace_back("i");
140   vocab.emplace_back("'");
141   vocab.emplace_back("m");
142   vocab.emplace_back("question");
143   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
144 
145   absl::string_view result;
146   ASSERT_FALSE(tokenizer->LookupWord(6, &result));
147 
148   ASSERT_TRUE(tokenizer->LookupWord(0, &result));
149   ASSERT_EQ(result, "i");
150   ASSERT_TRUE(tokenizer->LookupWord(1, &result));
151   ASSERT_EQ(result, "'");
152   ASSERT_TRUE(tokenizer->LookupWord(2, &result));
153   ASSERT_EQ(result, "m");
154   ASSERT_TRUE(tokenizer->LookupWord(3, &result));
155   ASSERT_EQ(result, "question");
156 }
157 
TEST(BertTokenizerTest,TestContains)158 TEST(BertTokenizerTest, TestContains) {
159   std::vector<std::string> vocab;
160   vocab.emplace_back("i");
161   vocab.emplace_back("'");
162   vocab.emplace_back("m");
163   vocab.emplace_back("question");
164   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
165 
166   bool result;
167   tokenizer->Contains("iDontExist", &result);
168   ASSERT_FALSE(result);
169 
170   tokenizer->Contains("i", &result);
171   ASSERT_TRUE(result);
172   tokenizer->Contains("'", &result);
173   ASSERT_TRUE(result);
174   tokenizer->Contains("m", &result);
175   ASSERT_TRUE(result);
176   tokenizer->Contains("question", &result);
177   ASSERT_TRUE(result);
178 }
179 
TEST(BertTokenizerTest,TestLVocabularySize)180 TEST(BertTokenizerTest, TestLVocabularySize) {
181   std::vector<std::string> vocab;
182   vocab.emplace_back("i");
183   vocab.emplace_back("'");
184   vocab.emplace_back("m");
185   vocab.emplace_back("question");
186   auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
187 
188   ASSERT_EQ(tokenizer->VocabularySize(), 4);
189 }
190 
TEST(BertTokenizerTest,SimpleEnglishWithPunctuation)191 TEST(BertTokenizerTest, SimpleEnglishWithPunctuation) {
192   absl::string_view input = "I am fine, thanks!";
193 
194   std::vector<std::string> tokens = BertTokenizer::PreTokenize(input);
195 
196   EXPECT_THAT(tokens, testing::ElementsAreArray(
197                           {"I", "am", "fine", ",", "thanks", "!"}));
198 }
199 }  // namespace
200 }  // namespace libtextclassifier3
201