• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/scoring/scorer-factory.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 
20 #include "icing/text_classifier/lib3/utils/base/statusor.h"
21 #include "icing/absl_ports/canonical_errors.h"
22 #include "icing/index/hit/doc-hit-info.h"
23 #include "icing/index/iterator/doc-hit-info-iterator.h"
24 #include "icing/proto/scoring.pb.h"
25 #include "icing/scoring/advanced_scoring/advanced-scorer.h"
26 #include "icing/scoring/bm25f-calculator.h"
27 #include "icing/scoring/scorer.h"
28 #include "icing/scoring/section-weights.h"
29 #include "icing/store/document-id.h"
30 #include "icing/store/document-store.h"
31 #include "icing/util/status-macros.h"
32 
33 namespace icing {
34 namespace lib {
35 
36 class DocumentScoreScorer : public Scorer {
37  public:
DocumentScoreScorer(const DocumentStore * document_store,double default_score)38   explicit DocumentScoreScorer(const DocumentStore* document_store,
39                                double default_score)
40       : document_store_(*document_store), default_score_(default_score) {}
41 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)42   double GetScore(const DocHitInfo& hit_info,
43                   const DocHitInfoIterator*) override {
44     ICING_ASSIGN_OR_RETURN(
45         DocumentAssociatedScoreData score_data,
46         document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()),
47         default_score_);
48 
49     return static_cast<double>(score_data.document_score());
50   }
51 
52  private:
53   const DocumentStore& document_store_;
54   double default_score_;
55 };
56 
57 class DocumentCreationTimestampScorer : public Scorer {
58  public:
DocumentCreationTimestampScorer(const DocumentStore * document_store,double default_score)59   explicit DocumentCreationTimestampScorer(const DocumentStore* document_store,
60                                            double default_score)
61       : document_store_(*document_store), default_score_(default_score) {}
62 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)63   double GetScore(const DocHitInfo& hit_info,
64                   const DocHitInfoIterator*) override {
65     ICING_ASSIGN_OR_RETURN(
66         DocumentAssociatedScoreData score_data,
67         document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()),
68         default_score_);
69 
70     return static_cast<double>(score_data.creation_timestamp_ms());
71   }
72 
73  private:
74   const DocumentStore& document_store_;
75   double default_score_;
76 };
77 
78 class RelevanceScoreScorer : public Scorer {
79  public:
RelevanceScoreScorer(std::unique_ptr<SectionWeights> section_weights,std::unique_ptr<Bm25fCalculator> bm25f_calculator,double default_score)80   explicit RelevanceScoreScorer(
81       std::unique_ptr<SectionWeights> section_weights,
82       std::unique_ptr<Bm25fCalculator> bm25f_calculator, double default_score)
83       : section_weights_(std::move(section_weights)),
84         bm25f_calculator_(std::move(bm25f_calculator)),
85         default_score_(default_score) {}
86 
PrepareToScore(std::unordered_map<std::string,std::unique_ptr<DocHitInfoIterator>> * query_term_iterators)87   void PrepareToScore(
88       std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>*
89           query_term_iterators) override {
90     bm25f_calculator_->PrepareToScore(query_term_iterators);
91   }
92 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)93   double GetScore(const DocHitInfo& hit_info,
94                   const DocHitInfoIterator* query_it) override {
95     if (!query_it) {
96       return default_score_;
97     }
98 
99     return static_cast<double>(
100         bm25f_calculator_->ComputeScore(query_it, hit_info, default_score_));
101   }
102 
103  private:
104   std::unique_ptr<SectionWeights> section_weights_;
105   std::unique_ptr<Bm25fCalculator> bm25f_calculator_;
106   double default_score_;
107 };
108 
109 // A scorer which assigns scores to documents based on usage reports.
110 class UsageScorer : public Scorer {
111  public:
UsageScorer(const DocumentStore * document_store,ScoringSpecProto::RankingStrategy::Code ranking_strategy,double default_score,int64_t current_time_ms)112   UsageScorer(const DocumentStore* document_store,
113               ScoringSpecProto::RankingStrategy::Code ranking_strategy,
114               double default_score, int64_t current_time_ms)
115       : document_store_(*document_store),
116         ranking_strategy_(ranking_strategy),
117         default_score_(default_score),
118         current_time_ms_(current_time_ms) {}
119 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)120   double GetScore(const DocHitInfo& hit_info,
121                   const DocHitInfoIterator*) override {
122     std::optional<UsageStore::UsageScores> usage_scores =
123         document_store_.GetUsageScores(hit_info.document_id(),
124                                        current_time_ms_);
125     if (!usage_scores) {
126       // If there's no UsageScores entry present for this doc, then just
127       // treat it as a default instance.
128       usage_scores = UsageStore::UsageScores();
129     }
130 
131     switch (ranking_strategy_) {
132       case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT:
133         return usage_scores->usage_type1_count;
134       case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT:
135         return usage_scores->usage_type2_count;
136       case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT:
137         return usage_scores->usage_type3_count;
138       case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP:
139         return usage_scores->usage_type1_last_used_timestamp_s * 1000.0;
140       case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP:
141         return usage_scores->usage_type2_last_used_timestamp_s * 1000.0;
142       case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
143         return usage_scores->usage_type3_last_used_timestamp_s * 1000.0;
144       default:
145         // This shouldn't happen if this scorer is used correctly.
146         return default_score_;
147     }
148   }
149 
150  private:
151   const DocumentStore& document_store_;
152   ScoringSpecProto::RankingStrategy::Code ranking_strategy_;
153   double default_score_;
154   int64_t current_time_ms_;
155 };
156 
157 // A special scorer which does nothing but assigns the default score to each
158 // document. This is used especially when no scoring is required in a query.
159 class NoScorer : public Scorer {
160  public:
NoScorer(double default_score)161   explicit NoScorer(double default_score) : default_score_(default_score) {}
162 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)163   double GetScore(const DocHitInfo& hit_info,
164                   const DocHitInfoIterator*) override {
165     return default_score_;
166   }
167 
168  private:
169   double default_score_;
170 };
171 
172 namespace scorer_factory {
173 
Create(const ScoringSpecProto & scoring_spec,double default_score,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms,const JoinChildrenFetcher * join_children_fetcher)174 libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
175     const ScoringSpecProto& scoring_spec, double default_score,
176     const DocumentStore* document_store, const SchemaStore* schema_store,
177     int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) {
178   ICING_RETURN_ERROR_IF_NULL(document_store);
179   ICING_RETURN_ERROR_IF_NULL(schema_store);
180 
181   if (!scoring_spec.advanced_scoring_expression().empty() &&
182       scoring_spec.rank_by() !=
183           ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION) {
184     return absl_ports::InvalidArgumentError(
185         "Advanced scoring is not enabled, but the advanced scoring expression "
186         "is not empty!");
187   }
188 
189   switch (scoring_spec.rank_by()) {
190     case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE:
191       return std::make_unique<DocumentScoreScorer>(document_store,
192                                                    default_score);
193     case ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP:
194       return std::make_unique<DocumentCreationTimestampScorer>(document_store,
195                                                                default_score);
196     case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: {
197       ICING_ASSIGN_OR_RETURN(
198           std::unique_ptr<SectionWeights> section_weights,
199           SectionWeights::Create(schema_store, scoring_spec));
200 
201       auto bm25f_calculator = std::make_unique<Bm25fCalculator>(
202           document_store, section_weights.get(), current_time_ms);
203       return std::make_unique<RelevanceScoreScorer>(std::move(section_weights),
204                                                     std::move(bm25f_calculator),
205                                                     default_score);
206     }
207     case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT:
208       [[fallthrough]];
209     case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT:
210       [[fallthrough]];
211     case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT:
212       [[fallthrough]];
213     case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP:
214       [[fallthrough]];
215     case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP:
216       [[fallthrough]];
217     case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
218       return std::make_unique<UsageScorer>(document_store,
219                                            scoring_spec.rank_by(),
220                                            default_score, current_time_ms);
221     case ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION:
222       if (scoring_spec.advanced_scoring_expression().empty()) {
223         return absl_ports::InvalidArgumentError(
224             "Advanced scoring is enabled, but the expression is empty!");
225       }
226       return AdvancedScorer::Create(scoring_spec, default_score, document_store,
227                                     schema_store, current_time_ms,
228                                     join_children_fetcher);
229     case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
230       // Use join aggregate score to rank. Since the aggregation score is
231       // calculated by child documents after joining (in JoinProcessor), we can
232       // simply use NoScorer for parent documents.
233       [[fallthrough]];
234     case ScoringSpecProto::RankingStrategy::NONE:
235       return std::make_unique<NoScorer>(default_score);
236   }
237 }
238 
239 }  // namespace scorer_factory
240 
241 }  // namespace lib
242 }  // namespace icing
243