• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/sentencepiece/encoder.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "utils/base/integral_types.h"
23 #include "utils/container/sorted-strings-table.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 
27 namespace libtextclassifier3 {
28 namespace {
29 
30 using testing::ElementsAre;
31 
TEST(EncoderTest,SimpleTokenization)32 TEST(EncoderTest, SimpleTokenization) {
33   const char pieces_table[] = "hell\0hello\0o\0there\0";
34   const uint32 offsets[] = {0, 5, 11, 13};
35   float scores[] = {-0.5, -1.0, -10.0, -1.0};
36   std::unique_ptr<StringSet> pieces(new SortedStringsTable(
37       /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
38   const Encoder encoder(pieces.get(),
39                         /*num_pieces=*/4, scores);
40 
41   {
42     std::vector<int> encoded_text;
43     EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
44     EXPECT_THAT(encoded_text, ElementsAre(0, 3, 5, 1));
45   }
46 
47   // Make probability of hello very low:
48   // hello gets now tokenized as hell + o.
49   scores[1] = -100.0;
50   {
51     std::vector<int> encoded_text;
52     EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
53     EXPECT_THAT(encoded_text, ElementsAre(0, 2, 4, 5, 1));
54   }
55 }
56 
TEST(EncoderTest,HandlesEdgeCases)57 TEST(EncoderTest, HandlesEdgeCases) {
58   const char pieces_table[] = "hell\0hello\0o\0there\0";
59   const uint32 offsets[] = {0, 5, 11, 13};
60   float scores[] = {-0.5, -1.0, -10.0, -1.0};
61   std::unique_ptr<StringSet> pieces(new SortedStringsTable(
62       /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
63   const Encoder encoder(pieces.get(),
64                         /*num_pieces=*/4, scores);
65   {
66     std::vector<int> encoded_text;
67     EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
68     EXPECT_THAT(encoded_text, ElementsAre(0, 2, 3, 1));
69   }
70   {
71     std::vector<int> encoded_text;
72     EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
73     EXPECT_THAT(encoded_text, ElementsAre(0, 3, 2, 1));
74   }
75   {
76     std::vector<int> encoded_text;
77     EXPECT_TRUE(encoder.Encode("", &encoded_text));
78     EXPECT_THAT(encoded_text, ElementsAre(0, 1));
79   }
80   {
81     std::vector<int> encoded_text;
82     EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
83     EXPECT_THAT(encoded_text, ElementsAre(0, 1));
84   }
85 }
86 
TEST(EncoderTest,HandlesOutOfDictionary)87 TEST(EncoderTest, HandlesOutOfDictionary) {
88   const char pieces_table[] = "hell\0hello\0o\0there\0";
89   const uint32 offsets[] = {0, 5, 11, 13};
90   float scores[] = {-0.5, -1.0, -10.0, -1.0};
91   std::unique_ptr<StringSet> pieces(new SortedStringsTable(
92       /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
93   const Encoder encoder(pieces.get(),
94                         /*num_pieces=*/4, scores,
95                         /*start_code=*/0, /*end_code=*/1,
96                         /*encoding_offset=*/3, /*unknown_code=*/2,
97                         /*unknown_score=*/-100.0);
98   {
99     std::vector<int> encoded_text;
100     EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
101     EXPECT_THAT(encoded_text, ElementsAre(0, 3, 4, 1));
102   }
103   {
104     std::vector<int> encoded_text;
105     EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
106     EXPECT_THAT(encoded_text, ElementsAre(0, 4, 3, 1));
107   }
108   {
109     std::vector<int> encoded_text;
110     EXPECT_TRUE(encoder.Encode("", &encoded_text));
111     EXPECT_THAT(encoded_text, ElementsAre(0, 1));
112   }
113   {
114     std::vector<int> encoded_text;
115     EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
116     EXPECT_THAT(encoded_text,
117                 ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
118   }
119 }
120 
121 }  // namespace
122 }  // namespace libtextclassifier3
123