1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "utils/base/arena.h" 24 #include "utils/flatbuffers/reflection.h" 25 #include "utils/grammar/parsing/derivation.h" 26 #include "utils/grammar/parsing/parse-tree.h" 27 #include "utils/grammar/semantics/value.h" 28 #include "utils/grammar/testing/value_generated.h" 29 #include "utils/grammar/text-context.h" 30 #include "utils/i18n/locale.h" 31 #include "utils/jvm-test-utils.h" 32 #include "utils/test-data-test-utils.h" 33 #include "utils/tokenizer.h" 34 #include "utils/utf8/unilib.h" 35 #include "gmock/gmock.h" 36 #include "flatbuffers/base.h" 37 #include "flatbuffers/flatbuffers.h" 38 39 namespace libtextclassifier3::grammar { 40 41 inline std::ostream& operator<<(std::ostream& os, const ParseTree* parse_tree) { 42 return os << "ParseTree(lhs=" << parse_tree->lhs 43 << ", begin=" << parse_tree->codepoint_span.first 44 << ", end=" << parse_tree->codepoint_span.second << ")"; 45 } 46 47 inline std::ostream& operator<<(std::ostream& os, 48 const Derivation& derivation) { 49 return os << "Derivation(rule_id=" << derivation.rule_id << ", " 50 << "parse_tree=" << derivation.parse_tree << ")"; 51 } 52 53 MATCHER_P3(IsDerivation, rule_id, begin, end, 54 "is derivation of rule that " + 55 ::testing::DescribeMatcher<int>(rule_id, negation) + 56 ", begin that " + 57 ::testing::DescribeMatcher<int>(begin, negation) + 58 ", end that " + ::testing::DescribeMatcher<int>(end, negation)) { 59 return ::testing::ExplainMatchResult(CodepointSpan(begin, end), 60 arg.parse_tree->codepoint_span, 61 result_listener) && 62 ::testing::ExplainMatchResult(rule_id, arg.rule_id, result_listener); 63 } 64 65 // A test fixture with common auxiliary test methods. 66 class GrammarTest : public testing::Test { 67 protected: GrammarTest()68 explicit GrammarTest() 69 : unilib_(CreateUniLibForTesting()), 70 arena_(/*block_size=*/16 << 10), 71 semantic_values_schema_( 72 GetTestFileContent("utils/grammar/testing/value.bfbs")), 73 tokenizer_(libtextclassifier3::TokenizationType_ICU, unilib_.get(), 74 /*codepoint_ranges=*/{}, 75 /*internal_tokenizer_codepoint_ranges=*/{}, 76 /*split_on_script_change=*/false, 77 /*icu_preserve_whitespace_tokens=*/false) {} 78 TextContextForText(const std::string & text)79 TextContext TextContextForText(const std::string& text) { 80 TextContext context; 81 context.text = UTF8ToUnicodeText(text); 82 context.tokens = tokenizer_.Tokenize(context.text); 83 context.codepoints = context.text.Codepoints(); 84 context.codepoints.push_back(context.text.end()); 85 context.locales = {Locale::FromBCP47("en")}; 86 context.context_span.first = 0; 87 context.context_span.second = context.tokens.size(); 88 return context; 89 } 90 91 // Creates a semantic expression union. 92 template <typename T> AsSemanticExpressionUnion(T && expression)93 SemanticExpressionT AsSemanticExpressionUnion(T&& expression) { 94 SemanticExpressionT semantic_expression; 95 semantic_expression.expression.Set(std::forward<T>(expression)); 96 return semantic_expression; 97 } 98 99 template <typename T> CreateExpression(T && expression)100 OwnedFlatbuffer<SemanticExpression> CreateExpression(T&& expression) { 101 return Pack<SemanticExpression>( 102 AsSemanticExpressionUnion(std::forward<T>(expression))); 103 } 104 CreateEmptyExpression()105 OwnedFlatbuffer<SemanticExpression> CreateEmptyExpression() { 106 return Pack<SemanticExpression>(SemanticExpressionT()); 107 } 108 109 // Packs a flatbuffer. 110 template <typename T> Pack(const typename T::NativeTableType && value)111 OwnedFlatbuffer<T> Pack(const typename T::NativeTableType&& value) { 112 flatbuffers::FlatBufferBuilder builder; 113 builder.Finish(T::Pack(builder, &value)); 114 return OwnedFlatbuffer<T>(builder.Release()); 115 } 116 117 // Creates a test semantic value. CreateSemanticValue(const TestValueT & value)118 const SemanticValue* CreateSemanticValue(const TestValueT& value) { 119 const std::string value_buffer = PackFlatbuffer<TestValue>(&value); 120 return arena_.AllocAndInit<SemanticValue>( 121 semantic_values_schema_->objects()->Get( 122 TypeIdForName(semantic_values_schema_.get(), 123 "libtextclassifier3.grammar.TestValue") 124 .value()), 125 StringPiece(arena_.Memdup(value_buffer.data(), value_buffer.size()), 126 value_buffer.size())); 127 } 128 129 // Creates a primitive semantic value. 130 template <typename T> CreatePrimitiveSemanticValue(const T value)131 const SemanticValue* CreatePrimitiveSemanticValue(const T value) { 132 return arena_.AllocAndInit<SemanticValue>(value); 133 } 134 CreateConstExpression(const TestValueT & value)135 std::unique_ptr<SemanticExpressionT> CreateConstExpression( 136 const TestValueT& value) { 137 ConstValueExpressionT const_value; 138 const_value.base_type = reflection::BaseType::Obj; 139 const_value.type = TypeIdForName(semantic_values_schema_.get(), 140 "libtextclassifier3.grammar.TestValue") 141 .value(); 142 const std::string value_buffer = PackFlatbuffer<TestValue>(&value); 143 const_value.value.assign(value_buffer.begin(), value_buffer.end()); 144 auto semantic_expression = std::make_unique<SemanticExpressionT>(); 145 semantic_expression->expression.Set(const_value); 146 return semantic_expression; 147 } 148 CreateAndPackConstExpression(const TestValueT & value)149 OwnedFlatbuffer<SemanticExpression> CreateAndPackConstExpression( 150 const TestValueT& value) { 151 ConstValueExpressionT const_value; 152 const_value.base_type = reflection::BaseType::Obj; 153 const_value.type = TypeIdForName(semantic_values_schema_.get(), 154 "libtextclassifier3.grammar.TestValue") 155 .value(); 156 const std::string value_buffer = PackFlatbuffer<TestValue>(&value); 157 const_value.value.assign(value_buffer.begin(), value_buffer.end()); 158 return CreateExpression(const_value); 159 } 160 CreateConstDateExpression(const TestDateT & value)161 std::unique_ptr<SemanticExpressionT> CreateConstDateExpression( 162 const TestDateT& value) { 163 ConstValueExpressionT const_value; 164 const_value.base_type = reflection::BaseType::Obj; 165 const_value.type = TypeIdForName(semantic_values_schema_.get(), 166 "libtextclassifier3.grammar.TestDate") 167 .value(); 168 const std::string value_buffer = PackFlatbuffer<TestDate>(&value); 169 const_value.value.assign(value_buffer.begin(), value_buffer.end()); 170 auto semantic_expression = std::make_unique<SemanticExpressionT>(); 171 semantic_expression->expression.Set(const_value); 172 return semantic_expression; 173 } 174 CreateAndPackMergeValuesExpression(const std::vector<TestDateT> & values)175 OwnedFlatbuffer<SemanticExpression> CreateAndPackMergeValuesExpression( 176 const std::vector<TestDateT>& values) { 177 MergeValueExpressionT merge_expression; 178 merge_expression.type = TypeIdForName(semantic_values_schema_.get(), 179 "libtextclassifier3.grammar.TestDate") 180 .value(); 181 for (const TestDateT& test_date : values) { 182 merge_expression.values.emplace_back(new SemanticExpressionT); 183 merge_expression.values.back() = CreateConstDateExpression(test_date); 184 } 185 return CreateExpression(std::move(merge_expression)); 186 } 187 188 template <typename T> CreatePrimitiveConstExpression(const T value)189 std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression( 190 const T value) { 191 ConstValueExpressionT const_value; 192 const_value.base_type = flatbuffers_base_type<T>::value; 193 const_value.value.resize(sizeof(T)); 194 flatbuffers::WriteScalar(const_value.value.data(), value); 195 auto semantic_expression = std::make_unique<SemanticExpressionT>(); 196 semantic_expression->expression.Set(const_value); 197 return semantic_expression; 198 } 199 200 template <typename T> CreateAndPackPrimitiveConstExpression(const T value)201 OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression( 202 const T value) { 203 ConstValueExpressionT const_value; 204 const_value.base_type = flatbuffers_base_type<T>::value; 205 const_value.value.resize(sizeof(T)); 206 flatbuffers::WriteScalar(const_value.value.data(), value); 207 return CreateExpression(const_value); 208 } 209 210 template <> CreateAndPackPrimitiveConstExpression(const StringPiece value)211 OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression( 212 const StringPiece value) { 213 ConstValueExpressionT const_value; 214 const_value.base_type = reflection::BaseType::String; 215 const_value.value.assign(value.data(), value.data() + value.size()); 216 return CreateExpression(const_value); 217 } 218 219 template <> CreatePrimitiveConstExpression(const StringPiece value)220 std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression( 221 const StringPiece value) { 222 ConstValueExpressionT const_value; 223 const_value.base_type = reflection::BaseType::String; 224 const_value.value.assign(value.data(), value.data() + value.size()); 225 auto semantic_expression = std::make_unique<SemanticExpressionT>(); 226 semantic_expression->expression.Set(const_value); 227 return semantic_expression; 228 } 229 230 const std::unique_ptr<UniLib> unilib_; 231 UnsafeArena arena_; 232 const OwnedFlatbuffer<reflection::Schema, std::string> 233 semantic_values_schema_; 234 const Tokenizer tokenizer_; 235 }; 236 237 } // namespace libtextclassifier3::grammar 238 239 #endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_ 240