• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/embedding-feature-extractor.h"
18 
19 #include "lang_id/language-identifier-features.h"
20 #include "lang_id/light-sentence-features.h"
21 #include "lang_id/light-sentence.h"
22 #include "lang_id/relevant-script-feature.h"
23 #include "gtest/gtest.h"
24 
25 namespace libtextclassifier {
26 namespace nlp_core {
27 
28 class EmbeddingFeatureExtractorTest : public ::testing::Test {
29  public:
SetUp()30   void SetUp() override {
31     // Make sure all relevant features are registered:
32     lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
33     lang_id::RelevantScriptFeature::RegisterClass();
34   }
35 };
36 
37 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
38 class TestEmbeddingFeatureExtractor
39     : public EmbeddingFeatureExtractor<lang_id::LightSentenceExtractor,
40                                        lang_id::LightSentence> {
41  public:
ArgPrefix() const42   const std::string ArgPrefix() const override { return "test"; }
43 };
44 
TEST_F(EmbeddingFeatureExtractorTest,NoEmbeddingSpaces)45 TEST_F(EmbeddingFeatureExtractorTest, NoEmbeddingSpaces) {
46   TaskContext context;
47   context.SetParameter("test_features", "");
48   context.SetParameter("test_embedding_names", "");
49   context.SetParameter("test_embedding_dims", "");
50   TestEmbeddingFeatureExtractor tefe;
51   ASSERT_TRUE(tefe.Init(&context));
52   EXPECT_EQ(tefe.NumEmbeddings(), 0);
53 }
54 
TEST_F(EmbeddingFeatureExtractorTest,GoodSpec)55 TEST_F(EmbeddingFeatureExtractorTest, GoodSpec) {
56   TaskContext context;
57   const std::string spec =
58       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
59       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
60   context.SetParameter("test_features", spec);
61   context.SetParameter("test_embedding_names", "trigram;quadgram");
62   context.SetParameter("test_embedding_dims", "16;24");
63   TestEmbeddingFeatureExtractor tefe;
64   ASSERT_TRUE(tefe.Init(&context));
65   EXPECT_EQ(tefe.NumEmbeddings(), 2);
66   EXPECT_EQ(tefe.EmbeddingSize(0), 5000);
67   EXPECT_EQ(tefe.EmbeddingDims(0), 16);
68   EXPECT_EQ(tefe.EmbeddingSize(1), 7000);
69   EXPECT_EQ(tefe.EmbeddingDims(1), 24);
70 }
71 
TEST_F(EmbeddingFeatureExtractorTest,MissmatchFmlVsNames)72 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsNames) {
73   TaskContext context;
74   const std::string spec =
75       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
76       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
77   context.SetParameter("test_features", spec);
78   context.SetParameter("test_embedding_names", "trigram");
79   context.SetParameter("test_embedding_dims", "16;16");
80   TestEmbeddingFeatureExtractor tefe;
81   ASSERT_FALSE(tefe.Init(&context));
82 }
83 
TEST_F(EmbeddingFeatureExtractorTest,MissmatchFmlVsDims)84 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsDims) {
85   TaskContext context;
86   const std::string spec =
87       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
88       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
89   context.SetParameter("test_features", spec);
90   context.SetParameter("test_embedding_names", "trigram;quadgram");
91   context.SetParameter("test_embedding_dims", "16;16;32");
92   TestEmbeddingFeatureExtractor tefe;
93   ASSERT_FALSE(tefe.Init(&context));
94 }
95 
TEST_F(EmbeddingFeatureExtractorTest,BrokenSpec)96 TEST_F(EmbeddingFeatureExtractorTest, BrokenSpec) {
97   TaskContext context;
98   const std::string spec =
99       "continuous-bag-of-ngrams(id_dim=5000;"
100       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
101   context.SetParameter("test_features", spec);
102   context.SetParameter("test_embedding_names", "trigram;quadgram");
103   context.SetParameter("test_embedding_dims", "16;16");
104   TestEmbeddingFeatureExtractor tefe;
105   ASSERT_FALSE(tefe.Init(&context));
106 }
107 
TEST_F(EmbeddingFeatureExtractorTest,MissingFeature)108 TEST_F(EmbeddingFeatureExtractorTest, MissingFeature) {
109   TaskContext context;
110   const std::string spec =
111       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
112       "no-such-feature";
113   context.SetParameter("test_features", spec);
114   context.SetParameter("test_embedding_names", "trigram;foo");
115   context.SetParameter("test_embedding_dims", "16;16");
116   TestEmbeddingFeatureExtractor tefe;
117   ASSERT_FALSE(tefe.Init(&context));
118 }
119 
TEST_F(EmbeddingFeatureExtractorTest,MultipleFeatures)120 TEST_F(EmbeddingFeatureExtractorTest, MultipleFeatures) {
121   TaskContext context;
122   const std::string spec =
123       "continuous-bag-of-ngrams(id_dim=1000,size=3);"
124       "continuous-bag-of-relevant-scripts";
125   context.SetParameter("test_features", spec);
126   context.SetParameter("test_embedding_names", "trigram;script");
127   context.SetParameter("test_embedding_dims", "8;16");
128   TestEmbeddingFeatureExtractor tefe;
129   ASSERT_TRUE(tefe.Init(&context));
130   EXPECT_EQ(tefe.NumEmbeddings(), 2);
131   EXPECT_EQ(tefe.EmbeddingSize(0), 1000);
132   EXPECT_EQ(tefe.EmbeddingDims(0), 8);
133 
134   // continuous-bag-of-relevant-scripts has its own hard-wired vocabulary size.
135   // We don't want this test to depend on that value; we just check it's bigger
136   // than 0.
137   EXPECT_GT(tefe.EmbeddingSize(1), 0);
138   EXPECT_EQ(tefe.EmbeddingDims(1), 16);
139 }
140 
141 }  // namespace nlp_core
142 }  // namespace libtextclassifier
143