1 // Copyright (C) 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 16 #define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "icing/text_classifier/lib3/utils/base/status.h" 22 #include "icing/text_classifier/lib3/utils/base/statusor.h" 23 #include "icing/join/join-children-fetcher.h" 24 #include "icing/schema/schema-store.h" 25 #include "icing/scoring/advanced_scoring/score-expression.h" 26 #include "icing/scoring/bm25f-calculator.h" 27 #include "icing/scoring/scorer.h" 28 #include "icing/store/document-store.h" 29 30 namespace icing { 31 namespace lib { 32 33 class AdvancedScorer : public Scorer { 34 public: 35 // Returns: 36 // A AdvancedScorer instance on success 37 // FAILED_PRECONDITION on any null pointer input 38 // INVALID_ARGUMENT if fails to create an instance 39 static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create( 40 const ScoringSpecProto& scoring_spec, double default_score, 41 const DocumentStore* document_store, const SchemaStore* schema_store, 42 int64_t current_time_ms, 43 const JoinChildrenFetcher* join_children_fetcher = nullptr); 44 GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)45 double GetScore(const DocHitInfo& hit_info, 46 const DocHitInfoIterator* query_it) override { 47 libtextclassifier3::StatusOr<double> result = 48 score_expression_->eval(hit_info, query_it); 49 if (!result.ok()) { 50 ICING_LOG(ERROR) << "Got an error when scoring a document:\n" 51 << result.status().error_message(); 52 return default_score_; 53 } 54 return std::move(result).ValueOrDie(); 55 } 56 PrepareToScore(std::unordered_map<std::string,std::unique_ptr<DocHitInfoIterator>> * query_term_iterators)57 void PrepareToScore( 58 std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* 59 query_term_iterators) override { 60 if (query_term_iterators == nullptr || query_term_iterators->empty()) { 61 return; 62 } 63 bm25f_calculator_->PrepareToScore(query_term_iterators); 64 } 65 is_constant()66 bool is_constant() const { return score_expression_->is_constant_double(); } 67 68 private: AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,std::unique_ptr<SectionWeights> section_weights,std::unique_ptr<Bm25fCalculator> bm25f_calculator,double default_score)69 explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, 70 std::unique_ptr<SectionWeights> section_weights, 71 std::unique_ptr<Bm25fCalculator> bm25f_calculator, 72 double default_score) 73 : score_expression_(std::move(score_expression)), 74 section_weights_(std::move(section_weights)), 75 bm25f_calculator_(std::move(bm25f_calculator)), 76 default_score_(default_score) { 77 if (is_constant()) { 78 ICING_LOG(WARNING) 79 << "The advanced scoring expression will evaluate to a constant."; 80 } 81 } 82 83 std::unique_ptr<ScoreExpression> score_expression_; 84 std::unique_ptr<SectionWeights> section_weights_; 85 std::unique_ptr<Bm25fCalculator> bm25f_calculator_; 86 double default_score_; 87 }; 88 89 } // namespace lib 90 } // namespace icing 91 92 #endif // ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 93