1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/index/embed/doc-hit-info-iterator-embedding.h"
16
17 #include <cstdint>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "icing/text_classifier/lib3/utils/base/status.h"
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/index/embed/embedding-index.h"
27 #include "icing/index/embed/embedding-query-results.h"
28 #include "icing/index/embed/embedding-scorer.h"
29 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/hit/hit.h"
32 #include "icing/proto/search.pb.h"
33 #include "icing/schema/schema-store.h"
34 #include "icing/schema/section.h"
35 #include "icing/store/document-filter-data.h"
36 #include "icing/store/document-id.h"
37 #include "icing/store/document-store.h"
38 #include "icing/util/status-macros.h"
39
40 namespace icing {
41 namespace lib {
42
43 libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>>
Create(const PropertyProto::VectorProto * query,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,double score_low,double score_high,bool get_embedding_match_info,EmbeddingQueryResults::EmbeddingQueryMatchInfoMap * info_map,const EmbeddingIndex * embedding_index,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms)44 DocHitInfoIteratorEmbedding::Create(
45 const PropertyProto::VectorProto* query,
46 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
47 double score_low, double score_high, bool get_embedding_match_info,
48 EmbeddingQueryResults::EmbeddingQueryMatchInfoMap* info_map,
49 const EmbeddingIndex* embedding_index, const DocumentStore* document_store,
50 const SchemaStore* schema_store, int64_t current_time_ms) {
51 ICING_RETURN_ERROR_IF_NULL(query);
52 ICING_RETURN_ERROR_IF_NULL(embedding_index);
53 ICING_RETURN_ERROR_IF_NULL(info_map);
54 ICING_RETURN_ERROR_IF_NULL(document_store);
55 ICING_RETURN_ERROR_IF_NULL(schema_store);
56
57 libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
58 pl_accessor_or = embedding_index->GetAccessorForVector(*query);
59 std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
60 if (pl_accessor_or.ok()) {
61 pl_accessor = std::move(pl_accessor_or).ValueOrDie();
62 } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
63 // A not-found error should be fine, since that means there is no matching
64 // embedding hits in the index.
65 pl_accessor = nullptr;
66 } else {
67 // Otherwise, return the error as is.
68 return pl_accessor_or.status();
69 }
70
71 ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer,
72 EmbeddingScorer::Create(metric_type));
73
74 return std::unique_ptr<DocHitInfoIteratorEmbedding>(
75 new DocHitInfoIteratorEmbedding(
76 query, metric_type, std::move(embedding_scorer), score_low,
77 score_high, get_embedding_match_info, info_map, embedding_index,
78 std::move(pl_accessor), document_store, schema_store,
79 current_time_ms));
80 }
81
82 libtextclassifier3::StatusOr<const EmbeddingHit*>
AdvanceToNextEmbeddingHit()83 DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() {
84 if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) {
85 ICING_ASSIGN_OR_RETURN(cached_embedding_hits_,
86 posting_list_accessor_->GetNextHitsBatch());
87 cached_embedding_hits_idx_ = 0;
88 if (cached_embedding_hits_.empty()) {
89 no_more_hit_ = true;
90 return nullptr;
91 }
92 }
93 const EmbeddingHit& embedding_hit =
94 cached_embedding_hits_[cached_embedding_hits_idx_];
95 if (doc_hit_info_.document_id() == kInvalidDocumentId) {
96 doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id());
97 current_allowed_sections_mask_ =
98 ComputeAllowedSectionsMask(doc_hit_info_.document_id());
99
100 schema_type_id_ = document_store_.GetSchemaTypeId(
101 doc_hit_info_.document_id(), current_time_ms_);
102 if (schema_type_id_ == kInvalidSchemaTypeId) {
103 // This means that the document is deleted or expired, so update
104 // current_allowed_sections_mask_ to skip the document.
105 current_allowed_sections_mask_ = kSectionIdMaskNone;
106 }
107 } else if (doc_hit_info_.document_id() !=
108 embedding_hit.basic_hit().document_id()) {
109 return nullptr;
110 }
111 ++cached_embedding_hits_idx_;
112 return &embedding_hit;
113 }
114
115 libtextclassifier3::Status
AdvanceToNextUnfilteredDocument()116 DocHitInfoIteratorEmbedding::AdvanceToNextUnfilteredDocument() {
117 if (no_more_hit_ || posting_list_accessor_ == nullptr) {
118 return absl_ports::ResourceExhaustedError(
119 "No more DocHitInfos in iterator");
120 }
121
122 doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone);
123 schema_type_id_ = kInvalidSchemaTypeId;
124 EmbeddingMatchInfos* matched_infos = nullptr;
125 current_allowed_sections_mask_ = kSectionIdMaskAll;
126 SectionId current_section_id = kInvalidSectionId;
127 EmbeddingIndexingConfig::QuantizationType::Code quantization_type =
128 EmbeddingIndexingConfig::QuantizationType::NONE;
129 int current_section_match_count = 0;
130
131 while (true) {
132 ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit,
133 AdvanceToNextEmbeddingHit());
134 if (embedding_hit == nullptr) {
135 // No more hits for the current document.
136 break;
137 }
138
139 // Filter out the embedding hit according to the section restriction.
140 if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) &
141 current_allowed_sections_mask_) == 0) {
142 continue;
143 }
144
145 // We've reached a new section. Reset the match count and retrieve the
146 // quantization type for the new section.
147 if (current_section_id != embedding_hit->basic_hit().section_id()) {
148 current_section_match_count = 0;
149 current_section_id = embedding_hit->basic_hit().section_id();
150 // The schema type id is guaranteed to be valid here. Otherwise,
151 // current_allowed_sections_mask_ should be assigned to kSectionIdMaskNone
152 // by AdvanceToNextEmbeddingHit, and the embedding hit should have been
153 // skipped above.
154 ICING_ASSIGN_OR_RETURN(
155 quantization_type,
156 schema_store_.GetQuantizationType(
157 schema_type_id_, current_section_id));
158 }
159
160 // Calculate the semantic score.
161 ICING_ASSIGN_OR_RETURN(
162 float semantic_score,
163 embedding_index_.ScoreEmbeddingHit(*embedding_scorer_, query_,
164 *embedding_hit, quantization_type));
165
166 // If the semantic score is within the desired score range, update
167 // doc_hit_info_ and info_map_.
168 if (score_low_ <= semantic_score && semantic_score <= score_high_) {
169 doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id());
170 if (matched_infos == nullptr) {
171 matched_infos = &(info_map_[doc_hit_info_.document_id()]);
172 }
173 matched_infos->AppendScore(semantic_score);
174 if (get_embedding_match_info_) {
175 // Add the section info for this embedding match.
176 matched_infos->AppendSectionInfo(current_section_id,
177 current_section_match_count);
178 }
179 }
180 ++current_section_match_count;
181 }
182
183 if (doc_hit_info_.document_id() == kInvalidDocumentId) {
184 return absl_ports::ResourceExhaustedError(
185 "No more DocHitInfos in iterator");
186 }
187 return libtextclassifier3::Status::OK;
188 }
189
Advance()190 libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() {
191 do {
192 ICING_RETURN_IF_ERROR(AdvanceToNextUnfilteredDocument());
193 } while (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone);
194 ++num_advance_calls_;
195 return libtextclassifier3::Status::OK;
196 }
197
198 } // namespace lib
199 } // namespace icing
200