/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "smartselect/token-feature-extractor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier { namespace { class TestingTokenFeatureExtractor : public TokenFeatureExtractor { public: using TokenFeatureExtractor::TokenFeatureExtractor; using TokenFeatureExtractor::HashToken; }; TEST(TokenFeatureExtractorTest, ExtractAscii) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2, 3}; options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; TestingTokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({ // clang-format off extractor.HashToken("H"), extractor.HashToken("e"), extractor.HashToken("l"), extractor.HashToken("l"), extractor.HashToken("o"), extractor.HashToken("^H"), extractor.HashToken("He"), extractor.HashToken("el"), extractor.HashToken("ll"), extractor.HashToken("lo"), extractor.HashToken("o$"), extractor.HashToken("^He"), extractor.HashToken("Hel"), extractor.HashToken("ell"), extractor.HashToken("llo"), extractor.HashToken("lo$") // clang-format on })); EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); sparse_features.clear(); dense_features.clear(); extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({ // clang-format off extractor.HashToken("w"), extractor.HashToken("o"), extractor.HashToken("r"), extractor.HashToken("l"), extractor.HashToken("d"), extractor.HashToken("!"), extractor.HashToken("^w"), extractor.HashToken("wo"), extractor.HashToken("or"), extractor.HashToken("rl"), extractor.HashToken("ld"), extractor.HashToken("d!"), extractor.HashToken("!$"), extractor.HashToken("^wo"), extractor.HashToken("wor"), extractor.HashToken("orl"), extractor.HashToken("rld"), extractor.HashToken("ld!"), extractor.HashToken("d!$"), // clang-format on })); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); } TEST(TokenFeatureExtractorTest, ExtractUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2, 3}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; TestingTokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({ // clang-format off extractor.HashToken("H"), extractor.HashToken("ě"), extractor.HashToken("l"), extractor.HashToken("l"), extractor.HashToken("ó"), extractor.HashToken("^H"), extractor.HashToken("Hě"), extractor.HashToken("ěl"), extractor.HashToken("ll"), extractor.HashToken("ló"), extractor.HashToken("ó$"), extractor.HashToken("^Hě"), extractor.HashToken("Hěl"), extractor.HashToken("ěll"), extractor.HashToken("lló"), extractor.HashToken("ló$") // clang-format on })); EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); sparse_features.clear(); dense_features.clear(); extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({ // clang-format off extractor.HashToken("w"), extractor.HashToken("o"), extractor.HashToken("r"), extractor.HashToken("l"), extractor.HashToken("d"), extractor.HashToken("!"), extractor.HashToken("^w"), extractor.HashToken("wo"), extractor.HashToken("or"), extractor.HashToken("rl"), extractor.HashToken("ld"), extractor.HashToken("d!"), extractor.HashToken("!$"), extractor.HashToken("^wo"), extractor.HashToken("wor"), extractor.HashToken("orl"), extractor.HashToken("rld"), extractor.HashToken("ld!"), extractor.HashToken("d!$"), // clang-format on })); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); } TEST(TokenFeatureExtractorTest, ICUCaseFeature) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = false; TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); sparse_features.clear(); dense_features.clear(); extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); sparse_features.clear(); dense_features.clear(); extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); sparse_features.clear(); dense_features.clear(); extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); } TEST(TokenFeatureExtractorTest, DigitRemapping) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.remap_digits = true; options.unicode_aware_features = false; TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, &dense_features); std::vector sparse_features2; extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::Not(testing::ElementsAreArray(sparse_features2))); } TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.remap_digits = true; options.unicode_aware_features = true; TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, &dense_features); std::vector sparse_features2; extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::Not(testing::ElementsAreArray(sparse_features2))); } TEST(TokenFeatureExtractorTest, LowercaseAscii) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = false; TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features, &dense_features); std::vector sparse_features2; extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); } TEST(TokenFeatureExtractorTest, LowercaseUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = true; TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features); std::vector sparse_features2; extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); } TEST(TokenFeatureExtractorTest, RegexFeatures) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.remap_digits = false; options.unicode_aware_features = false; options.regexp_features.push_back("^[a-z]+$"); // all lower case. options.regexp_features.push_back("^[0-9]+$"); // all digits. TokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); dense_features.clear(); extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0})); dense_features.clear(); extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); dense_features.clear(); extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features, &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0})); } TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{22}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; TestingTokenFeatureExtractor extractor(options); // Test that this runs. ASAN should catch problems. std::vector sparse_features; std::vector dense_features; extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({ // clang-format off extractor.HashToken("^abcdefghij\1qřstuvwxyz"), extractor.HashToken("abcdefghij\1qřstuvwxyz$"), // clang-format on })); } TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2, 3, 4, 5}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; TestingTokenFeatureExtractor extractor_unicode(options); options.unicode_aware_features = false; TestingTokenFeatureExtractor extractor_ascii(options); for (const std::string& input : {"https://www.abcdefgh.com/in/xxxkkkvayio", "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil", "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd", "x", "Hello", "Hey,", "Hi", ""}) { std::vector sparse_features_unicode; std::vector dense_features_unicode; extractor_unicode.Extract(Token{input, 0, 0}, true, &sparse_features_unicode, &dense_features_unicode); std::vector sparse_features_ascii; std::vector dense_features_ascii; extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii, &dense_features_ascii); EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input; EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input; } } TEST(TokenFeatureExtractorTest, ExtractForPadToken) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector{1, 2}; options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; TestingTokenFeatureExtractor extractor(options); std::vector sparse_features; std::vector dense_features; extractor.Extract(Token(), false, &sparse_features, &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray({extractor.HashToken("")})); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); } } // namespace } // namespace libtextclassifier