1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/scoring/scorer-factory.h"
16
17 #include <memory>
18 #include <unordered_map>
19
20 #include "icing/text_classifier/lib3/utils/base/statusor.h"
21 #include "icing/absl_ports/canonical_errors.h"
22 #include "icing/index/hit/doc-hit-info.h"
23 #include "icing/index/iterator/doc-hit-info-iterator.h"
24 #include "icing/proto/scoring.pb.h"
25 #include "icing/scoring/advanced_scoring/advanced-scorer.h"
26 #include "icing/scoring/bm25f-calculator.h"
27 #include "icing/scoring/scorer.h"
28 #include "icing/scoring/section-weights.h"
29 #include "icing/store/document-id.h"
30 #include "icing/store/document-store.h"
31 #include "icing/util/status-macros.h"
32
33 namespace icing {
34 namespace lib {
35
36 class DocumentScoreScorer : public Scorer {
37 public:
DocumentScoreScorer(const DocumentStore * document_store,double default_score)38 explicit DocumentScoreScorer(const DocumentStore* document_store,
39 double default_score)
40 : document_store_(*document_store), default_score_(default_score) {}
41
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)42 double GetScore(const DocHitInfo& hit_info,
43 const DocHitInfoIterator*) override {
44 ICING_ASSIGN_OR_RETURN(
45 DocumentAssociatedScoreData score_data,
46 document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()),
47 default_score_);
48
49 return static_cast<double>(score_data.document_score());
50 }
51
52 private:
53 const DocumentStore& document_store_;
54 double default_score_;
55 };
56
57 class DocumentCreationTimestampScorer : public Scorer {
58 public:
DocumentCreationTimestampScorer(const DocumentStore * document_store,double default_score)59 explicit DocumentCreationTimestampScorer(const DocumentStore* document_store,
60 double default_score)
61 : document_store_(*document_store), default_score_(default_score) {}
62
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)63 double GetScore(const DocHitInfo& hit_info,
64 const DocHitInfoIterator*) override {
65 ICING_ASSIGN_OR_RETURN(
66 DocumentAssociatedScoreData score_data,
67 document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()),
68 default_score_);
69
70 return static_cast<double>(score_data.creation_timestamp_ms());
71 }
72
73 private:
74 const DocumentStore& document_store_;
75 double default_score_;
76 };
77
78 class RelevanceScoreScorer : public Scorer {
79 public:
RelevanceScoreScorer(std::unique_ptr<SectionWeights> section_weights,std::unique_ptr<Bm25fCalculator> bm25f_calculator,double default_score)80 explicit RelevanceScoreScorer(
81 std::unique_ptr<SectionWeights> section_weights,
82 std::unique_ptr<Bm25fCalculator> bm25f_calculator, double default_score)
83 : section_weights_(std::move(section_weights)),
84 bm25f_calculator_(std::move(bm25f_calculator)),
85 default_score_(default_score) {}
86
PrepareToScore(std::unordered_map<std::string,std::unique_ptr<DocHitInfoIterator>> * query_term_iterators)87 void PrepareToScore(
88 std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>*
89 query_term_iterators) override {
90 bm25f_calculator_->PrepareToScore(query_term_iterators);
91 }
92
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)93 double GetScore(const DocHitInfo& hit_info,
94 const DocHitInfoIterator* query_it) override {
95 if (!query_it) {
96 return default_score_;
97 }
98
99 return static_cast<double>(
100 bm25f_calculator_->ComputeScore(query_it, hit_info, default_score_));
101 }
102
103 private:
104 std::unique_ptr<SectionWeights> section_weights_;
105 std::unique_ptr<Bm25fCalculator> bm25f_calculator_;
106 double default_score_;
107 };
108
109 // A scorer which assigns scores to documents based on usage reports.
110 class UsageScorer : public Scorer {
111 public:
UsageScorer(const DocumentStore * document_store,ScoringSpecProto::RankingStrategy::Code ranking_strategy,double default_score,int64_t current_time_ms)112 UsageScorer(const DocumentStore* document_store,
113 ScoringSpecProto::RankingStrategy::Code ranking_strategy,
114 double default_score, int64_t current_time_ms)
115 : document_store_(*document_store),
116 ranking_strategy_(ranking_strategy),
117 default_score_(default_score),
118 current_time_ms_(current_time_ms) {}
119
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)120 double GetScore(const DocHitInfo& hit_info,
121 const DocHitInfoIterator*) override {
122 std::optional<UsageStore::UsageScores> usage_scores =
123 document_store_.GetUsageScores(hit_info.document_id(),
124 current_time_ms_);
125 if (!usage_scores) {
126 // If there's no UsageScores entry present for this doc, then just
127 // treat it as a default instance.
128 usage_scores = UsageStore::UsageScores();
129 }
130
131 switch (ranking_strategy_) {
132 case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT:
133 return usage_scores->usage_type1_count;
134 case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT:
135 return usage_scores->usage_type2_count;
136 case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT:
137 return usage_scores->usage_type3_count;
138 case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP:
139 return usage_scores->usage_type1_last_used_timestamp_s * 1000.0;
140 case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP:
141 return usage_scores->usage_type2_last_used_timestamp_s * 1000.0;
142 case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
143 return usage_scores->usage_type3_last_used_timestamp_s * 1000.0;
144 default:
145 // This shouldn't happen if this scorer is used correctly.
146 return default_score_;
147 }
148 }
149
150 private:
151 const DocumentStore& document_store_;
152 ScoringSpecProto::RankingStrategy::Code ranking_strategy_;
153 double default_score_;
154 int64_t current_time_ms_;
155 };
156
157 // A special scorer which does nothing but assigns the default score to each
158 // document. This is used especially when no scoring is required in a query.
159 class NoScorer : public Scorer {
160 public:
NoScorer(double default_score)161 explicit NoScorer(double default_score) : default_score_(default_score) {}
162
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator *)163 double GetScore(const DocHitInfo& hit_info,
164 const DocHitInfoIterator*) override {
165 return default_score_;
166 }
167
168 private:
169 double default_score_;
170 };
171
172 namespace scorer_factory {
173
Create(const ScoringSpecProto & scoring_spec,double default_score,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms,const JoinChildrenFetcher * join_children_fetcher)174 libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
175 const ScoringSpecProto& scoring_spec, double default_score,
176 const DocumentStore* document_store, const SchemaStore* schema_store,
177 int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) {
178 ICING_RETURN_ERROR_IF_NULL(document_store);
179 ICING_RETURN_ERROR_IF_NULL(schema_store);
180
181 if (!scoring_spec.advanced_scoring_expression().empty() &&
182 scoring_spec.rank_by() !=
183 ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION) {
184 return absl_ports::InvalidArgumentError(
185 "Advanced scoring is not enabled, but the advanced scoring expression "
186 "is not empty!");
187 }
188
189 switch (scoring_spec.rank_by()) {
190 case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE:
191 return std::make_unique<DocumentScoreScorer>(document_store,
192 default_score);
193 case ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP:
194 return std::make_unique<DocumentCreationTimestampScorer>(document_store,
195 default_score);
196 case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: {
197 ICING_ASSIGN_OR_RETURN(
198 std::unique_ptr<SectionWeights> section_weights,
199 SectionWeights::Create(schema_store, scoring_spec));
200
201 auto bm25f_calculator = std::make_unique<Bm25fCalculator>(
202 document_store, section_weights.get(), current_time_ms);
203 return std::make_unique<RelevanceScoreScorer>(std::move(section_weights),
204 std::move(bm25f_calculator),
205 default_score);
206 }
207 case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT:
208 [[fallthrough]];
209 case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT:
210 [[fallthrough]];
211 case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT:
212 [[fallthrough]];
213 case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP:
214 [[fallthrough]];
215 case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP:
216 [[fallthrough]];
217 case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
218 return std::make_unique<UsageScorer>(document_store,
219 scoring_spec.rank_by(),
220 default_score, current_time_ms);
221 case ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION:
222 if (scoring_spec.advanced_scoring_expression().empty()) {
223 return absl_ports::InvalidArgumentError(
224 "Advanced scoring is enabled, but the expression is empty!");
225 }
226 return AdvancedScorer::Create(scoring_spec, default_score, document_store,
227 schema_store, current_time_ms,
228 join_children_fetcher);
229 case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
230 // Use join aggregate score to rank. Since the aggregation score is
231 // calculated by child documents after joining (in JoinProcessor), we can
232 // simply use NoScorer for parent documents.
233 [[fallthrough]];
234 case ScoringSpecProto::RankingStrategy::NONE:
235 return std::make_unique<NoScorer>(default_score);
236 }
237 }
238
239 } // namespace scorer_factory
240
241 } // namespace lib
242 } // namespace icing
243