1 // Copyright (C) 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 16 #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <string> 21 #include <string_view> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "icing/text_classifier/lib3/utils/base/statusor.h" 28 #include "icing/absl_ports/canonical_errors.h" 29 #include "icing/index/embed/embedding-query-results.h" 30 #include "icing/index/hit/doc-hit-info.h" 31 #include "icing/index/iterator/doc-hit-info-iterator.h" 32 #include "icing/join/join-children-fetcher.h" 33 #include "icing/scoring/bm25f-calculator.h" 34 #include "icing/scoring/section-weights.h" 35 #include "icing/store/document-filter-data.h" 36 #include "icing/store/document-id.h" 37 #include "icing/store/document-store.h" 38 39 namespace icing { 40 namespace lib { 41 42 enum class ScoreExpressionType { 43 kDouble, 44 kDoubleList, 45 kDocument, // Only "this" is considered as document type. 46 // TODO(b/326656531): Instead of creating a vector index type, consider 47 // changing it to vector type so that the data is the vector directly. 48 kVectorIndex, 49 kString, 50 }; 51 52 class ScoreExpression { 53 public: 54 virtual ~ScoreExpression() = default; 55 56 // Evaluate the score expression to double with the current document. 57 // 58 // RETURNS: 59 // - The evaluated result as a double on success. 60 // - INVALID_ARGUMENT if a non-finite value is reached while evaluating the 61 // expression. 62 // - INTERNAL if there are inconsistencies. EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)63 virtual libtextclassifier3::StatusOr<double> EvaluateDouble( 64 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { 65 if (type() == ScoreExpressionType::kDouble) { 66 return absl_ports::UnimplementedError( 67 "All ScoreExpressions of type double must provide their own " 68 "implementation of EvaluateDouble!"); 69 } 70 return absl_ports::InternalError( 71 "Runtime type error: the expression should never be evaluated to a " 72 "double. There must be inconsistencies in the static type checking."); 73 } 74 EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)75 virtual libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 76 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { 77 if (type() == ScoreExpressionType::kDoubleList) { 78 return absl_ports::UnimplementedError( 79 "All ScoreExpressions of type double List must provide their own " 80 "implementation of EvaluateList!"); 81 } 82 return absl_ports::InternalError( 83 "Runtime type error: the expression should never be evaluated to a " 84 "double list. There must be inconsistencies in the static type " 85 "checking."); 86 } 87 EvaluateString()88 virtual libtextclassifier3::StatusOr<std::string_view> EvaluateString() 89 const { 90 if (type() == ScoreExpressionType::kString) { 91 return absl_ports::UnimplementedError( 92 "All ScoreExpressions of type string must provide their own " 93 "implementation of EvaluateString!"); 94 } 95 return absl_ports::InternalError( 96 "Runtime type error: the expression should never be evaluated to a " 97 "string. There must be inconsistencies in the static type checking."); 98 } 99 100 // Indicate the type to which the current expression will be evaluated. 101 virtual ScoreExpressionType type() const = 0; 102 103 // Indicate whether the current expression is a constant. 104 // Returns true if and only if the object is of ConstantScoreExpression or 105 // StringExpression type. is_constant()106 virtual bool is_constant() const { return false; } 107 }; 108 109 class ThisExpression : public ScoreExpression { 110 public: Create()111 static std::unique_ptr<ThisExpression> Create() { 112 return std::unique_ptr<ThisExpression>(new ThisExpression()); 113 } 114 type()115 ScoreExpressionType type() const override { 116 return ScoreExpressionType::kDocument; 117 } 118 119 private: 120 ThisExpression() = default; 121 }; 122 123 class ConstantScoreExpression : public ScoreExpression { 124 public: 125 static std::unique_ptr<ConstantScoreExpression> Create( 126 libtextclassifier3::StatusOr<double> c, 127 ScoreExpressionType type = ScoreExpressionType::kDouble) { 128 return std::unique_ptr<ConstantScoreExpression>( 129 new ConstantScoreExpression(c, type)); 130 } 131 EvaluateDouble(const DocHitInfo &,const DocHitInfoIterator *)132 libtextclassifier3::StatusOr<double> EvaluateDouble( 133 const DocHitInfo&, const DocHitInfoIterator*) const override { 134 return c_; 135 } 136 type()137 ScoreExpressionType type() const override { return type_; } 138 is_constant()139 bool is_constant() const override { return true; } 140 141 private: ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,ScoreExpressionType type)142 explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c, 143 ScoreExpressionType type) 144 : c_(c), type_(type) {} 145 146 libtextclassifier3::StatusOr<double> c_; 147 ScoreExpressionType type_; 148 }; 149 150 class StringExpression : public ScoreExpression { 151 public: Create(std::string str)152 static std::unique_ptr<StringExpression> Create(std::string str) { 153 return std::unique_ptr<StringExpression>( 154 new StringExpression(std::move(str))); 155 } 156 EvaluateString()157 libtextclassifier3::StatusOr<std::string_view> EvaluateString() 158 const override { 159 return str_; 160 } 161 type()162 ScoreExpressionType type() const override { 163 return ScoreExpressionType::kString; 164 } 165 is_constant()166 bool is_constant() const override { return true; } 167 168 private: StringExpression(std::string str)169 explicit StringExpression(std::string str) : str_(std::move(str)) {} 170 std::string str_; 171 }; 172 173 class OperatorScoreExpression : public ScoreExpression { 174 public: 175 enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv }; 176 177 // RETURNS: 178 // - An OperatorScoreExpression instance on success if not simplifiable. 179 // - A ConstantScoreExpression instance on success if simplifiable. 180 // - FAILED_PRECONDITION on any null pointer in children. 181 // - INVALID_ARGUMENT on type errors. 182 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 183 OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children); 184 185 libtextclassifier3::StatusOr<double> EvaluateDouble( 186 const DocHitInfo& hit_info, 187 const DocHitInfoIterator* query_it) const override; 188 type()189 ScoreExpressionType type() const override { 190 return ScoreExpressionType::kDouble; 191 } 192 193 private: OperatorScoreExpression(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)194 explicit OperatorScoreExpression( 195 OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) 196 : op_(op), children_(std::move(children)) {} 197 198 OperatorType op_; 199 std::vector<std::unique_ptr<ScoreExpression>> children_; 200 }; 201 202 class MathFunctionScoreExpression : public ScoreExpression { 203 public: 204 enum class FunctionType { 205 kLog, 206 kPow, 207 kMax, 208 kMin, 209 kLen, 210 kSum, 211 kAvg, 212 kSqrt, 213 kAbs, 214 kSin, 215 kCos, 216 kTan, 217 kMaxOrDefault, 218 kMinOrDefault, 219 }; 220 221 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 222 223 static const std::unordered_set<FunctionType> kVariableArgumentsFunctions; 224 225 static const std::unordered_set<FunctionType> kListArgumentFunctions; 226 227 // RETURNS: 228 // - A MathFunctionScoreExpression instance on success if not simplifiable. 229 // - A ConstantScoreExpression instance on success if simplifiable. 230 // - FAILED_PRECONDITION on any null pointer in args. 231 // - INVALID_ARGUMENT on type errors. 232 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 233 FunctionType function_type, 234 std::vector<std::unique_ptr<ScoreExpression>> args); 235 236 libtextclassifier3::StatusOr<double> EvaluateDouble( 237 const DocHitInfo& hit_info, 238 const DocHitInfoIterator* query_it) const override; 239 type()240 ScoreExpressionType type() const override { 241 return ScoreExpressionType::kDouble; 242 } 243 244 private: MathFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)245 explicit MathFunctionScoreExpression( 246 FunctionType function_type, 247 std::vector<std::unique_ptr<ScoreExpression>> args) 248 : function_type_(function_type), args_(std::move(args)) {} 249 250 FunctionType function_type_; 251 std::vector<std::unique_ptr<ScoreExpression>> args_; 252 }; 253 254 class ListOperationFunctionScoreExpression : public ScoreExpression { 255 public: 256 enum class FunctionType { kFilterByRange }; 257 258 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 259 260 // RETURNS: 261 // - A ListOperationFunctionScoreExpression instance on success. 262 // - FAILED_PRECONDITION on any null pointer in args. 263 // - INVALID_ARGUMENT on type errors. 264 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 265 FunctionType function_type, 266 std::vector<std::unique_ptr<ScoreExpression>> args); 267 268 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 269 const DocHitInfo& hit_info, 270 const DocHitInfoIterator* query_it) const override; 271 type()272 ScoreExpressionType type() const override { 273 return ScoreExpressionType::kDoubleList; 274 } 275 276 private: ListOperationFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)277 explicit ListOperationFunctionScoreExpression( 278 FunctionType function_type, 279 std::vector<std::unique_ptr<ScoreExpression>> args) 280 : function_type_(function_type), args_(std::move(args)) {} 281 282 FunctionType function_type_; 283 std::vector<std::unique_ptr<ScoreExpression>> args_; 284 }; 285 286 class DocumentFunctionScoreExpression : public ScoreExpression { 287 public: 288 enum class FunctionType { 289 kDocumentScore, 290 kCreationTimestamp, 291 kUsageCount, 292 kUsageLastUsedTimestamp, 293 }; 294 295 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 296 297 // RETURNS: 298 // - A DocumentFunctionScoreExpression instance on success. 299 // - FAILED_PRECONDITION on any null pointer in args. 300 // - INVALID_ARGUMENT on type errors. 301 static libtextclassifier3::StatusOr< 302 std::unique_ptr<DocumentFunctionScoreExpression>> 303 Create(FunctionType function_type, 304 std::vector<std::unique_ptr<ScoreExpression>> args, 305 const DocumentStore* document_store, double default_score, 306 int64_t current_time_ms); 307 308 libtextclassifier3::StatusOr<double> EvaluateDouble( 309 const DocHitInfo& hit_info, 310 const DocHitInfoIterator* query_it) const override; 311 type()312 ScoreExpressionType type() const override { 313 return ScoreExpressionType::kDouble; 314 } 315 316 private: DocumentFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)317 explicit DocumentFunctionScoreExpression( 318 FunctionType function_type, 319 std::vector<std::unique_ptr<ScoreExpression>> args, 320 const DocumentStore* document_store, double default_score, 321 int64_t current_time_ms) 322 : args_(std::move(args)), 323 document_store_(*document_store), 324 default_score_(default_score), 325 function_type_(function_type), 326 current_time_ms_(current_time_ms) {} 327 328 std::vector<std::unique_ptr<ScoreExpression>> args_; 329 const DocumentStore& document_store_; 330 double default_score_; 331 FunctionType function_type_; 332 int64_t current_time_ms_; 333 }; 334 335 class RelevanceScoreFunctionScoreExpression : public ScoreExpression { 336 public: 337 static constexpr std::string_view kFunctionName = "relevanceScore"; 338 339 // RETURNS: 340 // - A RelevanceScoreFunctionScoreExpression instance on success. 341 // - FAILED_PRECONDITION on any null pointer in args. 342 // - INVALID_ARGUMENT on type errors. 343 static libtextclassifier3::StatusOr< 344 std::unique_ptr<RelevanceScoreFunctionScoreExpression>> 345 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 346 Bm25fCalculator* bm25f_calculator, double default_score); 347 348 libtextclassifier3::StatusOr<double> EvaluateDouble( 349 const DocHitInfo& hit_info, 350 const DocHitInfoIterator* query_it) const override; 351 type()352 ScoreExpressionType type() const override { 353 return ScoreExpressionType::kDouble; 354 } 355 356 private: RelevanceScoreFunctionScoreExpression(Bm25fCalculator * bm25f_calculator,double default_score)357 explicit RelevanceScoreFunctionScoreExpression( 358 Bm25fCalculator* bm25f_calculator, double default_score) 359 : bm25f_calculator_(*bm25f_calculator), default_score_(default_score) {} 360 361 Bm25fCalculator& bm25f_calculator_; 362 double default_score_; 363 }; 364 365 class ChildrenRankingSignalsFunctionScoreExpression : public ScoreExpression { 366 public: 367 static constexpr std::string_view kFunctionName = "childrenRankingSignals"; 368 369 // RETURNS: 370 // - A ChildrenRankingSignalsFunctionScoreExpression instance on success. 371 // - FAILED_PRECONDITION on any null pointer in children. 372 // - INVALID_ARGUMENT on type errors. 373 static libtextclassifier3::StatusOr< 374 std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>> 375 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 376 const JoinChildrenFetcher* join_children_fetcher); 377 378 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 379 const DocHitInfo& hit_info, 380 const DocHitInfoIterator* query_it) const override; 381 type()382 ScoreExpressionType type() const override { 383 return ScoreExpressionType::kDoubleList; 384 } 385 386 private: ChildrenRankingSignalsFunctionScoreExpression(const JoinChildrenFetcher & join_children_fetcher)387 explicit ChildrenRankingSignalsFunctionScoreExpression( 388 const JoinChildrenFetcher& join_children_fetcher) 389 : join_children_fetcher_(join_children_fetcher) {} 390 const JoinChildrenFetcher& join_children_fetcher_; 391 }; 392 393 class PropertyWeightsFunctionScoreExpression : public ScoreExpression { 394 public: 395 static constexpr std::string_view kFunctionName = "propertyWeights"; 396 397 // RETURNS: 398 // - A PropertyWeightsFunctionScoreExpression instance on success. 399 // - FAILED_PRECONDITION on any null pointer in children. 400 // - INVALID_ARGUMENT on type errors. 401 static libtextclassifier3::StatusOr< 402 std::unique_ptr<PropertyWeightsFunctionScoreExpression>> 403 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 404 const DocumentStore* document_store, 405 const SectionWeights* section_weights, int64_t current_time_ms); 406 407 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 408 const DocHitInfo& hit_info, const DocHitInfoIterator*) const override; 409 type()410 ScoreExpressionType type() const override { 411 return ScoreExpressionType::kDoubleList; 412 } 413 414 SchemaTypeId GetSchemaTypeId(DocumentId document_id) const; 415 416 private: PropertyWeightsFunctionScoreExpression(const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)417 explicit PropertyWeightsFunctionScoreExpression( 418 const DocumentStore* document_store, 419 const SectionWeights* section_weights, int64_t current_time_ms) 420 : document_store_(*document_store), 421 section_weights_(*section_weights), 422 current_time_ms_(current_time_ms) {} 423 const DocumentStore& document_store_; 424 const SectionWeights& section_weights_; 425 int64_t current_time_ms_; 426 }; 427 428 class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression { 429 public: 430 static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding"; 431 432 // RETURNS: 433 // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if 434 // not simplifiable. 435 // - A ConstantScoreExpression instance on success if simplifiable. 436 // - FAILED_PRECONDITION on any null pointer in children. 437 // - INVALID_ARGUMENT on type errors. 438 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 439 std::vector<std::unique_ptr<ScoreExpression>> args); 440 441 libtextclassifier3::StatusOr<double> EvaluateDouble( 442 const DocHitInfo& hit_info, 443 const DocHitInfoIterator* query_it) const override; 444 type()445 ScoreExpressionType type() const override { 446 return ScoreExpressionType::kVectorIndex; 447 } 448 449 private: GetSearchSpecEmbeddingFunctionScoreExpression(std::unique_ptr<ScoreExpression> arg)450 explicit GetSearchSpecEmbeddingFunctionScoreExpression( 451 std::unique_ptr<ScoreExpression> arg) 452 : arg_(std::move(arg)) {} 453 std::unique_ptr<ScoreExpression> arg_; 454 }; 455 456 class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression { 457 public: 458 static constexpr std::string_view kFunctionName = "matchedSemanticScores"; 459 460 // RETURNS: 461 // - A MatchedSemanticScoresFunctionScoreExpression instance on success. 462 // - FAILED_PRECONDITION on any null pointer in children. 463 // - INVALID_ARGUMENT on type errors. 464 static libtextclassifier3::StatusOr< 465 std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> 466 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 467 SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, 468 const EmbeddingQueryResults* embedding_query_results); 469 470 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 471 const DocHitInfo& hit_info, 472 const DocHitInfoIterator* query_it) const override; 473 type()474 ScoreExpressionType type() const override { 475 return ScoreExpressionType::kDoubleList; 476 } 477 478 private: MatchedSemanticScoresFunctionScoreExpression(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,const EmbeddingQueryResults & embedding_query_results)479 explicit MatchedSemanticScoresFunctionScoreExpression( 480 std::vector<std::unique_ptr<ScoreExpression>> args, 481 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 482 const EmbeddingQueryResults& embedding_query_results) 483 : args_(std::move(args)), 484 metric_type_(metric_type), 485 embedding_query_results_(embedding_query_results) {} 486 487 std::vector<std::unique_ptr<ScoreExpression>> args_; 488 const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; 489 const EmbeddingQueryResults& embedding_query_results_; 490 }; 491 492 } // namespace lib 493 } // namespace icing 494 495 #endif // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 496