• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
17 
18 #include <cstring>
19 #include <list>
20 #include <vector>
21 
22 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
23 
24 namespace tensorflow {
25 namespace boosted_trees {
26 namespace quantiles {
27 
28 // Summary holding a sorted block of entries with upper bound guarantees
29 // over the approximation error.
30 template <typename ValueType, typename WeightType,
31           typename CompareFn = std::less<ValueType>>
32 class WeightedQuantilesSummary {
33  public:
34   using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
35   using BufferEntry = typename Buffer::BufferEntry;
36 
37   struct SummaryEntry {
SummaryEntrySummaryEntry38     SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
39                  const WeightType& max) {
40       // Explicitly initialize all of memory (including padding from memory
41       // alignment) to allow the struct to be msan-resistant "plain old data".
42       //
43       // POD = https://en.cppreference.com/w/cpp/named_req/PODType
44       memset(this, 0, sizeof(*this));
45 
46       value = v;
47       weight = w;
48       min_rank = min;
49       max_rank = max;
50     }
51 
SummaryEntrySummaryEntry52     SummaryEntry() {
53       memset(this, 0, sizeof(*this));
54 
55       value = ValueType();
56       weight = 0;
57       min_rank = 0;
58       max_rank = 0;
59     }
60 
61     bool operator==(const SummaryEntry& other) const {
62       return value == other.value && weight == other.weight &&
63              min_rank == other.min_rank && max_rank == other.max_rank;
64     }
65     friend std::ostream& operator<<(std::ostream& strm,
66                                     const SummaryEntry& entry) {
67       return strm << "{" << entry.value << ", " << entry.weight << ", "
68                   << entry.min_rank << ", " << entry.max_rank << "}";
69     }
70 
71     // Max rank estimate for previous smaller value.
PrevMaxRankSummaryEntry72     WeightType PrevMaxRank() const { return max_rank - weight; }
73 
74     // Min rank estimate for next larger value.
NextMinRankSummaryEntry75     WeightType NextMinRank() const { return min_rank + weight; }
76 
77     ValueType value;
78     WeightType weight;
79     WeightType min_rank;
80     WeightType max_rank;
81   };
82 
83   // Re-construct summary from the specified buffer.
BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)84   void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
85     entries_.clear();
86     entries_.reserve(buffer_entries.size());
87     WeightType cumulative_weight = 0;
88     for (const auto& entry : buffer_entries) {
89       WeightType current_weight = entry.weight;
90       entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
91                             cumulative_weight + current_weight);
92       cumulative_weight += current_weight;
93     }
94   }
95 
96   // Re-construct summary from the specified summary entries.
BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)97   void BuildFromSummaryEntries(
98       const std::vector<SummaryEntry>& summary_entries) {
99     entries_.clear();
100     entries_.reserve(summary_entries.size());
101     entries_.insert(entries_.begin(), summary_entries.begin(),
102                     summary_entries.end());
103   }
104 
105   // Merges two summaries through an algorithm that's derived from MergeSort
106   // for summary entries while guaranteeing that the max approximation error
107   // of the final merged summary is no greater than the approximation errors
108   // of each individual summary.
109   // For example consider summaries where each entry is of the form
110   // (element, weight, min rank, max rank):
111   // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
112   // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
113   // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
Merge(const WeightedQuantilesSummary & other_summary)114   void Merge(const WeightedQuantilesSummary& other_summary) {
115     // Make sure we have something to merge.
116     const auto& other_entries = other_summary.entries_;
117     if (other_entries.empty()) {
118       return;
119     }
120     if (entries_.empty()) {
121       BuildFromSummaryEntries(other_summary.entries_);
122       return;
123     }
124 
125     // Move current entries to make room for a new buffer.
126     std::vector<SummaryEntry> base_entries(std::move(entries_));
127     entries_.clear();
128     entries_.reserve(base_entries.size() + other_entries.size());
129 
130     // Merge entries maintaining ranks. The idea is to stack values
131     // in order which we can do in linear time as the two summaries are
132     // already sorted. We keep track of the next lower rank from either
133     // summary and update it as we pop elements from the summaries.
134     // We handle the special case when the next two elements from either
135     // summary are equal, in which case we just merge the two elements
136     // and simultaneously update both ranks.
137     auto it1 = base_entries.cbegin();
138     auto it2 = other_entries.cbegin();
139     WeightType next_min_rank1 = 0;
140     WeightType next_min_rank2 = 0;
141     while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
142       if (kCompFn(it1->value, it2->value)) {  // value1 < value2
143         // Take value1 and use the last added value2 to compute
144         // the min rank and the current value2 to compute the max rank.
145         entries_.emplace_back(it1->value, it1->weight,
146                               it1->min_rank + next_min_rank2,
147                               it1->max_rank + it2->PrevMaxRank());
148         // Update next min rank 1.
149         next_min_rank1 = it1->NextMinRank();
150         ++it1;
151       } else if (kCompFn(it2->value, it1->value)) {  // value1 > value2
152         // Take value2 and use the last added value1 to compute
153         // the min rank and the current value1 to compute the max rank.
154         entries_.emplace_back(it2->value, it2->weight,
155                               it2->min_rank + next_min_rank1,
156                               it2->max_rank + it1->PrevMaxRank());
157         // Update next min rank 2.
158         next_min_rank2 = it2->NextMinRank();
159         ++it2;
160       } else {  // value1 == value2
161         // Straight additive merger of the two entries into one.
162         entries_.emplace_back(it1->value, it1->weight + it2->weight,
163                               it1->min_rank + it2->min_rank,
164                               it1->max_rank + it2->max_rank);
165         // Update next min ranks for both.
166         next_min_rank1 = it1->NextMinRank();
167         next_min_rank2 = it2->NextMinRank();
168         ++it1;
169         ++it2;
170       }
171     }
172 
173     // Fill in any residual.
174     while (it1 != base_entries.cend()) {
175       entries_.emplace_back(it1->value, it1->weight,
176                             it1->min_rank + next_min_rank2,
177                             it1->max_rank + other_entries.back().max_rank);
178       ++it1;
179     }
180     while (it2 != other_entries.cend()) {
181       entries_.emplace_back(it2->value, it2->weight,
182                             it2->min_rank + next_min_rank1,
183                             it2->max_rank + base_entries.back().max_rank);
184       ++it2;
185     }
186   }
187 
188   // Compresses buffer into desired size. The size specification is
189   // considered a hint as we always keep the first and last elements and
190   // maintain strict approximation error bounds.
191   // The approximation error delta is taken as the max of either the requested
192   // min error or 1 / size_hint.
193   // After compression, the approximation error is guaranteed to increase
194   // by no more than that error delta.
195   // This algorithm is linear in the original size of the summary and is
196   // designed to be cache-friendly.
197   void Compress(int64 size_hint, double min_eps = 0) {
198     // No-op if we're already within the size requirement.
199     size_hint = std::max(size_hint, int64{2});
200     if (entries_.size() <= size_hint) {
201       return;
202     }
203 
204     // First compute the max error bound delta resulting from this compression.
205     double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
206 
207     // Compress elements ensuring approximation bounds and elements diversity
208     // are both maintained.
209     int64 add_accumulator = 0, add_step = entries_.size();
210     auto write_it = entries_.begin() + 1, last_it = write_it;
211     for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
212       auto next_it = read_it + 1;
213       while (next_it != entries_.end() && add_accumulator < add_step &&
214              next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
215         add_accumulator += size_hint;
216         ++next_it;
217       }
218       if (read_it == next_it - 1) {
219         ++read_it;
220       } else {
221         read_it = next_it - 1;
222       }
223       (*write_it++) = (*read_it);
224       last_it = read_it;
225       add_accumulator -= add_step;
226     }
227     // Write last element and resize.
228     if (last_it + 1 != entries_.end()) {
229       (*write_it++) = entries_.back();
230     }
231     entries_.resize(write_it - entries_.begin());
232   }
233 
234   // To construct the boundaries we first run a soft compress over a copy
235   // of the summary and retrieve the values.
236   // The resulting boundaries are guaranteed to both contain at least
237   // num_boundaries unique elements and maintain approximation bounds.
GenerateBoundaries(int64 num_boundaries)238   std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
239     std::vector<ValueType> output;
240     if (entries_.empty()) {
241       return output;
242     }
243 
244     // Generate soft compressed summary.
245     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
246         compressed_summary;
247     compressed_summary.BuildFromSummaryEntries(entries_);
248     // Set an epsilon for compression that's at most 1.0 / num_boundaries
249     // more than epsilon of original our summary since the compression operation
250     // adds ~1.0/num_boundaries to final approximation error.
251     float compression_eps = ApproximationError() + (1.0 / num_boundaries);
252     compressed_summary.Compress(num_boundaries, compression_eps);
253 
254     // Remove the least important boundaries by the gap removing them would
255     // create.
256     std::list<int64> boundaries_to_keep;
257     for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) {
258       boundaries_to_keep.push_back(i);
259     }
260     while (boundaries_to_keep.size() > num_boundaries) {
261       std::list<int64>::iterator min_element = boundaries_to_keep.end();
262       auto prev = boundaries_to_keep.begin();
263       auto curr = prev;
264       ++curr;
265       auto next = curr;
266       ++next;
267       WeightType min_weight = TotalWeight();
268       for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) {
269         WeightType new_weight =
270             compressed_summary.entries_[*next].PrevMaxRank() -
271             compressed_summary.entries_[*prev].NextMinRank();
272         if (new_weight < min_weight) {
273           min_element = curr;
274           min_weight = new_weight;
275         }
276       }
277       boundaries_to_keep.erase(min_element);
278     }
279 
280     // Return boundaries.
281     output.reserve(boundaries_to_keep.size());
282     for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end();
283          ++itr) {
284       output.push_back(compressed_summary.entries_[*itr].value);
285     }
286     return output;
287   }
288 
289   // To construct the desired n-quantiles we repetitively query n ranks from the
290   // original summary. The following algorithm is an efficient cache-friendly
291   // O(n) implementation of that idea which avoids the cost of the repetitive
292   // full rank queries O(nlogn).
GenerateQuantiles(int64 num_quantiles)293   std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
294     std::vector<ValueType> output;
295     if (entries_.empty()) {
296       return output;
297     }
298     num_quantiles = std::max(num_quantiles, int64{2});
299     output.reserve(num_quantiles + 1);
300 
301     // Make successive rank queries to get boundaries.
302     // We always keep the first (min) and last (max) entries.
303     for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
304       // This step boils down to finding the next element sub-range defined by
305       // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
306       WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
307       size_t next_idx = cur_idx + 1;
308       while (next_idx < entries_.size() &&
309              d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
310         ++next_idx;
311       }
312       cur_idx = next_idx - 1;
313 
314       // Determine insertion order.
315       if (next_idx == entries_.size() ||
316           d_2 < entries_[cur_idx].NextMinRank() +
317                     entries_[next_idx].PrevMaxRank()) {
318         output.push_back(entries_[cur_idx].value);
319       } else {
320         output.push_back(entries_[next_idx].value);
321       }
322     }
323     return output;
324   }
325 
326   // Calculates current approximation error which should always be <= eps.
ApproximationError()327   double ApproximationError() const {
328     if (entries_.empty()) {
329       return 0;
330     }
331 
332     WeightType max_gap = 0;
333     for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
334       max_gap = std::max(max_gap,
335                          std::max(it->max_rank - it->min_rank - it->weight,
336                                   it->PrevMaxRank() - (it - 1)->NextMinRank()));
337     }
338     return static_cast<double>(max_gap) / TotalWeight();
339   }
340 
MinValue()341   ValueType MinValue() const {
342     return !entries_.empty() ? entries_.front().value
343                              : std::numeric_limits<ValueType>::max();
344   }
MaxValue()345   ValueType MaxValue() const {
346     return !entries_.empty() ? entries_.back().value
347                              : std::numeric_limits<ValueType>::max();
348   }
TotalWeight()349   WeightType TotalWeight() const {
350     return !entries_.empty() ? entries_.back().max_rank : 0;
351   }
Size()352   int64 Size() const { return entries_.size(); }
Clear()353   void Clear() { entries_.clear(); }
GetEntryList()354   const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
355 
356  private:
357   // Comparison function.
358   static constexpr decltype(CompareFn()) kCompFn = CompareFn();
359 
360   // Summary entries.
361   std::vector<SummaryEntry> entries_;
362 };
363 
364 template <typename ValueType, typename WeightType, typename CompareFn>
365 constexpr decltype(CompareFn())
366     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
367 
368 }  // namespace quantiles
369 }  // namespace boosted_trees
370 }  // namespace tensorflow
371 
372 #endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
373