• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/cached-features.h"
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 
22 namespace libtextclassifier {
23 namespace {
24 
25 class TestingCachedFeatures : public CachedFeatures {
26  public:
27   using CachedFeatures::CachedFeatures;
28   using CachedFeatures::RemapV0FeatureVector;
29 };
30 
TEST(CachedFeaturesTest,Simple)31 TEST(CachedFeaturesTest, Simple) {
32   std::vector<Token> tokens;
33   tokens.push_back(Token());
34   tokens.push_back(Token());
35   tokens.push_back(Token("Hello", 0, 1));
36   tokens.push_back(Token("World", 1, 2));
37   tokens.push_back(Token("today!", 2, 3));
38   tokens.push_back(Token());
39   tokens.push_back(Token());
40 
41   std::vector<std::vector<int>> sparse_features(tokens.size());
42   for (int i = 0; i < sparse_features.size(); ++i) {
43     sparse_features[i].push_back(i);
44   }
45   std::vector<std::vector<float>> dense_features(tokens.size());
46   for (int i = 0; i < dense_features.size(); ++i) {
47     dense_features[i].push_back(-i);
48   }
49 
50   TestingCachedFeatures feature_extractor(
51       tokens, /*context_size=*/2, sparse_features, dense_features,
52       [](const std::vector<int>& sparse_features,
53          const std::vector<float>& dense_features, float* features) {
54         features[0] = sparse_features[0];
55         features[1] = sparse_features[0];
56         features[2] = dense_features[0];
57         features[3] = dense_features[0];
58         features[4] = 123;
59         return true;
60       },
61       5);
62 
63   VectorSpan<float> features;
64   VectorSpan<Token> output_tokens;
65   EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
66   for (int i = 0; i < 5; i++) {
67     EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
68     EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
69     EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
70     EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
71     EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
72   }
73 }
74 
TEST(CachedFeaturesTest,InvalidInput)75 TEST(CachedFeaturesTest, InvalidInput) {
76   std::vector<Token> tokens;
77   tokens.push_back(Token());
78   tokens.push_back(Token());
79   tokens.push_back(Token("Hello", 0, 1));
80   tokens.push_back(Token("World", 1, 2));
81   tokens.push_back(Token("today!", 2, 3));
82   tokens.push_back(Token());
83   tokens.push_back(Token());
84 
85   std::vector<std::vector<int>> sparse_features(tokens.size());
86   std::vector<std::vector<float>> dense_features(tokens.size());
87 
88   TestingCachedFeatures feature_extractor(
89       tokens, /*context_size=*/2, sparse_features, dense_features,
90       [](const std::vector<int>& sparse_features,
91          const std::vector<float>& dense_features,
92          float* features) { return true; },
93       /*feature_vector_size=*/5);
94 
95   VectorSpan<float> features;
96   VectorSpan<Token> output_tokens;
97   EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
98   EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
99   EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
100   EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
101   EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
102   EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
103   EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
104 }
105 
TEST(CachedFeaturesTest,RemapV0FeatureVector)106 TEST(CachedFeaturesTest, RemapV0FeatureVector) {
107   std::vector<Token> tokens;
108   tokens.push_back(Token());
109   tokens.push_back(Token());
110   tokens.push_back(Token("Hello", 0, 1));
111   tokens.push_back(Token("World", 1, 2));
112   tokens.push_back(Token("today!", 2, 3));
113   tokens.push_back(Token());
114   tokens.push_back(Token());
115 
116   std::vector<std::vector<int>> sparse_features(tokens.size());
117   std::vector<std::vector<float>> dense_features(tokens.size());
118 
119   TestingCachedFeatures feature_extractor(
120       tokens, /*context_size=*/2, sparse_features, dense_features,
121       [](const std::vector<int>& sparse_features,
122          const std::vector<float>& dense_features,
123          float* features) { return true; },
124       /*feature_vector_size=*/5);
125 
126   std::vector<float> features_orig(5 * 5);
127   for (int i = 0; i < features_orig.size(); i++) {
128     features_orig[i] = i;
129   }
130   VectorSpan<float> features;
131 
132   feature_extractor.SetV0FeatureMode(0);
133   features = VectorSpan<float>(features_orig);
134   feature_extractor.RemapV0FeatureVector(&features);
135   EXPECT_EQ(
136       std::vector<float>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
137                           13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
138       std::vector<float>(features.begin(), features.end()));
139 
140   feature_extractor.SetV0FeatureMode(2);
141   features = VectorSpan<float>(features_orig);
142   feature_extractor.RemapV0FeatureVector(&features);
143   EXPECT_EQ(std::vector<float>({0, 1, 5, 6,  10, 11, 15, 16, 20, 21, 2,  3, 4,
144                                 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
145             std::vector<float>(features.begin(), features.end()));
146 }
147 
148 }  // namespace
149 }  // namespace libtextclassifier
150