1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 17 18 #include <cstring> 19 #include <vector> 20 21 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" 22 23 namespace tensorflow { 24 namespace boosted_trees { 25 namespace quantiles { 26 27 // Summary holding a sorted block of entries with upper bound guarantees 28 // over the approximation error. 29 template <typename ValueType, typename WeightType, 30 typename CompareFn = std::less<ValueType>> 31 class WeightedQuantilesSummary { 32 public: 33 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>; 34 using BufferEntry = typename Buffer::BufferEntry; 35 36 struct SummaryEntry { SummaryEntrySummaryEntry37 SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, 38 const WeightType& max) { 39 value = v; 40 weight = w; 41 min_rank = min; 42 max_rank = max; 43 } 44 SummaryEntrySummaryEntry45 SummaryEntry() { 46 value = ValueType(); 47 weight = 0; 48 min_rank = 0; 49 max_rank = 0; 50 } 51 52 bool operator==(const SummaryEntry& other) const { 53 return value == other.value && weight == other.weight && 54 min_rank == other.min_rank && max_rank == other.max_rank; 55 } 56 friend std::ostream& operator<<(std::ostream& strm, 57 const SummaryEntry& entry) { 58 return strm << "{" << entry.value << ", " << entry.weight << ", " 59 << entry.min_rank << ", " << entry.max_rank << "}"; 60 } 61 62 // Max rank estimate for previous smaller value. PrevMaxRankSummaryEntry63 WeightType PrevMaxRank() const { return max_rank - weight; } 64 65 // Min rank estimate for next larger value. NextMinRankSummaryEntry66 WeightType NextMinRank() const { return min_rank + weight; } 67 68 ValueType value; 69 WeightType weight; 70 WeightType min_rank; 71 WeightType max_rank; 72 }; 73 74 // Re-construct summary from the specified buffer. BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)75 void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) { 76 entries_.clear(); 77 entries_.reserve(buffer_entries.size()); 78 WeightType cumulative_weight = 0; 79 for (const auto& entry : buffer_entries) { 80 WeightType current_weight = entry.weight; 81 entries_.emplace_back(entry.value, entry.weight, cumulative_weight, 82 cumulative_weight + current_weight); 83 cumulative_weight += current_weight; 84 } 85 } 86 87 // Re-construct summary from the specified summary entries. BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)88 void BuildFromSummaryEntries( 89 const std::vector<SummaryEntry>& summary_entries) { 90 entries_.clear(); 91 entries_.reserve(summary_entries.size()); 92 entries_.insert(entries_.begin(), summary_entries.begin(), 93 summary_entries.end()); 94 } 95 96 // Merges two summaries through an algorithm that's derived from MergeSort 97 // for summary entries while guaranteeing that the max approximation error 98 // of the final merged summary is no greater than the approximation errors 99 // of each individual summary. 100 // For example consider summaries where each entry is of the form 101 // (element, weight, min rank, max rank): 102 // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5) 103 // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2) 104 // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7). Merge(const WeightedQuantilesSummary & other_summary)105 void Merge(const WeightedQuantilesSummary& other_summary) { 106 // Make sure we have something to merge. 107 const auto& other_entries = other_summary.entries_; 108 if (other_entries.empty()) { 109 return; 110 } 111 if (entries_.empty()) { 112 BuildFromSummaryEntries(other_summary.entries_); 113 return; 114 } 115 116 // Move current entries to make room for a new buffer. 117 std::vector<SummaryEntry> base_entries(std::move(entries_)); 118 entries_.clear(); 119 entries_.reserve(base_entries.size() + other_entries.size()); 120 121 // Merge entries maintaining ranks. The idea is to stack values 122 // in order which we can do in linear time as the two summaries are 123 // already sorted. We keep track of the next lower rank from either 124 // summary and update it as we pop elements from the summaries. 125 // We handle the special case when the next two elements from either 126 // summary are equal, in which case we just merge the two elements 127 // and simultaneously update both ranks. 128 auto it1 = base_entries.cbegin(); 129 auto it2 = other_entries.cbegin(); 130 WeightType next_min_rank1 = 0; 131 WeightType next_min_rank2 = 0; 132 while (it1 != base_entries.cend() && it2 != other_entries.cend()) { 133 if (kCompFn(it1->value, it2->value)) { // value1 < value2 134 // Take value1 and use the last added value2 to compute 135 // the min rank and the current value2 to compute the max rank. 136 entries_.emplace_back(it1->value, it1->weight, 137 it1->min_rank + next_min_rank2, 138 it1->max_rank + it2->PrevMaxRank()); 139 // Update next min rank 1. 140 next_min_rank1 = it1->NextMinRank(); 141 ++it1; 142 } else if (kCompFn(it2->value, it1->value)) { // value1 > value2 143 // Take value2 and use the last added value1 to compute 144 // the min rank and the current value1 to compute the max rank. 145 entries_.emplace_back(it2->value, it2->weight, 146 it2->min_rank + next_min_rank1, 147 it2->max_rank + it1->PrevMaxRank()); 148 // Update next min rank 2. 149 next_min_rank2 = it2->NextMinRank(); 150 ++it2; 151 } else { // value1 == value2 152 // Straight additive merger of the two entries into one. 153 entries_.emplace_back(it1->value, it1->weight + it2->weight, 154 it1->min_rank + it2->min_rank, 155 it1->max_rank + it2->max_rank); 156 // Update next min ranks for both. 157 next_min_rank1 = it1->NextMinRank(); 158 next_min_rank2 = it2->NextMinRank(); 159 ++it1; 160 ++it2; 161 } 162 } 163 164 // Fill in any residual. 165 while (it1 != base_entries.cend()) { 166 entries_.emplace_back(it1->value, it1->weight, 167 it1->min_rank + next_min_rank2, 168 it1->max_rank + other_entries.back().max_rank); 169 ++it1; 170 } 171 while (it2 != other_entries.cend()) { 172 entries_.emplace_back(it2->value, it2->weight, 173 it2->min_rank + next_min_rank1, 174 it2->max_rank + base_entries.back().max_rank); 175 ++it2; 176 } 177 } 178 179 // Compresses buffer into desired size. The size specification is 180 // considered a hint as we always keep the first and last elements and 181 // maintain strict approximation error bounds. 182 // The approximation error delta is taken as the max of either the requested 183 // min error or 1 / size_hint. 184 // After compression, the approximation error is guaranteed to increase 185 // by no more than that error delta. 186 // This algorithm is linear in the original size of the summary and is 187 // designed to be cache-friendly. 188 void Compress(int64 size_hint, double min_eps = 0) { 189 // No-op if we're already within the size requirement. 190 size_hint = std::max(size_hint, int64{2}); 191 if (entries_.size() <= size_hint) { 192 return; 193 } 194 195 // First compute the max error bound delta resulting from this compression. 196 double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps); 197 198 // Compress elements ensuring approximation bounds and elements diversity 199 // are both maintained. 200 int64 add_accumulator = 0, add_step = entries_.size(); 201 auto write_it = entries_.begin() + 1, last_it = write_it; 202 for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) { 203 auto next_it = read_it + 1; 204 while (next_it != entries_.end() && add_accumulator < add_step && 205 next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) { 206 add_accumulator += size_hint; 207 ++next_it; 208 } 209 if (read_it == next_it - 1) { 210 ++read_it; 211 } else { 212 read_it = next_it - 1; 213 } 214 (*write_it++) = (*read_it); 215 last_it = read_it; 216 add_accumulator -= add_step; 217 } 218 // Write last element and resize. 219 if (last_it + 1 != entries_.end()) { 220 (*write_it++) = entries_.back(); 221 } 222 entries_.resize(write_it - entries_.begin()); 223 } 224 225 // To construct the boundaries we first run a soft compress over a copy 226 // of the summary and retrieve the values. 227 // The resulting boundaries are guaranteed to both contain at least 228 // num_boundaries unique elements and maintain approximation bounds. GenerateBoundaries(int64 num_boundaries)229 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const { 230 std::vector<ValueType> output; 231 if (entries_.empty()) { 232 return output; 233 } 234 235 // Generate soft compressed summary. 236 WeightedQuantilesSummary<ValueType, WeightType, CompareFn> 237 compressed_summary; 238 compressed_summary.BuildFromSummaryEntries(entries_); 239 // Set an epsilon for compression that's at most 1.0 / num_boundaries 240 // more than epsilon of original our summary since the compression operation 241 // adds ~1.0/num_boundaries to final approximation error. 242 float compression_eps = ApproximationError() + (1.0 / num_boundaries); 243 compressed_summary.Compress(num_boundaries, compression_eps); 244 245 // Return boundaries. 246 output.reserve(compressed_summary.entries_.size()); 247 for (const auto& entry : compressed_summary.entries_) { 248 output.push_back(entry.value); 249 } 250 return output; 251 } 252 253 // To construct the desired n-quantiles we repetitively query n ranks from the 254 // original summary. The following algorithm is an efficient cache-friendly 255 // O(n) implementation of that idea which avoids the cost of the repetitive 256 // full rank queries O(nlogn). GenerateQuantiles(int64 num_quantiles)257 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const { 258 std::vector<ValueType> output; 259 if (entries_.empty()) { 260 return output; 261 } 262 num_quantiles = std::max(num_quantiles, int64{2}); 263 output.reserve(num_quantiles + 1); 264 265 // Make successive rank queries to get boundaries. 266 // We always keep the first (min) and last (max) entries. 267 for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) { 268 // This step boils down to finding the next element sub-range defined by 269 // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r. 270 WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles); 271 size_t next_idx = cur_idx + 1; 272 while (next_idx < entries_.size() && 273 d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) { 274 ++next_idx; 275 } 276 cur_idx = next_idx - 1; 277 278 // Determine insertion order. 279 if (next_idx == entries_.size() || 280 d_2 < entries_[cur_idx].NextMinRank() + 281 entries_[next_idx].PrevMaxRank()) { 282 output.push_back(entries_[cur_idx].value); 283 } else { 284 output.push_back(entries_[next_idx].value); 285 } 286 } 287 return output; 288 } 289 290 // Calculates current approximation error which should always be <= eps. ApproximationError()291 double ApproximationError() const { 292 if (entries_.empty()) { 293 return 0; 294 } 295 296 WeightType max_gap = 0; 297 for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) { 298 max_gap = std::max(max_gap, 299 std::max(it->max_rank - it->min_rank - it->weight, 300 it->PrevMaxRank() - (it - 1)->NextMinRank())); 301 } 302 return static_cast<double>(max_gap) / TotalWeight(); 303 } 304 MinValue()305 ValueType MinValue() const { 306 return !entries_.empty() ? entries_.front().value 307 : std::numeric_limits<ValueType>::max(); 308 } MaxValue()309 ValueType MaxValue() const { 310 return !entries_.empty() ? entries_.back().value 311 : std::numeric_limits<ValueType>::max(); 312 } TotalWeight()313 WeightType TotalWeight() const { 314 return !entries_.empty() ? entries_.back().max_rank : 0; 315 } Size()316 int64 Size() const { return entries_.size(); } Clear()317 void Clear() { entries_.clear(); } GetEntryList()318 const std::vector<SummaryEntry>& GetEntryList() const { return entries_; } 319 320 private: 321 // Comparison function. 322 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 323 324 // Summary entries. 325 std::vector<SummaryEntry> entries_; 326 }; 327 328 template <typename ValueType, typename WeightType, typename CompareFn> 329 constexpr decltype(CompareFn()) 330 WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn; 331 332 } // namespace quantiles 333 } // namespace boosted_trees 334 } // namespace tensorflow 335 336 #endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 337