• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/index/embed/doc-hit-info-iterator-embedding.h"
16 
17 #include <cstdint>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "icing/text_classifier/lib3/utils/base/status.h"
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/index/embed/embedding-index.h"
27 #include "icing/index/embed/embedding-query-results.h"
28 #include "icing/index/embed/embedding-scorer.h"
29 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/hit/hit.h"
32 #include "icing/proto/search.pb.h"
33 #include "icing/schema/schema-store.h"
34 #include "icing/schema/section.h"
35 #include "icing/store/document-filter-data.h"
36 #include "icing/store/document-id.h"
37 #include "icing/store/document-store.h"
38 #include "icing/util/status-macros.h"
39 
40 namespace icing {
41 namespace lib {
42 
43 libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>>
Create(const PropertyProto::VectorProto * query,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,double score_low,double score_high,bool get_embedding_match_info,EmbeddingQueryResults::EmbeddingQueryMatchInfoMap * info_map,const EmbeddingIndex * embedding_index,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms)44 DocHitInfoIteratorEmbedding::Create(
45     const PropertyProto::VectorProto* query,
46     SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
47     double score_low, double score_high, bool get_embedding_match_info,
48     EmbeddingQueryResults::EmbeddingQueryMatchInfoMap* info_map,
49     const EmbeddingIndex* embedding_index, const DocumentStore* document_store,
50     const SchemaStore* schema_store, int64_t current_time_ms) {
51   ICING_RETURN_ERROR_IF_NULL(query);
52   ICING_RETURN_ERROR_IF_NULL(embedding_index);
53   ICING_RETURN_ERROR_IF_NULL(info_map);
54   ICING_RETURN_ERROR_IF_NULL(document_store);
55   ICING_RETURN_ERROR_IF_NULL(schema_store);
56 
57   libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
58       pl_accessor_or = embedding_index->GetAccessorForVector(*query);
59   std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
60   if (pl_accessor_or.ok()) {
61     pl_accessor = std::move(pl_accessor_or).ValueOrDie();
62   } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
63     // A not-found error should be fine, since that means there is no matching
64     // embedding hits in the index.
65     pl_accessor = nullptr;
66   } else {
67     // Otherwise, return the error as is.
68     return pl_accessor_or.status();
69   }
70 
71   ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer,
72                          EmbeddingScorer::Create(metric_type));
73 
74   return std::unique_ptr<DocHitInfoIteratorEmbedding>(
75       new DocHitInfoIteratorEmbedding(
76           query, metric_type, std::move(embedding_scorer), score_low,
77           score_high, get_embedding_match_info, info_map, embedding_index,
78           std::move(pl_accessor), document_store, schema_store,
79           current_time_ms));
80 }
81 
82 libtextclassifier3::StatusOr<const EmbeddingHit*>
AdvanceToNextEmbeddingHit()83 DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() {
84   if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) {
85     ICING_ASSIGN_OR_RETURN(cached_embedding_hits_,
86                            posting_list_accessor_->GetNextHitsBatch());
87     cached_embedding_hits_idx_ = 0;
88     if (cached_embedding_hits_.empty()) {
89       no_more_hit_ = true;
90       return nullptr;
91     }
92   }
93   const EmbeddingHit& embedding_hit =
94       cached_embedding_hits_[cached_embedding_hits_idx_];
95   if (doc_hit_info_.document_id() == kInvalidDocumentId) {
96     doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id());
97     current_allowed_sections_mask_ =
98         ComputeAllowedSectionsMask(doc_hit_info_.document_id());
99 
100     schema_type_id_ = document_store_.GetSchemaTypeId(
101         doc_hit_info_.document_id(), current_time_ms_);
102     if (schema_type_id_ == kInvalidSchemaTypeId) {
103       // This means that the document is deleted or expired, so update
104       // current_allowed_sections_mask_ to skip the document.
105       current_allowed_sections_mask_ = kSectionIdMaskNone;
106     }
107   } else if (doc_hit_info_.document_id() !=
108              embedding_hit.basic_hit().document_id()) {
109     return nullptr;
110   }
111   ++cached_embedding_hits_idx_;
112   return &embedding_hit;
113 }
114 
115 libtextclassifier3::Status
AdvanceToNextUnfilteredDocument()116 DocHitInfoIteratorEmbedding::AdvanceToNextUnfilteredDocument() {
117   if (no_more_hit_ || posting_list_accessor_ == nullptr) {
118     return absl_ports::ResourceExhaustedError(
119         "No more DocHitInfos in iterator");
120   }
121 
122   doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone);
123   schema_type_id_ = kInvalidSchemaTypeId;
124   EmbeddingMatchInfos* matched_infos = nullptr;
125   current_allowed_sections_mask_ = kSectionIdMaskAll;
126   SectionId current_section_id = kInvalidSectionId;
127   EmbeddingIndexingConfig::QuantizationType::Code quantization_type =
128       EmbeddingIndexingConfig::QuantizationType::NONE;
129   int current_section_match_count = 0;
130 
131   while (true) {
132     ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit,
133                            AdvanceToNextEmbeddingHit());
134     if (embedding_hit == nullptr) {
135       // No more hits for the current document.
136       break;
137     }
138 
139     // Filter out the embedding hit according to the section restriction.
140     if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) &
141          current_allowed_sections_mask_) == 0) {
142       continue;
143     }
144 
145     // We've reached a new section. Reset the match count and retrieve the
146     // quantization type for the new section.
147     if (current_section_id != embedding_hit->basic_hit().section_id()) {
148       current_section_match_count = 0;
149       current_section_id = embedding_hit->basic_hit().section_id();
150       // The schema type id is guaranteed to be valid here. Otherwise,
151       // current_allowed_sections_mask_ should be assigned to kSectionIdMaskNone
152       // by AdvanceToNextEmbeddingHit, and the embedding hit should have been
153       // skipped above.
154       ICING_ASSIGN_OR_RETURN(
155           quantization_type,
156           schema_store_.GetQuantizationType(
157               schema_type_id_, current_section_id));
158     }
159 
160     // Calculate the semantic score.
161     ICING_ASSIGN_OR_RETURN(
162         float semantic_score,
163         embedding_index_.ScoreEmbeddingHit(*embedding_scorer_, query_,
164                                            *embedding_hit, quantization_type));
165 
166     // If the semantic score is within the desired score range, update
167     // doc_hit_info_ and info_map_.
168     if (score_low_ <= semantic_score && semantic_score <= score_high_) {
169       doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id());
170       if (matched_infos == nullptr) {
171         matched_infos = &(info_map_[doc_hit_info_.document_id()]);
172       }
173       matched_infos->AppendScore(semantic_score);
174       if (get_embedding_match_info_) {
175         // Add the section info for this embedding match.
176         matched_infos->AppendSectionInfo(current_section_id,
177                                          current_section_match_count);
178       }
179     }
180     ++current_section_match_count;
181   }
182 
183   if (doc_hit_info_.document_id() == kInvalidDocumentId) {
184     return absl_ports::ResourceExhaustedError(
185         "No more DocHitInfos in iterator");
186   }
187   return libtextclassifier3::Status::OK;
188 }
189 
Advance()190 libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() {
191   do {
192     ICING_RETURN_IF_ERROR(AdvanceToNextUnfilteredDocument());
193   } while (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone);
194   ++num_advance_calls_;
195   return libtextclassifier3::Status::OK;
196 }
197 
198 }  // namespace lib
199 }  // namespace icing
200