• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
16 #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/index/embed/embedding-query-results.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/iterator/doc-hit-info-iterator.h"
32 #include "icing/join/join-children-fetcher.h"
33 #include "icing/scoring/bm25f-calculator.h"
34 #include "icing/scoring/section-weights.h"
35 #include "icing/store/document-filter-data.h"
36 #include "icing/store/document-id.h"
37 #include "icing/store/document-store.h"
38 
39 namespace icing {
40 namespace lib {
41 
42 enum class ScoreExpressionType {
43   kDouble,
44   kDoubleList,
45   kDocument,  // Only "this" is considered as document type.
46   // TODO(b/326656531): Instead of creating a vector index type, consider
47   // changing it to vector type so that the data is the vector directly.
48   kVectorIndex,
49   kString,
50 };
51 
52 class ScoreExpression {
53  public:
54   virtual ~ScoreExpression() = default;
55 
56   // Evaluate the score expression to double with the current document.
57   //
58   // RETURNS:
59   //   - The evaluated result as a double on success.
60   //   - INVALID_ARGUMENT if a non-finite value is reached while evaluating the
61   //                      expression.
62   //   - INTERNAL if there are inconsistencies.
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)63   virtual libtextclassifier3::StatusOr<double> EvaluateDouble(
64       const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
65     if (type() == ScoreExpressionType::kDouble) {
66       return absl_ports::UnimplementedError(
67           "All ScoreExpressions of type double must provide their own "
68           "implementation of EvaluateDouble!");
69     }
70     return absl_ports::InternalError(
71         "Runtime type error: the expression should never be evaluated to a "
72         "double. There must be inconsistencies in the static type checking.");
73   }
74 
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)75   virtual libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
76       const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
77     if (type() == ScoreExpressionType::kDoubleList) {
78       return absl_ports::UnimplementedError(
79           "All ScoreExpressions of type double List must provide their own "
80           "implementation of EvaluateList!");
81     }
82     return absl_ports::InternalError(
83         "Runtime type error: the expression should never be evaluated to a "
84         "double list. There must be inconsistencies in the static type "
85         "checking.");
86   }
87 
EvaluateString()88   virtual libtextclassifier3::StatusOr<std::string_view> EvaluateString()
89       const {
90     if (type() == ScoreExpressionType::kString) {
91       return absl_ports::UnimplementedError(
92           "All ScoreExpressions of type string must provide their own "
93           "implementation of EvaluateString!");
94     }
95     return absl_ports::InternalError(
96         "Runtime type error: the expression should never be evaluated to a "
97         "string. There must be inconsistencies in the static type checking.");
98   }
99 
100   // Indicate the type to which the current expression will be evaluated.
101   virtual ScoreExpressionType type() const = 0;
102 
103   // Indicate whether the current expression is a constant.
104   // Returns true if and only if the object is of ConstantScoreExpression or
105   // StringExpression type.
is_constant()106   virtual bool is_constant() const { return false; }
107 };
108 
109 class ThisExpression : public ScoreExpression {
110  public:
Create()111   static std::unique_ptr<ThisExpression> Create() {
112     return std::unique_ptr<ThisExpression>(new ThisExpression());
113   }
114 
type()115   ScoreExpressionType type() const override {
116     return ScoreExpressionType::kDocument;
117   }
118 
119  private:
120   ThisExpression() = default;
121 };
122 
123 class ConstantScoreExpression : public ScoreExpression {
124  public:
125   static std::unique_ptr<ConstantScoreExpression> Create(
126       libtextclassifier3::StatusOr<double> c,
127       ScoreExpressionType type = ScoreExpressionType::kDouble) {
128     return std::unique_ptr<ConstantScoreExpression>(
129         new ConstantScoreExpression(c, type));
130   }
131 
EvaluateDouble(const DocHitInfo &,const DocHitInfoIterator *)132   libtextclassifier3::StatusOr<double> EvaluateDouble(
133       const DocHitInfo&, const DocHitInfoIterator*) const override {
134     return c_;
135   }
136 
type()137   ScoreExpressionType type() const override { return type_; }
138 
is_constant()139   bool is_constant() const override { return true; }
140 
141  private:
ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,ScoreExpressionType type)142   explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,
143                                    ScoreExpressionType type)
144       : c_(c), type_(type) {}
145 
146   libtextclassifier3::StatusOr<double> c_;
147   ScoreExpressionType type_;
148 };
149 
150 class StringExpression : public ScoreExpression {
151  public:
Create(std::string str)152   static std::unique_ptr<StringExpression> Create(std::string str) {
153     return std::unique_ptr<StringExpression>(
154         new StringExpression(std::move(str)));
155   }
156 
EvaluateString()157   libtextclassifier3::StatusOr<std::string_view> EvaluateString()
158       const override {
159     return str_;
160   }
161 
type()162   ScoreExpressionType type() const override {
163     return ScoreExpressionType::kString;
164   }
165 
is_constant()166   bool is_constant() const override { return true; }
167 
168  private:
StringExpression(std::string str)169   explicit StringExpression(std::string str) : str_(std::move(str)) {}
170   std::string str_;
171 };
172 
173 class OperatorScoreExpression : public ScoreExpression {
174  public:
175   enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv };
176 
177   // RETURNS:
178   //   - An OperatorScoreExpression instance on success if not simplifiable.
179   //   - A ConstantScoreExpression instance on success if simplifiable.
180   //   - FAILED_PRECONDITION on any null pointer in children.
181   //   - INVALID_ARGUMENT on type errors.
182   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
183       OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children);
184 
185   libtextclassifier3::StatusOr<double> EvaluateDouble(
186       const DocHitInfo& hit_info,
187       const DocHitInfoIterator* query_it) const override;
188 
type()189   ScoreExpressionType type() const override {
190     return ScoreExpressionType::kDouble;
191   }
192 
193  private:
OperatorScoreExpression(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)194   explicit OperatorScoreExpression(
195       OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children)
196       : op_(op), children_(std::move(children)) {}
197 
198   OperatorType op_;
199   std::vector<std::unique_ptr<ScoreExpression>> children_;
200 };
201 
202 class MathFunctionScoreExpression : public ScoreExpression {
203  public:
204   enum class FunctionType {
205     kLog,
206     kPow,
207     kMax,
208     kMin,
209     kLen,
210     kSum,
211     kAvg,
212     kSqrt,
213     kAbs,
214     kSin,
215     kCos,
216     kTan,
217     kMaxOrDefault,
218     kMinOrDefault,
219   };
220 
221   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
222 
223   static const std::unordered_set<FunctionType> kVariableArgumentsFunctions;
224 
225   static const std::unordered_set<FunctionType> kListArgumentFunctions;
226 
227   // RETURNS:
228   //   - A MathFunctionScoreExpression instance on success if not simplifiable.
229   //   - A ConstantScoreExpression instance on success if simplifiable.
230   //   - FAILED_PRECONDITION on any null pointer in args.
231   //   - INVALID_ARGUMENT on type errors.
232   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
233       FunctionType function_type,
234       std::vector<std::unique_ptr<ScoreExpression>> args);
235 
236   libtextclassifier3::StatusOr<double> EvaluateDouble(
237       const DocHitInfo& hit_info,
238       const DocHitInfoIterator* query_it) const override;
239 
type()240   ScoreExpressionType type() const override {
241     return ScoreExpressionType::kDouble;
242   }
243 
244  private:
MathFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)245   explicit MathFunctionScoreExpression(
246       FunctionType function_type,
247       std::vector<std::unique_ptr<ScoreExpression>> args)
248       : function_type_(function_type), args_(std::move(args)) {}
249 
250   FunctionType function_type_;
251   std::vector<std::unique_ptr<ScoreExpression>> args_;
252 };
253 
254 class ListOperationFunctionScoreExpression : public ScoreExpression {
255  public:
256   enum class FunctionType { kFilterByRange };
257 
258   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
259 
260   // RETURNS:
261   //   - A ListOperationFunctionScoreExpression instance on success.
262   //   - FAILED_PRECONDITION on any null pointer in args.
263   //   - INVALID_ARGUMENT on type errors.
264   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
265       FunctionType function_type,
266       std::vector<std::unique_ptr<ScoreExpression>> args);
267 
268   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
269       const DocHitInfo& hit_info,
270       const DocHitInfoIterator* query_it) const override;
271 
type()272   ScoreExpressionType type() const override {
273     return ScoreExpressionType::kDoubleList;
274   }
275 
276  private:
ListOperationFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)277   explicit ListOperationFunctionScoreExpression(
278       FunctionType function_type,
279       std::vector<std::unique_ptr<ScoreExpression>> args)
280       : function_type_(function_type), args_(std::move(args)) {}
281 
282   FunctionType function_type_;
283   std::vector<std::unique_ptr<ScoreExpression>> args_;
284 };
285 
286 class DocumentFunctionScoreExpression : public ScoreExpression {
287  public:
288   enum class FunctionType {
289     kDocumentScore,
290     kCreationTimestamp,
291     kUsageCount,
292     kUsageLastUsedTimestamp,
293   };
294 
295   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
296 
297   // RETURNS:
298   //   - A DocumentFunctionScoreExpression instance on success.
299   //   - FAILED_PRECONDITION on any null pointer in args.
300   //   - INVALID_ARGUMENT on type errors.
301   static libtextclassifier3::StatusOr<
302       std::unique_ptr<DocumentFunctionScoreExpression>>
303   Create(FunctionType function_type,
304          std::vector<std::unique_ptr<ScoreExpression>> args,
305          const DocumentStore* document_store, double default_score,
306          int64_t current_time_ms);
307 
308   libtextclassifier3::StatusOr<double> EvaluateDouble(
309       const DocHitInfo& hit_info,
310       const DocHitInfoIterator* query_it) const override;
311 
type()312   ScoreExpressionType type() const override {
313     return ScoreExpressionType::kDouble;
314   }
315 
316  private:
DocumentFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)317   explicit DocumentFunctionScoreExpression(
318       FunctionType function_type,
319       std::vector<std::unique_ptr<ScoreExpression>> args,
320       const DocumentStore* document_store, double default_score,
321       int64_t current_time_ms)
322       : args_(std::move(args)),
323         document_store_(*document_store),
324         default_score_(default_score),
325         function_type_(function_type),
326         current_time_ms_(current_time_ms) {}
327 
328   std::vector<std::unique_ptr<ScoreExpression>> args_;
329   const DocumentStore& document_store_;
330   double default_score_;
331   FunctionType function_type_;
332   int64_t current_time_ms_;
333 };
334 
335 class RelevanceScoreFunctionScoreExpression : public ScoreExpression {
336  public:
337   static constexpr std::string_view kFunctionName = "relevanceScore";
338 
339   // RETURNS:
340   //   - A RelevanceScoreFunctionScoreExpression instance on success.
341   //   - FAILED_PRECONDITION on any null pointer in args.
342   //   - INVALID_ARGUMENT on type errors.
343   static libtextclassifier3::StatusOr<
344       std::unique_ptr<RelevanceScoreFunctionScoreExpression>>
345   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
346          Bm25fCalculator* bm25f_calculator, double default_score);
347 
348   libtextclassifier3::StatusOr<double> EvaluateDouble(
349       const DocHitInfo& hit_info,
350       const DocHitInfoIterator* query_it) const override;
351 
type()352   ScoreExpressionType type() const override {
353     return ScoreExpressionType::kDouble;
354   }
355 
356  private:
RelevanceScoreFunctionScoreExpression(Bm25fCalculator * bm25f_calculator,double default_score)357   explicit RelevanceScoreFunctionScoreExpression(
358       Bm25fCalculator* bm25f_calculator, double default_score)
359       : bm25f_calculator_(*bm25f_calculator), default_score_(default_score) {}
360 
361   Bm25fCalculator& bm25f_calculator_;
362   double default_score_;
363 };
364 
365 class ChildrenRankingSignalsFunctionScoreExpression : public ScoreExpression {
366  public:
367   static constexpr std::string_view kFunctionName = "childrenRankingSignals";
368 
369   // RETURNS:
370   //   - A ChildrenRankingSignalsFunctionScoreExpression instance on success.
371   //   - FAILED_PRECONDITION on any null pointer in children.
372   //   - INVALID_ARGUMENT on type errors.
373   static libtextclassifier3::StatusOr<
374       std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>>
375   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
376          const JoinChildrenFetcher* join_children_fetcher);
377 
378   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
379       const DocHitInfo& hit_info,
380       const DocHitInfoIterator* query_it) const override;
381 
type()382   ScoreExpressionType type() const override {
383     return ScoreExpressionType::kDoubleList;
384   }
385 
386  private:
ChildrenRankingSignalsFunctionScoreExpression(const JoinChildrenFetcher & join_children_fetcher)387   explicit ChildrenRankingSignalsFunctionScoreExpression(
388       const JoinChildrenFetcher& join_children_fetcher)
389       : join_children_fetcher_(join_children_fetcher) {}
390   const JoinChildrenFetcher& join_children_fetcher_;
391 };
392 
393 class PropertyWeightsFunctionScoreExpression : public ScoreExpression {
394  public:
395   static constexpr std::string_view kFunctionName = "propertyWeights";
396 
397   // RETURNS:
398   //   - A PropertyWeightsFunctionScoreExpression instance on success.
399   //   - FAILED_PRECONDITION on any null pointer in children.
400   //   - INVALID_ARGUMENT on type errors.
401   static libtextclassifier3::StatusOr<
402       std::unique_ptr<PropertyWeightsFunctionScoreExpression>>
403   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
404          const DocumentStore* document_store,
405          const SectionWeights* section_weights, int64_t current_time_ms);
406 
407   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
408       const DocHitInfo& hit_info, const DocHitInfoIterator*) const override;
409 
type()410   ScoreExpressionType type() const override {
411     return ScoreExpressionType::kDoubleList;
412   }
413 
414   SchemaTypeId GetSchemaTypeId(DocumentId document_id) const;
415 
416  private:
PropertyWeightsFunctionScoreExpression(const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)417   explicit PropertyWeightsFunctionScoreExpression(
418       const DocumentStore* document_store,
419       const SectionWeights* section_weights, int64_t current_time_ms)
420       : document_store_(*document_store),
421         section_weights_(*section_weights),
422         current_time_ms_(current_time_ms) {}
423   const DocumentStore& document_store_;
424   const SectionWeights& section_weights_;
425   int64_t current_time_ms_;
426 };
427 
428 class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression {
429  public:
430   static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding";
431 
432   // RETURNS:
433   //   - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if
434   //     not simplifiable.
435   //   - A ConstantScoreExpression instance on success if simplifiable.
436   //   - FAILED_PRECONDITION on any null pointer in children.
437   //   - INVALID_ARGUMENT on type errors.
438   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
439       std::vector<std::unique_ptr<ScoreExpression>> args);
440 
441   libtextclassifier3::StatusOr<double> EvaluateDouble(
442       const DocHitInfo& hit_info,
443       const DocHitInfoIterator* query_it) const override;
444 
type()445   ScoreExpressionType type() const override {
446     return ScoreExpressionType::kVectorIndex;
447   }
448 
449  private:
GetSearchSpecEmbeddingFunctionScoreExpression(std::unique_ptr<ScoreExpression> arg)450   explicit GetSearchSpecEmbeddingFunctionScoreExpression(
451       std::unique_ptr<ScoreExpression> arg)
452       : arg_(std::move(arg)) {}
453   std::unique_ptr<ScoreExpression> arg_;
454 };
455 
456 class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression {
457  public:
458   static constexpr std::string_view kFunctionName = "matchedSemanticScores";
459 
460   // RETURNS:
461   //   - A MatchedSemanticScoresFunctionScoreExpression instance on success.
462   //   - FAILED_PRECONDITION on any null pointer in children.
463   //   - INVALID_ARGUMENT on type errors.
464   static libtextclassifier3::StatusOr<
465       std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
466   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
467          SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
468          const EmbeddingQueryResults* embedding_query_results);
469 
470   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
471       const DocHitInfo& hit_info,
472       const DocHitInfoIterator* query_it) const override;
473 
type()474   ScoreExpressionType type() const override {
475     return ScoreExpressionType::kDoubleList;
476   }
477 
478  private:
MatchedSemanticScoresFunctionScoreExpression(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,const EmbeddingQueryResults & embedding_query_results)479   explicit MatchedSemanticScoresFunctionScoreExpression(
480       std::vector<std::unique_ptr<ScoreExpression>> args,
481       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
482       const EmbeddingQueryResults& embedding_query_results)
483       : args_(std::move(args)),
484         metric_type_(metric_type),
485         embedding_query_results_(embedding_query_results) {}
486 
487   std::vector<std::unique_ptr<ScoreExpression>> args_;
488   const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
489   const EmbeddingQueryResults& embedding_query_results_;
490 };
491 
492 }  // namespace lib
493 }  // namespace icing
494 
495 #endif  // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
496