• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "utils/base/arena.h"
24 #include "utils/flatbuffers/reflection.h"
25 #include "utils/grammar/parsing/derivation.h"
26 #include "utils/grammar/parsing/parse-tree.h"
27 #include "utils/grammar/semantics/value.h"
28 #include "utils/grammar/testing/value_generated.h"
29 #include "utils/grammar/text-context.h"
30 #include "utils/i18n/locale.h"
31 #include "utils/jvm-test-utils.h"
32 #include "utils/test-data-test-utils.h"
33 #include "utils/tokenizer.h"
34 #include "utils/utf8/unilib.h"
35 #include "gmock/gmock.h"
36 #include "flatbuffers/base.h"
37 #include "flatbuffers/flatbuffers.h"
38 
39 namespace libtextclassifier3::grammar {
40 
41 inline std::ostream& operator<<(std::ostream& os, const ParseTree* parse_tree) {
42   return os << "ParseTree(lhs=" << parse_tree->lhs
43             << ", begin=" << parse_tree->codepoint_span.first
44             << ", end=" << parse_tree->codepoint_span.second << ")";
45 }
46 
47 inline std::ostream& operator<<(std::ostream& os,
48                                 const Derivation& derivation) {
49   return os << "Derivation(rule_id=" << derivation.rule_id << ", "
50             << "parse_tree=" << derivation.parse_tree << ")";
51 }
52 
53 MATCHER_P3(IsDerivation, rule_id, begin, end,
54            "is derivation of rule that " +
55                ::testing::DescribeMatcher<int>(rule_id, negation) +
56                ", begin that " +
57                ::testing::DescribeMatcher<int>(begin, negation) +
58                ", end that " + ::testing::DescribeMatcher<int>(end, negation)) {
59   return ::testing::ExplainMatchResult(CodepointSpan(begin, end),
60                                        arg.parse_tree->codepoint_span,
61                                        result_listener) &&
62          ::testing::ExplainMatchResult(rule_id, arg.rule_id, result_listener);
63 }
64 
65 // A test fixture with common auxiliary test methods.
66 class GrammarTest : public testing::Test {
67  protected:
GrammarTest()68   explicit GrammarTest()
69       : unilib_(CreateUniLibForTesting()),
70         arena_(/*block_size=*/16 << 10),
71         semantic_values_schema_(
72             GetTestFileContent("utils/grammar/testing/value.bfbs")),
73         tokenizer_(libtextclassifier3::TokenizationType_ICU, unilib_.get(),
74                    /*codepoint_ranges=*/{},
75                    /*internal_tokenizer_codepoint_ranges=*/{},
76                    /*split_on_script_change=*/false,
77                    /*icu_preserve_whitespace_tokens=*/false) {}
78 
TextContextForText(const std::string & text)79   TextContext TextContextForText(const std::string& text) {
80     TextContext context;
81     context.text = UTF8ToUnicodeText(text);
82     context.tokens = tokenizer_.Tokenize(context.text);
83     context.codepoints = context.text.Codepoints();
84     context.codepoints.push_back(context.text.end());
85     context.locales = {Locale::FromBCP47("en")};
86     context.context_span.first = 0;
87     context.context_span.second = context.tokens.size();
88     return context;
89   }
90 
91   // Creates a semantic expression union.
92   template <typename T>
AsSemanticExpressionUnion(T && expression)93   SemanticExpressionT AsSemanticExpressionUnion(T&& expression) {
94     SemanticExpressionT semantic_expression;
95     semantic_expression.expression.Set(std::forward<T>(expression));
96     return semantic_expression;
97   }
98 
99   template <typename T>
CreateExpression(T && expression)100   OwnedFlatbuffer<SemanticExpression> CreateExpression(T&& expression) {
101     return Pack<SemanticExpression>(
102         AsSemanticExpressionUnion(std::forward<T>(expression)));
103   }
104 
CreateEmptyExpression()105   OwnedFlatbuffer<SemanticExpression> CreateEmptyExpression() {
106     return Pack<SemanticExpression>(SemanticExpressionT());
107   }
108 
109   // Packs a flatbuffer.
110   template <typename T>
Pack(const typename T::NativeTableType && value)111   OwnedFlatbuffer<T> Pack(const typename T::NativeTableType&& value) {
112     flatbuffers::FlatBufferBuilder builder;
113     builder.Finish(T::Pack(builder, &value));
114     return OwnedFlatbuffer<T>(builder.Release());
115   }
116 
117   // Creates a test semantic value.
CreateSemanticValue(const TestValueT & value)118   const SemanticValue* CreateSemanticValue(const TestValueT& value) {
119     const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
120     return arena_.AllocAndInit<SemanticValue>(
121         semantic_values_schema_->objects()->Get(
122             TypeIdForName(semantic_values_schema_.get(),
123                           "libtextclassifier3.grammar.TestValue")
124                 .value()),
125         StringPiece(arena_.Memdup(value_buffer.data(), value_buffer.size()),
126                     value_buffer.size()));
127   }
128 
129   // Creates a primitive semantic value.
130   template <typename T>
CreatePrimitiveSemanticValue(const T value)131   const SemanticValue* CreatePrimitiveSemanticValue(const T value) {
132     return arena_.AllocAndInit<SemanticValue>(value);
133   }
134 
CreateConstExpression(const TestValueT & value)135   std::unique_ptr<SemanticExpressionT> CreateConstExpression(
136       const TestValueT& value) {
137     ConstValueExpressionT const_value;
138     const_value.base_type = reflection::BaseType::Obj;
139     const_value.type = TypeIdForName(semantic_values_schema_.get(),
140                                      "libtextclassifier3.grammar.TestValue")
141                            .value();
142     const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
143     const_value.value.assign(value_buffer.begin(), value_buffer.end());
144     auto semantic_expression = std::make_unique<SemanticExpressionT>();
145     semantic_expression->expression.Set(const_value);
146     return semantic_expression;
147   }
148 
CreateAndPackConstExpression(const TestValueT & value)149   OwnedFlatbuffer<SemanticExpression> CreateAndPackConstExpression(
150       const TestValueT& value) {
151     ConstValueExpressionT const_value;
152     const_value.base_type = reflection::BaseType::Obj;
153     const_value.type = TypeIdForName(semantic_values_schema_.get(),
154                                      "libtextclassifier3.grammar.TestValue")
155                            .value();
156     const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
157     const_value.value.assign(value_buffer.begin(), value_buffer.end());
158     return CreateExpression(const_value);
159   }
160 
CreateConstDateExpression(const TestDateT & value)161   std::unique_ptr<SemanticExpressionT> CreateConstDateExpression(
162       const TestDateT& value) {
163     ConstValueExpressionT const_value;
164     const_value.base_type = reflection::BaseType::Obj;
165     const_value.type = TypeIdForName(semantic_values_schema_.get(),
166                                      "libtextclassifier3.grammar.TestDate")
167                            .value();
168     const std::string value_buffer = PackFlatbuffer<TestDate>(&value);
169     const_value.value.assign(value_buffer.begin(), value_buffer.end());
170     auto semantic_expression = std::make_unique<SemanticExpressionT>();
171     semantic_expression->expression.Set(const_value);
172     return semantic_expression;
173   }
174 
CreateAndPackMergeValuesExpression(const std::vector<TestDateT> & values)175   OwnedFlatbuffer<SemanticExpression> CreateAndPackMergeValuesExpression(
176       const std::vector<TestDateT>& values) {
177     MergeValueExpressionT merge_expression;
178     merge_expression.type = TypeIdForName(semantic_values_schema_.get(),
179                                           "libtextclassifier3.grammar.TestDate")
180                                 .value();
181     for (const TestDateT& test_date : values) {
182       merge_expression.values.emplace_back(new SemanticExpressionT);
183       merge_expression.values.back() = CreateConstDateExpression(test_date);
184     }
185     return CreateExpression(std::move(merge_expression));
186   }
187 
188   template <typename T>
CreatePrimitiveConstExpression(const T value)189   std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
190       const T value) {
191     ConstValueExpressionT const_value;
192     const_value.base_type = flatbuffers_base_type<T>::value;
193     const_value.value.resize(sizeof(T));
194     flatbuffers::WriteScalar(const_value.value.data(), value);
195     auto semantic_expression = std::make_unique<SemanticExpressionT>();
196     semantic_expression->expression.Set(const_value);
197     return semantic_expression;
198   }
199 
200   template <typename T>
CreateAndPackPrimitiveConstExpression(const T value)201   OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
202       const T value) {
203     ConstValueExpressionT const_value;
204     const_value.base_type = flatbuffers_base_type<T>::value;
205     const_value.value.resize(sizeof(T));
206     flatbuffers::WriteScalar(const_value.value.data(), value);
207     return CreateExpression(const_value);
208   }
209 
210   template <>
CreateAndPackPrimitiveConstExpression(const StringPiece value)211   OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
212       const StringPiece value) {
213     ConstValueExpressionT const_value;
214     const_value.base_type = reflection::BaseType::String;
215     const_value.value.assign(value.data(), value.data() + value.size());
216     return CreateExpression(const_value);
217   }
218 
219   template <>
CreatePrimitiveConstExpression(const StringPiece value)220   std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
221       const StringPiece value) {
222     ConstValueExpressionT const_value;
223     const_value.base_type = reflection::BaseType::String;
224     const_value.value.assign(value.data(), value.data() + value.size());
225     auto semantic_expression = std::make_unique<SemanticExpressionT>();
226     semantic_expression->expression.Set(const_value);
227     return semantic_expression;
228   }
229 
230   const std::unique_ptr<UniLib> unilib_;
231   UnsafeArena arena_;
232   const OwnedFlatbuffer<reflection::Schema, std::string>
233       semantic_values_schema_;
234   const Tokenizer tokenizer_;
235 };
236 
237 }  // namespace libtextclassifier3::grammar
238 
239 #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
240