1 // Copyright (C) 2024 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 16 #define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 17 18 #include <memory> 19 #include <string> 20 #include <string_view> 21 #include <utility> 22 #include <vector> 23 24 #include "icing/text_classifier/lib3/utils/base/status.h" 25 #include "icing/text_classifier/lib3/utils/base/statusor.h" 26 #include "icing/absl_ports/canonical_errors.h" 27 #include "icing/index/embed/embedding-hit.h" 28 #include "icing/index/embed/embedding-index.h" 29 #include "icing/index/embed/embedding-query-results.h" 30 #include "icing/index/embed/embedding-scorer.h" 31 #include "icing/index/embed/posting-list-embedding-hit-accessor.h" 32 #include "icing/index/iterator/doc-hit-info-iterator.h" 33 #include "icing/index/iterator/section-restrict-data.h" 34 #include "icing/proto/search.pb.h" 35 #include "icing/schema/section.h" 36 37 namespace icing { 38 namespace lib { 39 40 class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator { 41 public: 42 // Create a DocHitInfoIterator for iterating through all docs which have an 43 // embedding matched with the provided query with a score in the range of 44 // [score_low, score_high], using the provided metric_type. 45 // 46 // The iterator will store the matched embedding scores in score_map to 47 // prepare for scoring. 48 // 49 // The iterator will handle the section restriction logic internally by the 50 // provided section_restrict_data. 51 // 52 // Returns: 53 // - a DocHitInfoIteratorEmbedding instance on success. 54 // - Any error from posting lists. 55 static libtextclassifier3::StatusOr< 56 std::unique_ptr<DocHitInfoIteratorEmbedding>> 57 Create(const PropertyProto::VectorProto* query, 58 std::unique_ptr<SectionRestrictData> section_restrict_data, 59 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 60 double score_low, double score_high, 61 EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, 62 const EmbeddingIndex* embedding_index); 63 64 libtextclassifier3::Status Advance() override; 65 66 // The iterator will internally handle the section restriction logic by itself 67 // to have better control, so that it is able to filter out embedding hits 68 // from unwanted sections to avoid retrieving unnecessary vectors and 69 // calculate scores for them. full_section_restriction_applied()70 bool full_section_restriction_applied() const override { return true; } 71 TrimRightMostNode()72 libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override { 73 return absl_ports::InvalidArgumentError( 74 "Query suggestions for the semanticSearch function are not supported"); 75 } 76 GetCallStats()77 CallStats GetCallStats() const override { 78 return CallStats( 79 /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_, 80 /*num_leaf_advance_calls_main_index_in=*/0, 81 /*num_leaf_advance_calls_integer_index_in=*/0, 82 /*num_leaf_advance_calls_no_index_in=*/0, 83 /*num_blocks_inspected_in=*/0); 84 } 85 ToString()86 std::string ToString() const override { return "embedding_iterator"; } 87 88 // PopulateMatchedTermsStats is not applicable to embedding search. PopulateMatchedTermsStats(std::vector<TermMatchInfo> * matched_terms_stats,SectionIdMask filtering_section_mask)89 void PopulateMatchedTermsStats( 90 std::vector<TermMatchInfo>* matched_terms_stats, 91 SectionIdMask filtering_section_mask) const override {} 92 93 private: DocHitInfoIteratorEmbedding(const PropertyProto::VectorProto * query,std::unique_ptr<SectionRestrictData> section_restrict_data,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,std::unique_ptr<EmbeddingScorer> embedding_scorer,double score_low,double score_high,EmbeddingQueryResults::EmbeddingQueryScoreMap * score_map,const EmbeddingIndex * embedding_index,std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)94 explicit DocHitInfoIteratorEmbedding( 95 const PropertyProto::VectorProto* query, 96 std::unique_ptr<SectionRestrictData> section_restrict_data, 97 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 98 std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low, 99 double score_high, 100 EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, 101 const EmbeddingIndex* embedding_index, 102 std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor) 103 : query_(*query), 104 section_restrict_data_(std::move(section_restrict_data)), 105 metric_type_(metric_type), 106 embedding_scorer_(std::move(embedding_scorer)), 107 score_low_(score_low), 108 score_high_(score_high), 109 score_map_(*score_map), 110 embedding_index_(*embedding_index), 111 posting_list_accessor_(std::move(posting_list_accessor)), 112 cached_embedding_hits_idx_(0), 113 current_allowed_sections_mask_(kSectionIdMaskAll), 114 no_more_hit_(false), 115 num_advance_calls_(0) {} 116 117 // Advance to the next embedding hit of the current document. If the current 118 // document id is kInvalidDocumentId, the method will advance to the first 119 // embedding hit of the next document and update doc_hit_info_. 120 // 121 // This method also properly updates cached_embedding_hits_, 122 // cached_embedding_hits_idx_, current_allowed_sections_mask_, and 123 // no_more_hit_ to reflect the current state. 124 // 125 // Returns: 126 // - a const pointer to the next embedding hit on success. 127 // - nullptr, if there is no more hit for the current document, or no more 128 // hit in general if the current document id is kInvalidDocumentId. 129 // - Any error from posting lists. 130 libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit(); 131 132 // Similar to Advance(), this method advances the iterator to the next 133 // document, but it does not guarantee that the next document will have 134 // a matched embedding hit within the score range. 135 // 136 // Returns: 137 // - OK, if it is able to advance to a new document_id. 138 // - RESOUCE_EXHAUSTED, if we have run out of document_ids to iterate over. 139 // - Any error from posting lists. 140 libtextclassifier3::Status AdvanceToNextUnfilteredDocument(); 141 142 // Query information 143 const PropertyProto::VectorProto& query_; // Does not own 144 std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable. 145 146 // Scoring arguments 147 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; 148 std::unique_ptr<EmbeddingScorer> embedding_scorer_; 149 double score_low_; 150 double score_high_; 151 152 // Score map 153 EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own 154 155 // Access to embeddings index data 156 const EmbeddingIndex& embedding_index_; 157 std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_; 158 159 // Cached data from the embeddings index 160 std::vector<EmbeddingHit> cached_embedding_hits_; 161 int cached_embedding_hits_idx_; 162 SectionIdMask current_allowed_sections_mask_; 163 bool no_more_hit_; 164 165 int num_advance_calls_; 166 }; 167 168 } // namespace lib 169 } // namespace icing 170 171 #endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 172