1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/join/aggregation-scorer.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21
22 #include "icing/proto/search.pb.h"
23 #include "icing/scoring/scored-document-hit.h"
24
25 namespace icing {
26 namespace lib {
27
28 class CountAggregationScorer : public AggregationScorer {
29 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)30 double GetScore(const ScoredDocumentHit& parent,
31 const std::vector<ScoredDocumentHit>& children) override {
32 return children.size();
33 }
34 };
35
36 class MinAggregationScorer : public AggregationScorer {
37 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)38 double GetScore(const ScoredDocumentHit& parent,
39 const std::vector<ScoredDocumentHit>& children) override {
40 if (children.empty()) {
41 // Return 0 if there is no child document.
42 // For non-empty children with negative scores, they are considered "worse
43 // than" 0, so it is correct to return 0 for empty children to assign it a
44 // rank higher than non-empty children with negative scores.
45 return 0.0;
46 }
47 return std::min_element(children.begin(), children.end(),
48 [](const ScoredDocumentHit& lhs,
49 const ScoredDocumentHit& rhs) -> bool {
50 return lhs.score() < rhs.score();
51 })
52 ->score();
53 }
54 };
55
56 class AverageAggregationScorer : public AggregationScorer {
57 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)58 double GetScore(const ScoredDocumentHit& parent,
59 const std::vector<ScoredDocumentHit>& children) override {
60 if (children.empty()) {
61 // Return 0 if there is no child document.
62 // For non-empty children with negative scores, they are considered "worse
63 // than" 0, so it is correct to return 0 for empty children to assign it a
64 // rank higher than non-empty children with negative scores.
65 return 0.0;
66 }
67 return std::reduce(
68 children.begin(), children.end(), 0.0,
69 [](double prev, const ScoredDocumentHit& item) -> double {
70 return prev + item.score();
71 }) /
72 children.size();
73 }
74 };
75
76 class MaxAggregationScorer : public AggregationScorer {
77 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)78 double GetScore(const ScoredDocumentHit& parent,
79 const std::vector<ScoredDocumentHit>& children) override {
80 if (children.empty()) {
81 // Return 0 if there is no child document.
82 // For non-empty children with negative scores, they are considered "worse
83 // than" 0, so it is correct to return 0 for empty children to assign it a
84 // rank higher than non-empty children with negative scores.
85 return 0.0;
86 }
87 return std::max_element(children.begin(), children.end(),
88 [](const ScoredDocumentHit& lhs,
89 const ScoredDocumentHit& rhs) -> bool {
90 return lhs.score() < rhs.score();
91 })
92 ->score();
93 }
94 };
95
96 class SumAggregationScorer : public AggregationScorer {
97 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)98 double GetScore(const ScoredDocumentHit& parent,
99 const std::vector<ScoredDocumentHit>& children) override {
100 return std::reduce(
101 children.begin(), children.end(), 0.0,
102 [](double prev, const ScoredDocumentHit& item) -> double {
103 return prev + item.score();
104 });
105 }
106 };
107
108 class DefaultAggregationScorer : public AggregationScorer {
109 public:
GetScore(const ScoredDocumentHit & parent,const std::vector<ScoredDocumentHit> & children)110 double GetScore(const ScoredDocumentHit& parent,
111 const std::vector<ScoredDocumentHit>& children) override {
112 return parent.score();
113 }
114 };
115
Create(const JoinSpecProto & join_spec)116 std::unique_ptr<AggregationScorer> AggregationScorer::Create(
117 const JoinSpecProto& join_spec) {
118 switch (join_spec.aggregation_scoring_strategy()) {
119 case JoinSpecProto::AggregationScoringStrategy::COUNT:
120 return std::make_unique<CountAggregationScorer>();
121 case JoinSpecProto::AggregationScoringStrategy::MIN:
122 return std::make_unique<MinAggregationScorer>();
123 case JoinSpecProto::AggregationScoringStrategy::AVG:
124 return std::make_unique<AverageAggregationScorer>();
125 case JoinSpecProto::AggregationScoringStrategy::MAX:
126 return std::make_unique<MaxAggregationScorer>();
127 case JoinSpecProto::AggregationScoringStrategy::SUM:
128 return std::make_unique<SumAggregationScorer>();
129 case JoinSpecProto::AggregationScoringStrategy::NONE:
130 // No aggregation strategy means using parent document score, so fall
131 // through to return DefaultAggregationScorer.
132 [[fallthrough]];
133 default:
134 return std::make_unique<DefaultAggregationScorer>();
135 }
136 }
137
138 } // namespace lib
139 } // namespace icing
140