• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_TESTING_COMMON_MATCHERS_H_
16 #define ICING_TESTING_COMMON_MATCHERS_H_
17 
18 #include <algorithm>
19 #include <cinttypes>
20 #include <cmath>
21 #include <string>
22 #include <vector>
23 
24 #include "icing/text_classifier/lib3/utils/base/status.h"
25 #include "icing/text_classifier/lib3/utils/base/status_macros.h"
26 #include "gmock/gmock.h"
27 #include "gtest/gtest.h"
28 #include "icing/absl_ports/str_join.h"
29 #include "icing/index/hit/doc-hit-info.h"
30 #include "icing/index/hit/hit.h"
31 #include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
32 #include "icing/legacy/core/icing-string-util.h"
33 #include "icing/portable/equals-proto.h"
34 #include "icing/proto/search.pb.h"
35 #include "icing/proto/status.pb.h"
36 #include "icing/schema/joinable-property.h"
37 #include "icing/schema/schema-store.h"
38 #include "icing/schema/section.h"
39 #include "icing/scoring/scored-document-hit.h"
40 
41 namespace icing {
42 namespace lib {
43 
44 // Used to match Token(Token::Type type, std::string_view text)
45 MATCHER_P2(EqualsToken, type, text, "") {
46   std::string arg_string(arg.text.data(), arg.text.length());
47   if (arg.type != type || arg.text != text) {
48     *result_listener << IcingStringUtil::StringPrintf(
49         "(Expected: type=%d, text=\"%s\". Actual: type=%d, text=\"%s\")", type,
50         text, arg.type, arg_string.c_str());
51     return false;
52   }
53   return true;
54 }
55 
56 // Used to match a DocHitInfo
57 MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") {
58   const DocHitInfo& actual = arg;
59   SectionIdMask section_mask = kSectionIdMaskNone;
60   for (SectionId section_id : section_ids) {
61     section_mask |= UINT64_C(1) << section_id;
62   }
63   *result_listener << IcingStringUtil::StringPrintf(
64       "(actual is {document_id=%d, section_mask=%" PRIu64
65       "}, but expected was "
66       "{document_id=%d, section_mask=%" PRIu64 "}.)",
67       actual.document_id(), actual.hit_section_ids_mask(), document_id,
68       section_mask);
69   return actual.document_id() == document_id &&
70          actual.hit_section_ids_mask() == section_mask;
71 }
72 
73 struct ExtractTermFrequenciesResult {
74   std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies = {0};
75   SectionIdMask section_mask = kSectionIdMaskNone;
76 };
77 // Extracts the term frequencies represented by the section_ids_tf_map.
78 // Returns:
79 //   - a SectionIdMask representing all sections that appears as entries in the
80 //     map, even if they have an entry with term_frequency==0
81 //   - an array representing the term frequencies for each section. Sections not
82 //     present in section_ids_tf_map have a term frequency of 0.
83 ExtractTermFrequenciesResult ExtractTermFrequencies(
84     const std::unordered_map<SectionId, Hit::TermFrequency>&
85         section_ids_tf_map);
86 
87 struct CheckTermFrequencyResult {
88   std::string expected_term_frequencies_str;
89   std::string actual_term_frequencies_str;
90   bool term_frequencies_match = true;
91 };
92 // Checks that the term frequencies in actual_term_frequencies match those
93 // specified in expected_section_ids_tf_map. If there is no entry in
94 // expected_section_ids_tf_map, then it is assumed that the term frequency for
95 // that section is 0.
96 // Returns:
97 //   - a bool indicating if the term frequencies match
98 //   - debug strings representing the contents of the actual and expected term
99 //     term frequency arrays.
100 CheckTermFrequencyResult CheckTermFrequency(
101     const std::array<Hit::TermFrequency, kTotalNumSections>&
102         expected_term_frequencies,
103     const std::array<Hit::TermFrequency, kTotalNumSections>&
104         actual_term_frequencies);
105 
106 // Used to match a DocHitInfo
107 MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id,
108            section_ids_to_term_frequencies_map, "") {
109   const DocHitInfoTermFrequencyPair& actual = arg;
110   std::array<Hit::TermFrequency, kTotalNumSections> actual_tf_array;
111   for (SectionId section_id = 0; section_id < kTotalNumSections; ++section_id) {
112     actual_tf_array[section_id] = actual.hit_term_frequency(section_id);
113   }
114   ExtractTermFrequenciesResult expected =
115       ExtractTermFrequencies(section_ids_to_term_frequencies_map);
116   CheckTermFrequencyResult check_tf_result =
117       CheckTermFrequency(expected.term_frequencies, actual_tf_array);
118 
119   *result_listener << IcingStringUtil::StringPrintf(
120       "(actual is {document_id=%d, section_mask=%" PRIu64
121       ", term_frequencies=%s}, but expected was "
122       "{document_id=%d, section_mask=%" PRIu64 ", term_frequencies=%s}.)",
123       actual.doc_hit_info().document_id(),
124       actual.doc_hit_info().hit_section_ids_mask(),
125       check_tf_result.actual_term_frequencies_str.c_str(), document_id,
126       expected.section_mask,
127       check_tf_result.expected_term_frequencies_str.c_str());
128   return actual.doc_hit_info().document_id() == document_id &&
129          actual.doc_hit_info().hit_section_ids_mask() ==
130              expected.section_mask &&
131          check_tf_result.term_frequencies_match;
132 }
133 
134 MATCHER_P2(EqualsTermMatchInfo, term, section_ids_to_term_frequencies_map, "") {
135   const TermMatchInfo& actual = arg;
136   std::string term_str(term);
137   ExtractTermFrequenciesResult expected =
138       ExtractTermFrequencies(section_ids_to_term_frequencies_map);
139   CheckTermFrequencyResult check_tf_result =
140       CheckTermFrequency(expected.term_frequencies, actual.term_frequencies);
141   *result_listener << IcingStringUtil::StringPrintf(
142       "(actual is {term=%s, section_mask=%" PRIu64
143       ", term_frequencies=%s}, but expected was "
144       "{term=%s, section_mask=%" PRIu64 ", term_frequencies=%s}.)",
145       actual.term.data(), actual.section_ids_mask,
146       check_tf_result.actual_term_frequencies_str.c_str(), term_str.data(),
147       expected.section_mask,
148       check_tf_result.expected_term_frequencies_str.c_str());
149   return actual.term == term &&
150          actual.section_ids_mask == expected.section_mask &&
151          check_tf_result.term_frequencies_match;
152 }
153 
154 class ScoredDocumentHitFormatter {
155  public:
operator()156   std::string operator()(const ScoredDocumentHit& scored_document_hit) {
157     return IcingStringUtil::StringPrintf(
158         "(document_id=%d, hit_section_id_mask=%" PRId64 ", score=%.2f)",
159         scored_document_hit.document_id(),
160         scored_document_hit.hit_section_id_mask(), scored_document_hit.score());
161   }
162 };
163 
164 class ScoredDocumentHitEqualComparator {
165  public:
operator()166   bool operator()(const ScoredDocumentHit& lhs,
167                   const ScoredDocumentHit& rhs) const {
168     return lhs.document_id() == rhs.document_id() &&
169            lhs.hit_section_id_mask() == rhs.hit_section_id_mask() &&
170            std::fabs(lhs.score() - rhs.score()) < 1e-6;
171   }
172 };
173 
174 // Used to match a ScoredDocumentHit
175 MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") {
176   ScoredDocumentHitEqualComparator equal_comparator;
177   if (!equal_comparator(arg, expected_scored_document_hit)) {
178     ScoredDocumentHitFormatter formatter;
179     *result_listener << "Expected: " << formatter(expected_scored_document_hit)
180                      << ". Actual: " << formatter(arg);
181     return false;
182   }
183   return true;
184 }
185 
186 // Used to match a JoinedScoredDocumentHit
187 MATCHER_P(EqualsJoinedScoredDocumentHit, expected_joined_scored_document_hit,
188           "") {
189   ScoredDocumentHitEqualComparator equal_comparator;
190   if (std::fabs(arg.final_score() -
191                 expected_joined_scored_document_hit.final_score()) > 1e-6 ||
192       !equal_comparator(
193           arg.parent_scored_document_hit(),
194           expected_joined_scored_document_hit.parent_scored_document_hit()) ||
195       arg.child_scored_document_hits().size() !=
196           expected_joined_scored_document_hit.child_scored_document_hits()
197               .size() ||
198       !std::equal(
199           arg.child_scored_document_hits().cbegin(),
200           arg.child_scored_document_hits().cend(),
201           expected_joined_scored_document_hit.child_scored_document_hits()
202               .cbegin(),
203           equal_comparator)) {
204     ScoredDocumentHitFormatter formatter;
205 
206     *result_listener << IcingStringUtil::StringPrintf(
207         "Expected: final_score=%.2f, parent_scored_document_hit=%s, "
208         "child_scored_document_hits=[%s]. Actual: final_score=%.2f, "
209         "parent_scored_document_hit=%s, child_scored_document_hits=[%s]",
210         expected_joined_scored_document_hit.final_score(),
211         formatter(
212             expected_joined_scored_document_hit.parent_scored_document_hit())
213             .c_str(),
214         absl_ports::StrJoin(
215             expected_joined_scored_document_hit.child_scored_document_hits(),
216             ",", formatter)
217             .c_str(),
218         arg.final_score(), formatter(arg.parent_scored_document_hit()).c_str(),
219         absl_ports::StrJoin(arg.child_scored_document_hits(), ",", formatter)
220             .c_str());
221     return false;
222   }
223   return true;
224 }
225 
226 MATCHER_P(EqualsSetSchemaResult, expected, "") {
227   const SchemaStore::SetSchemaResult& actual = arg;
228 
229   if (actual.success == expected.success &&
230       actual.old_schema_type_ids_changed ==
231           expected.old_schema_type_ids_changed &&
232       actual.schema_types_deleted_by_name ==
233           expected.schema_types_deleted_by_name &&
234       actual.schema_types_deleted_by_id ==
235           expected.schema_types_deleted_by_id &&
236       actual.schema_types_incompatible_by_name ==
237           expected.schema_types_incompatible_by_name &&
238       actual.schema_types_incompatible_by_id ==
239           expected.schema_types_incompatible_by_id &&
240       actual.schema_types_new_by_name == expected.schema_types_new_by_name &&
241       actual.schema_types_changed_fully_compatible_by_name ==
242           expected.schema_types_changed_fully_compatible_by_name &&
243       actual.schema_types_index_incompatible_by_name ==
244           expected.schema_types_index_incompatible_by_name) {
245     return true;
246   }
247 
248   // Format schema_type_ids_changed
249   std::string actual_old_schema_type_ids_changed = absl_ports::StrCat(
250       "[",
251       absl_ports::StrJoin(actual.old_schema_type_ids_changed, ",",
252                           absl_ports::NumberFormatter()),
253       "]");
254 
255   std::string expected_old_schema_type_ids_changed = absl_ports::StrCat(
256       "[",
257       absl_ports::StrJoin(expected.old_schema_type_ids_changed, ",",
258                           absl_ports::NumberFormatter()),
259       "]");
260 
261   // Format schema_types_deleted_by_name
262   std::string actual_schema_types_deleted_by_name = absl_ports::StrCat(
263       "[", absl_ports::StrJoin(actual.schema_types_deleted_by_name, ","), "]");
264 
265   std::string expected_schema_types_deleted_by_name = absl_ports::StrCat(
266       "[", absl_ports::StrJoin(expected.schema_types_deleted_by_name, ","),
267       "]");
268 
269   // Format schema_types_deleted_by_id
270   std::string actual_schema_types_deleted_by_id = absl_ports::StrCat(
271       "[",
272       absl_ports::StrJoin(actual.schema_types_deleted_by_id, ",",
273                           absl_ports::NumberFormatter()),
274       "]");
275 
276   std::string expected_schema_types_deleted_by_id = absl_ports::StrCat(
277       "[",
278       absl_ports::StrJoin(expected.schema_types_deleted_by_id, ",",
279                           absl_ports::NumberFormatter()),
280       "]");
281 
282   // Format schema_types_incompatible_by_name
283   std::string actual_schema_types_incompatible_by_name = absl_ports::StrCat(
284       "[", absl_ports::StrJoin(actual.schema_types_incompatible_by_name, ","),
285       "]");
286 
287   std::string expected_schema_types_incompatible_by_name = absl_ports::StrCat(
288       "[", absl_ports::StrJoin(expected.schema_types_incompatible_by_name, ","),
289       "]");
290 
291   // Format schema_types_incompatible_by_id
292   std::string actual_schema_types_incompatible_by_id = absl_ports::StrCat(
293       "[",
294       absl_ports::StrJoin(actual.schema_types_incompatible_by_id, ",",
295                           absl_ports::NumberFormatter()),
296       "]");
297 
298   std::string expected_schema_types_incompatible_by_id = absl_ports::StrCat(
299       "[",
300       absl_ports::StrJoin(expected.schema_types_incompatible_by_id, ",",
301                           absl_ports::NumberFormatter()),
302       "]");
303 
304   // Format schema_types_new_by_name
305   std::string actual_schema_types_new_by_name = absl_ports::StrCat(
306       "[", absl_ports::StrJoin(actual.schema_types_new_by_name, ","), "]");
307 
308   std::string expected_schema_types_new_by_name = absl_ports::StrCat(
309       "[", absl_ports::StrJoin(expected.schema_types_new_by_name, ","), "]");
310 
311   // Format schema_types_changed_fully_compatible_by_name
312   std::string actual_schema_types_changed_fully_compatible_by_name =
313       absl_ports::StrCat(
314           "[",
315           absl_ports::StrJoin(
316               actual.schema_types_changed_fully_compatible_by_name, ","),
317           "]");
318 
319   std::string expected_schema_types_changed_fully_compatible_by_name =
320       absl_ports::StrCat(
321           "[",
322           absl_ports::StrJoin(
323               expected.schema_types_changed_fully_compatible_by_name, ","),
324           "]");
325 
326   // Format schema_types_deleted_by_id
327   std::string actual_schema_types_index_incompatible_by_name =
328       absl_ports::StrCat(
329           "[",
330           absl_ports::StrJoin(actual.schema_types_index_incompatible_by_name,
331                               ","),
332           "]");
333 
334   std::string expected_schema_types_index_incompatible_by_name =
335       absl_ports::StrCat(
336           "[",
337           absl_ports::StrJoin(expected.schema_types_index_incompatible_by_name,
338                               ","),
339           "]");
340 
341   *result_listener << IcingStringUtil::StringPrintf(
342       "\nExpected {\n"
343       "\tsuccess=%d,\n"
344       "\told_schema_type_ids_changed=%s,\n"
345       "\tschema_types_deleted_by_name=%s,\n"
346       "\tschema_types_deleted_by_id=%s,\n"
347       "\tschema_types_incompatible_by_name=%s,\n"
348       "\tschema_types_incompatible_by_id=%s\n"
349       "\tschema_types_new_by_name=%s,\n"
350       "\tschema_types_index_incompatible_by_name=%s,\n"
351       "\tschema_types_changed_fully_compatible_by_name=%s\n"
352       "}\n"
353       "Actual {\n"
354       "\tsuccess=%d,\n"
355       "\told_schema_type_ids_changed=%s,\n"
356       "\tschema_types_deleted_by_name=%s,\n"
357       "\tschema_types_deleted_by_id=%s,\n"
358       "\tschema_types_incompatible_by_name=%s,\n"
359       "\tschema_types_incompatible_by_id=%s\n"
360       "\tschema_types_new_by_name=%s,\n"
361       "\tschema_types_index_incompatible_by_name=%s,\n"
362       "\tschema_types_changed_fully_compatible_by_name=%s\n"
363       "}\n",
364       expected.success, expected_old_schema_type_ids_changed.c_str(),
365       expected_schema_types_deleted_by_name.c_str(),
366       expected_schema_types_deleted_by_id.c_str(),
367       expected_schema_types_incompatible_by_name.c_str(),
368       expected_schema_types_incompatible_by_id.c_str(),
369       expected_schema_types_new_by_name.c_str(),
370       expected_schema_types_changed_fully_compatible_by_name.c_str(),
371       expected_schema_types_index_incompatible_by_name.c_str(), actual.success,
372       actual_old_schema_type_ids_changed.c_str(),
373       actual_schema_types_deleted_by_name.c_str(),
374       actual_schema_types_deleted_by_id.c_str(),
375       actual_schema_types_incompatible_by_name.c_str(),
376       actual_schema_types_incompatible_by_id.c_str(),
377       actual_schema_types_new_by_name.c_str(),
378       actual_schema_types_changed_fully_compatible_by_name.c_str(),
379       actual_schema_types_index_incompatible_by_name.c_str());
380   return false;
381 }
382 
383 MATCHER_P3(EqualsSectionMetadata, expected_id, expected_property_path,
384            expected_property_config_proto, "") {
385   const SectionMetadata& actual = arg;
386   return actual.id == expected_id && actual.path == expected_property_path &&
387          actual.data_type == expected_property_config_proto.data_type() &&
388          actual.tokenizer ==
389              expected_property_config_proto.string_indexing_config()
390                  .tokenizer_type() &&
391          actual.term_match_type ==
392              expected_property_config_proto.string_indexing_config()
393                  .term_match_type() &&
394          actual.numeric_match_type ==
395              expected_property_config_proto.integer_indexing_config()
396                  .numeric_match_type();
397 }
398 
399 MATCHER_P3(EqualsJoinablePropertyMetadata, expected_id, expected_property_path,
400            expected_property_config_proto, "") {
401   const JoinablePropertyMetadata& actual = arg;
402   return actual.id == expected_id && actual.path == expected_property_path &&
403          actual.data_type == expected_property_config_proto.data_type() &&
404          actual.value_type ==
405              expected_property_config_proto.joinable_config().value_type();
406 }
407 
408 std::string StatusCodeToString(libtextclassifier3::StatusCode code);
409 
410 std::string ProtoStatusCodeToString(StatusProto::Code code);
411 
412 MATCHER(IsOk, "") {
413   libtextclassifier3::StatusAdapter adapter(arg);
414   if (adapter.status().ok()) {
415     return true;
416   }
417   *result_listener << IcingStringUtil::StringPrintf(
418       "Expected OK, actual was (%s:%s)",
419       StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
420       adapter.status().error_message().c_str());
421   return false;
422 }
423 
424 MATCHER_P(IsOkAndHolds, matcher, "") {
425   if (!arg.ok()) {
426     *result_listener << IcingStringUtil::StringPrintf(
427         "Expected OK, actual was (%s:%s)",
428         StatusCodeToString(arg.status().CanonicalCode()).c_str(),
429         arg.status().error_message().c_str());
430     return false;
431   }
432   return ExplainMatchResult(matcher, arg.ValueOrDie(), result_listener);
433 }
434 
435 MATCHER_P(StatusIs, status_code, "") {
436   libtextclassifier3::StatusAdapter adapter(arg);
437   if (adapter.status().CanonicalCode() == status_code) {
438     return true;
439   }
440   *result_listener << IcingStringUtil::StringPrintf(
441       "Expected (%s:), actual was (%s:%s)",
442       StatusCodeToString(status_code).c_str(),
443       StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
444       adapter.status().error_message().c_str());
445   return false;
446 }
447 
448 MATCHER_P2(StatusIs, status_code, error_matcher, "") {
449   libtextclassifier3::StatusAdapter adapter(arg);
450   if (adapter.status().CanonicalCode() != status_code) {
451     *result_listener << IcingStringUtil::StringPrintf(
452         "Expected (%s:), actual was (%s:%s)",
453         StatusCodeToString(status_code).c_str(),
454         StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
455         adapter.status().error_message().c_str());
456     return false;
457   }
458   return ExplainMatchResult(error_matcher, adapter.status().error_message(),
459                             result_listener);
460 }
461 
462 MATCHER(ProtoIsOk, "") {
463   if (arg.code() == StatusProto::OK) {
464     return true;
465   }
466   *result_listener << IcingStringUtil::StringPrintf(
467       "Expected OK, actual was (%s:%s)",
468       ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
469   return false;
470 }
471 
472 MATCHER_P(ProtoStatusIs, status_code, "") {
473   if (arg.code() == status_code) {
474     return true;
475   }
476   *result_listener << IcingStringUtil::StringPrintf(
477       "Expected (%s:), actual was (%s:%s)",
478       ProtoStatusCodeToString(status_code).c_str(),
479       ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
480   return false;
481 }
482 
483 MATCHER_P2(ProtoStatusIs, status_code, error_matcher, "") {
484   if (arg.code() != status_code) {
485     *result_listener << IcingStringUtil::StringPrintf(
486         "Expected (%s:), actual was (%s:%s)",
487         ProtoStatusCodeToString(status_code).c_str(),
488         ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
489     return false;
490   }
491   return ExplainMatchResult(error_matcher, arg.message(), result_listener);
492 }
493 
494 MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
495   SearchResultProto actual_copy = arg;
496   actual_copy.clear_query_stats();
497   actual_copy.clear_debug_info();
498   for (SearchResultProto::ResultProto& result :
499        *actual_copy.mutable_results()) {
500     // Joined results
501     for (SearchResultProto::ResultProto& joined_result :
502          *result.mutable_joined_results()) {
503       joined_result.clear_score();
504     }
505     result.clear_score();
506   }
507 
508   SearchResultProto expected_copy = expected;
509   expected_copy.clear_query_stats();
510   expected_copy.clear_debug_info();
511   for (SearchResultProto::ResultProto& result :
512        *expected_copy.mutable_results()) {
513     // Joined results
514     for (SearchResultProto::ResultProto& joined_result :
515          *result.mutable_joined_results()) {
516       joined_result.clear_score();
517     }
518     result.clear_score();
519   }
520   return ExplainMatchResult(portable_equals_proto::EqualsProto(expected_copy),
521                             actual_copy, result_listener);
522 }
523 
524 // TODO(tjbarron) Remove this once icing has switched to depend on TC3 Status
525 #define ICING_STATUS_MACROS_CONCAT_NAME(x, y) \
526   ICING_STATUS_MACROS_CONCAT_IMPL(x, y)
527 #define ICING_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
528 
529 #define ICING_EXPECT_OK(func) EXPECT_THAT(func, IsOk())
530 #define ICING_ASSERT_OK(func) ASSERT_THAT(func, IsOk())
531 #define ICING_ASSERT_OK_AND_ASSIGN(lhs, rexpr)                             \
532   ICING_ASSERT_OK_AND_ASSIGN_IMPL(                                         \
533       ICING_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
534       rexpr)
535 #define ICING_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
536   auto statusor = (rexpr);                                    \
537   ICING_ASSERT_OK(statusor.status());                         \
538   lhs = std::move(statusor).ValueOrDie()
539 
540 #define ICING_ASSERT_HAS_VALUE_AND_ASSIGN(lhs, rexpr) \
541   ASSERT_TRUE(rexpr);                                 \
542   lhs = rexpr.value()
543 
544 }  // namespace lib
545 }  // namespace icing
546 
547 #endif  // ICING_TESTING_COMMON_MATCHERS_H_
548