• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
17 
18 #include <cinttypes>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "icing/text_classifier/lib3/utils/base/status.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/absl_ports/str_cat.h"
26 #include "icing/index/hit/doc-hit-info.h"
27 #include "icing/index/iterator/doc-hit-info-iterator.h"
28 #include "icing/legacy/core/icing-string-util.h"
29 #include "icing/schema/section.h"
30 #include "icing/store/document-id.h"
31 
32 namespace icing {
33 namespace lib {
34 
35 class DocHitInfoTermFrequencyPair {
36  public:
37   DocHitInfoTermFrequencyPair(
38       const DocHitInfo& doc_hit_info,
39       const Hit::TermFrequencyArray& hit_term_frequency = {})
doc_hit_info_(doc_hit_info)40       : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {}
41 
UpdateSection(SectionId section_id,Hit::TermFrequency hit_term_frequency)42   void UpdateSection(SectionId section_id,
43                      Hit::TermFrequency hit_term_frequency) {
44     doc_hit_info_.UpdateSection(section_id);
45     hit_term_frequency_[section_id] = hit_term_frequency;
46   }
47 
MergeSectionsFrom(const DocHitInfoTermFrequencyPair & other)48   void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) {
49     SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask();
50     doc_hit_info_.MergeSectionsFrom(other_mask);
51     while (other_mask) {
52       SectionId section_id = __builtin_ctzll(other_mask);
53       hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id];
54       other_mask &= ~(UINT64_C(1) << section_id);
55     }
56   }
57 
doc_hit_info()58   DocHitInfo doc_hit_info() const { return doc_hit_info_; }
59 
hit_term_frequency(SectionId section_id)60   Hit::TermFrequency hit_term_frequency(SectionId section_id) const {
61     return hit_term_frequency_[section_id];
62   }
63 
64  private:
65   DocHitInfo doc_hit_info_;
66   Hit::TermFrequencyArray hit_term_frequency_;
67 };
68 
69 // Dummy class to help with testing. It starts with an kInvalidDocumentId doc
70 // hit info until an Advance is called (like normal DocHitInfoIterators). It
71 // will then proceed to return the doc_hit_infos in order as Advance's are
72 // called. After all doc_hit_infos are returned, Advance will return a NotFound
73 // error (also like normal DocHitInfoIterators).
74 class DocHitInfoIteratorDummy : public DocHitInfoIterator {
75  public:
76   DocHitInfoIteratorDummy() = default;
77   explicit DocHitInfoIteratorDummy(
78       std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos,
79       std::string term = "")
doc_hit_infos_(std::move (doc_hit_infos))80       : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {}
81 
82   explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos,
83                                    std::string term = "",
84                                    int term_start_index = 0,
85                                    int unnormalized_term_length = 0)
term_(std::move (term))86       : term_(std::move(term)),
87         term_start_index_(term_start_index),
88         unnormalized_term_length_(unnormalized_term_length) {
89     for (auto& doc_hit_info : doc_hit_infos) {
90       doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info));
91     }
92   }
93 
Advance()94   libtextclassifier3::Status Advance() override {
95     ++index_;
96     if (index_ < doc_hit_infos_.size()) {
97       doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info();
98       return libtextclassifier3::Status::OK;
99     }
100 
101     return absl_ports::ResourceExhaustedError(
102         "No more DocHitInfos in iterator");
103   }
104 
TrimRightMostNode()105   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
106     DocHitInfoIterator::TrimmedNode node = {nullptr, term_, term_start_index_,
107                                             unnormalized_term_length_};
108     return node;
109   }
110 
111   // Imitates behavior of DocHitInfoIteratorTermMain/DocHitInfoIteratorTermLite
112   void PopulateMatchedTermsStats(
113       std::vector<TermMatchInfo>* matched_terms_stats,
114       SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
115     if (index_ == -1 || index_ >= doc_hit_infos_.size()) {
116       // Current hit isn't valid, return.
117       return;
118     }
119     SectionIdMask section_mask =
120         doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
121     SectionIdMask section_mask_copy = section_mask;
122     std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
123         {Hit::kNoTermFrequency};
124     while (section_mask_copy) {
125       SectionId section_id = __builtin_ctzll(section_mask_copy);
126       section_term_frequencies.at(section_id) =
127           doc_hit_infos_.at(index_).hit_term_frequency(section_id);
128       section_mask_copy &= ~(UINT64_C(1) << section_id);
129     }
130     TermMatchInfo term_stats(term_, section_mask,
131                              std::move(section_term_frequencies));
132 
133     for (auto& cur_term_stats : *matched_terms_stats) {
134       if (cur_term_stats.term == term_stats.term) {
135         // Same docId and same term, we don't need to add the term and the term
136         // frequency should always be the same
137         return;
138       }
139     }
140     matched_terms_stats->push_back(term_stats);
141   }
142 
set_hit_intersect_section_ids_mask(SectionIdMask hit_intersect_section_ids_mask)143   void set_hit_intersect_section_ids_mask(
144       SectionIdMask hit_intersect_section_ids_mask) {
145     hit_intersect_section_ids_mask_ = hit_intersect_section_ids_mask;
146   }
147 
GetNumBlocksInspected()148   int32_t GetNumBlocksInspected() const override {
149     return num_blocks_inspected_;
150   }
151 
SetNumBlocksInspected(int32_t num_blocks_inspected)152   void SetNumBlocksInspected(int32_t num_blocks_inspected) {
153     num_blocks_inspected_ = num_blocks_inspected;
154   }
155 
GetNumLeafAdvanceCalls()156   int32_t GetNumLeafAdvanceCalls() const override {
157     return num_leaf_advance_calls_;
158   }
159 
SetNumLeafAdvanceCalls(int32_t num_leaf_advance_calls)160   void SetNumLeafAdvanceCalls(int32_t num_leaf_advance_calls) {
161     num_leaf_advance_calls_ = num_leaf_advance_calls;
162   }
163 
ToString()164   std::string ToString() const override {
165     std::string ret = "<";
166     for (auto& doc_hit_info_pair : doc_hit_infos_) {
167       absl_ports::StrAppend(
168           &ret, IcingStringUtil::StringPrintf(
169                     "[%d,%" PRIu64 "]",
170                     doc_hit_info_pair.doc_hit_info().document_id(),
171                     doc_hit_info_pair.doc_hit_info().hit_section_ids_mask()));
172     }
173     absl_ports::StrAppend(&ret, ">");
174     return ret;
175   }
176 
177  private:
178   int32_t index_ = -1;
179   int32_t num_blocks_inspected_ = 0;
180   int32_t num_leaf_advance_calls_ = 0;
181   std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_;
182   std::string term_;
183   int term_start_index_;
184   int unnormalized_term_length_;
185 };
186 
GetDocumentIds(DocHitInfoIterator * iterator)187 inline std::vector<DocumentId> GetDocumentIds(DocHitInfoIterator* iterator) {
188   std::vector<DocumentId> ids;
189   while (iterator->Advance().ok()) {
190     ids.push_back(iterator->doc_hit_info().document_id());
191   }
192   return ids;
193 }
194 
GetDocHitInfos(DocHitInfoIterator * iterator)195 inline std::vector<DocHitInfo> GetDocHitInfos(DocHitInfoIterator* iterator) {
196   std::vector<DocHitInfo> doc_hit_infos;
197   while (iterator->Advance().ok()) {
198     doc_hit_infos.push_back(iterator->doc_hit_info());
199   }
200   return doc_hit_infos;
201 }
202 
203 }  // namespace lib
204 }  // namespace icing
205 
206 #endif  // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
207