• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/join/aggregation-scorer.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21 
22 #include "icing/proto/search.pb.h"
23 #include "icing/scoring/scored-document-hit.h"
24 
25 namespace icing {
26 namespace lib {
27 
28 class CountAggregationScorer : public AggregationScorer {
29  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)30   double GetScore(const ScoredDocumentHit& parent,
31                   const std::vector<ScoredDocumentHit>& children) override {
32     return children.size();
33   }
34 };
35 
36 class MinAggregationScorer : public AggregationScorer {
37  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)38   double GetScore(const ScoredDocumentHit& parent,
39                   const std::vector<ScoredDocumentHit>& children) override {
40     if (children.empty()) {
41       // Return 0 if there is no child document.
42       // For non-empty children with negative scores, they are considered "worse
43       // than" 0, so it is correct to return 0 for empty children to assign it a
44       // rank higher than non-empty children with negative scores.
45       return 0.0;
46     }
47     return std::min_element(children.begin(), children.end(),
48                             [](const ScoredDocumentHit& lhs,
49                                const ScoredDocumentHit& rhs) -> bool {
50                               return lhs.score() < rhs.score();
51                             })
52         ->score();
53   }
54 };
55 
56 class AverageAggregationScorer : public AggregationScorer {
57  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)58   double GetScore(const ScoredDocumentHit& parent,
59                   const std::vector<ScoredDocumentHit>& children) override {
60     if (children.empty()) {
61       // Return 0 if there is no child document.
62       // For non-empty children with negative scores, they are considered "worse
63       // than" 0, so it is correct to return 0 for empty children to assign it a
64       // rank higher than non-empty children with negative scores.
65       return 0.0;
66     }
67     return std::reduce(
68                children.begin(), children.end(), 0.0,
69                [](double prev, const ScoredDocumentHit& item) -> double {
70                  return prev + item.score();
71                }) /
72            children.size();
73   }
74 };
75 
76 class MaxAggregationScorer : public AggregationScorer {
77  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)78   double GetScore(const ScoredDocumentHit& parent,
79                   const std::vector<ScoredDocumentHit>& children) override {
80     if (children.empty()) {
81       // Return 0 if there is no child document.
82       // For non-empty children with negative scores, they are considered "worse
83       // than" 0, so it is correct to return 0 for empty children to assign it a
84       // rank higher than non-empty children with negative scores.
85       return 0.0;
86     }
87     return std::max_element(children.begin(), children.end(),
88                             [](const ScoredDocumentHit& lhs,
89                                const ScoredDocumentHit& rhs) -> bool {
90                               return lhs.score() < rhs.score();
91                             })
92         ->score();
93   }
94 };
95 
96 class SumAggregationScorer : public AggregationScorer {
97  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)98   double GetScore(const ScoredDocumentHit& parent,
99                   const std::vector<ScoredDocumentHit>& children) override {
100     return std::reduce(
101         children.begin(), children.end(), 0.0,
102         [](double prev, const ScoredDocumentHit& item) -> double {
103           return prev + item.score();
104         });
105   }
106 };
107 
108 class DefaultAggregationScorer : public AggregationScorer {
109  public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)110   double GetScore(const ScoredDocumentHit& parent,
111                   const std::vector<ScoredDocumentHit>& children) override {
112     return parent.score();
113   }
114 };
115 
Create(const JoinSpecProto & join_spec)116 std::unique_ptr<AggregationScorer> AggregationScorer::Create(
117     const JoinSpecProto& join_spec) {
118   switch (join_spec.aggregation_scoring_strategy()) {
119     case JoinSpecProto::AggregationScoringStrategy::COUNT:
120       return std::make_unique<CountAggregationScorer>();
121     case JoinSpecProto::AggregationScoringStrategy::MIN:
122       return std::make_unique<MinAggregationScorer>();
123     case JoinSpecProto::AggregationScoringStrategy::AVG:
124       return std::make_unique<AverageAggregationScorer>();
125     case JoinSpecProto::AggregationScoringStrategy::MAX:
126       return std::make_unique<MaxAggregationScorer>();
127     case JoinSpecProto::AggregationScoringStrategy::SUM:
128       return std::make_unique<SumAggregationScorer>();
129     case JoinSpecProto::AggregationScoringStrategy::NONE:
130       // No aggregation strategy means using parent document score, so fall
131       // through to return DefaultAggregationScorer.
132       [[fallthrough]];
133     default:
134       return std::make_unique<DefaultAggregationScorer>();
135   }
136 }
137 
138 }  // namespace lib
139 }  // namespace icing
140