1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
17
18 #include <cinttypes>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "icing/text_classifier/lib3/utils/base/status.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/absl_ports/str_cat.h"
26 #include "icing/index/hit/doc-hit-info.h"
27 #include "icing/index/iterator/doc-hit-info-iterator.h"
28 #include "icing/legacy/core/icing-string-util.h"
29 #include "icing/schema/section.h"
30 #include "icing/store/document-id.h"
31
32 namespace icing {
33 namespace lib {
34
35 class DocHitInfoTermFrequencyPair {
36 public:
37 DocHitInfoTermFrequencyPair(
38 const DocHitInfo& doc_hit_info,
39 const Hit::TermFrequencyArray& hit_term_frequency = {})
doc_hit_info_(doc_hit_info)40 : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {}
41
UpdateSection(SectionId section_id,Hit::TermFrequency hit_term_frequency)42 void UpdateSection(SectionId section_id,
43 Hit::TermFrequency hit_term_frequency) {
44 doc_hit_info_.UpdateSection(section_id);
45 hit_term_frequency_[section_id] = hit_term_frequency;
46 }
47
MergeSectionsFrom(const DocHitInfoTermFrequencyPair & other)48 void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) {
49 SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask();
50 doc_hit_info_.MergeSectionsFrom(other_mask);
51 while (other_mask) {
52 SectionId section_id = __builtin_ctzll(other_mask);
53 hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id];
54 other_mask &= ~(UINT64_C(1) << section_id);
55 }
56 }
57
doc_hit_info()58 DocHitInfo doc_hit_info() const { return doc_hit_info_; }
59
hit_term_frequency(SectionId section_id)60 Hit::TermFrequency hit_term_frequency(SectionId section_id) const {
61 return hit_term_frequency_[section_id];
62 }
63
64 private:
65 DocHitInfo doc_hit_info_;
66 Hit::TermFrequencyArray hit_term_frequency_;
67 };
68
69 // Dummy class to help with testing. It starts with an kInvalidDocumentId doc
70 // hit info until an Advance is called (like normal DocHitInfoIterators). It
71 // will then proceed to return the doc_hit_infos in order as Advance's are
72 // called. After all doc_hit_infos are returned, Advance will return a NotFound
73 // error (also like normal DocHitInfoIterators).
74 class DocHitInfoIteratorDummy : public DocHitInfoIterator {
75 public:
76 DocHitInfoIteratorDummy() = default;
77 explicit DocHitInfoIteratorDummy(
78 std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos,
79 std::string term = "")
doc_hit_infos_(std::move (doc_hit_infos))80 : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {}
81
82 explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos,
83 std::string term = "",
84 int term_start_index = 0,
85 int unnormalized_term_length = 0)
term_(std::move (term))86 : term_(std::move(term)),
87 term_start_index_(term_start_index),
88 unnormalized_term_length_(unnormalized_term_length) {
89 for (auto& doc_hit_info : doc_hit_infos) {
90 doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info));
91 }
92 }
93
Advance()94 libtextclassifier3::Status Advance() override {
95 ++index_;
96 if (index_ < doc_hit_infos_.size()) {
97 doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info();
98 return libtextclassifier3::Status::OK;
99 }
100
101 return absl_ports::ResourceExhaustedError(
102 "No more DocHitInfos in iterator");
103 }
104
TrimRightMostNode()105 libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
106 DocHitInfoIterator::TrimmedNode node = {nullptr, term_, term_start_index_,
107 unnormalized_term_length_};
108 return node;
109 }
110
111 // Imitates behavior of DocHitInfoIteratorTermMain/DocHitInfoIteratorTermLite
112 void PopulateMatchedTermsStats(
113 std::vector<TermMatchInfo>* matched_terms_stats,
114 SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
115 if (index_ == -1 || index_ >= doc_hit_infos_.size()) {
116 // Current hit isn't valid, return.
117 return;
118 }
119 SectionIdMask section_mask =
120 doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
121 SectionIdMask section_mask_copy = section_mask;
122 std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
123 {Hit::kNoTermFrequency};
124 while (section_mask_copy) {
125 SectionId section_id = __builtin_ctzll(section_mask_copy);
126 section_term_frequencies.at(section_id) =
127 doc_hit_infos_.at(index_).hit_term_frequency(section_id);
128 section_mask_copy &= ~(UINT64_C(1) << section_id);
129 }
130 TermMatchInfo term_stats(term_, section_mask,
131 std::move(section_term_frequencies));
132
133 for (auto& cur_term_stats : *matched_terms_stats) {
134 if (cur_term_stats.term == term_stats.term) {
135 // Same docId and same term, we don't need to add the term and the term
136 // frequency should always be the same
137 return;
138 }
139 }
140 matched_terms_stats->push_back(term_stats);
141 }
142
set_hit_intersect_section_ids_mask(SectionIdMask hit_intersect_section_ids_mask)143 void set_hit_intersect_section_ids_mask(
144 SectionIdMask hit_intersect_section_ids_mask) {
145 hit_intersect_section_ids_mask_ = hit_intersect_section_ids_mask;
146 }
147
GetNumBlocksInspected()148 int32_t GetNumBlocksInspected() const override {
149 return num_blocks_inspected_;
150 }
151
SetNumBlocksInspected(int32_t num_blocks_inspected)152 void SetNumBlocksInspected(int32_t num_blocks_inspected) {
153 num_blocks_inspected_ = num_blocks_inspected;
154 }
155
GetNumLeafAdvanceCalls()156 int32_t GetNumLeafAdvanceCalls() const override {
157 return num_leaf_advance_calls_;
158 }
159
SetNumLeafAdvanceCalls(int32_t num_leaf_advance_calls)160 void SetNumLeafAdvanceCalls(int32_t num_leaf_advance_calls) {
161 num_leaf_advance_calls_ = num_leaf_advance_calls;
162 }
163
ToString()164 std::string ToString() const override {
165 std::string ret = "<";
166 for (auto& doc_hit_info_pair : doc_hit_infos_) {
167 absl_ports::StrAppend(
168 &ret, IcingStringUtil::StringPrintf(
169 "[%d,%" PRIu64 "]",
170 doc_hit_info_pair.doc_hit_info().document_id(),
171 doc_hit_info_pair.doc_hit_info().hit_section_ids_mask()));
172 }
173 absl_ports::StrAppend(&ret, ">");
174 return ret;
175 }
176
177 private:
178 int32_t index_ = -1;
179 int32_t num_blocks_inspected_ = 0;
180 int32_t num_leaf_advance_calls_ = 0;
181 std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_;
182 std::string term_;
183 int term_start_index_;
184 int unnormalized_term_length_;
185 };
186
GetDocumentIds(DocHitInfoIterator * iterator)187 inline std::vector<DocumentId> GetDocumentIds(DocHitInfoIterator* iterator) {
188 std::vector<DocumentId> ids;
189 while (iterator->Advance().ok()) {
190 ids.push_back(iterator->doc_hit_info().document_id());
191 }
192 return ids;
193 }
194
GetDocHitInfos(DocHitInfoIterator * iterator)195 inline std::vector<DocHitInfo> GetDocHitInfos(DocHitInfoIterator* iterator) {
196 std::vector<DocHitInfo> doc_hit_infos;
197 while (iterator->Advance().ok()) {
198 doc_hit_infos.push_back(iterator->doc_hit_info());
199 }
200 return doc_hit_infos;
201 }
202
203 } // namespace lib
204 } // namespace icing
205
206 #endif // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
207