1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 17 18 #include <cstring> 19 #include <list> 20 #include <vector> 21 22 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h" 23 24 namespace tensorflow { 25 namespace boosted_trees { 26 namespace quantiles { 27 28 // Summary holding a sorted block of entries with upper bound guarantees 29 // over the approximation error. 30 template <typename ValueType, typename WeightType, 31 typename CompareFn = std::less<ValueType>> 32 class WeightedQuantilesSummary { 33 public: 34 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>; 35 using BufferEntry = typename Buffer::BufferEntry; 36 37 struct SummaryEntry { SummaryEntrySummaryEntry38 SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, 39 const WeightType& max) { 40 // Explicitly initialize all of memory (including padding from memory 41 // alignment) to allow the struct to be msan-resistant "plain old data". 42 // 43 // POD = https://en.cppreference.com/w/cpp/named_req/PODType 44 memset(this, 0, sizeof(*this)); 45 46 value = v; 47 weight = w; 48 min_rank = min; 49 max_rank = max; 50 } 51 SummaryEntrySummaryEntry52 SummaryEntry() { 53 memset(this, 0, sizeof(*this)); 54 55 value = ValueType(); 56 weight = 0; 57 min_rank = 0; 58 max_rank = 0; 59 } 60 61 bool operator==(const SummaryEntry& other) const { 62 return value == other.value && weight == other.weight && 63 min_rank == other.min_rank && max_rank == other.max_rank; 64 } 65 friend std::ostream& operator<<(std::ostream& strm, 66 const SummaryEntry& entry) { 67 return strm << "{" << entry.value << ", " << entry.weight << ", " 68 << entry.min_rank << ", " << entry.max_rank << "}"; 69 } 70 71 // Max rank estimate for previous smaller value. PrevMaxRankSummaryEntry72 WeightType PrevMaxRank() const { return max_rank - weight; } 73 74 // Min rank estimate for next larger value. NextMinRankSummaryEntry75 WeightType NextMinRank() const { return min_rank + weight; } 76 77 ValueType value; 78 WeightType weight; 79 WeightType min_rank; 80 WeightType max_rank; 81 }; 82 83 // Re-construct summary from the specified buffer. BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)84 void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) { 85 entries_.clear(); 86 entries_.reserve(buffer_entries.size()); 87 WeightType cumulative_weight = 0; 88 for (const auto& entry : buffer_entries) { 89 WeightType current_weight = entry.weight; 90 entries_.emplace_back(entry.value, entry.weight, cumulative_weight, 91 cumulative_weight + current_weight); 92 cumulative_weight += current_weight; 93 } 94 } 95 96 // Re-construct summary from the specified summary entries. BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)97 void BuildFromSummaryEntries( 98 const std::vector<SummaryEntry>& summary_entries) { 99 entries_.clear(); 100 entries_.reserve(summary_entries.size()); 101 entries_.insert(entries_.begin(), summary_entries.begin(), 102 summary_entries.end()); 103 } 104 105 // Merges two summaries through an algorithm that's derived from MergeSort 106 // for summary entries while guaranteeing that the max approximation error 107 // of the final merged summary is no greater than the approximation errors 108 // of each individual summary. 109 // For example consider summaries where each entry is of the form 110 // (element, weight, min rank, max rank): 111 // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5) 112 // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2) 113 // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7). Merge(const WeightedQuantilesSummary & other_summary)114 void Merge(const WeightedQuantilesSummary& other_summary) { 115 // Make sure we have something to merge. 116 const auto& other_entries = other_summary.entries_; 117 if (other_entries.empty()) { 118 return; 119 } 120 if (entries_.empty()) { 121 BuildFromSummaryEntries(other_summary.entries_); 122 return; 123 } 124 125 // Move current entries to make room for a new buffer. 126 std::vector<SummaryEntry> base_entries(std::move(entries_)); 127 entries_.clear(); 128 entries_.reserve(base_entries.size() + other_entries.size()); 129 130 // Merge entries maintaining ranks. The idea is to stack values 131 // in order which we can do in linear time as the two summaries are 132 // already sorted. We keep track of the next lower rank from either 133 // summary and update it as we pop elements from the summaries. 134 // We handle the special case when the next two elements from either 135 // summary are equal, in which case we just merge the two elements 136 // and simultaneously update both ranks. 137 auto it1 = base_entries.cbegin(); 138 auto it2 = other_entries.cbegin(); 139 WeightType next_min_rank1 = 0; 140 WeightType next_min_rank2 = 0; 141 while (it1 != base_entries.cend() && it2 != other_entries.cend()) { 142 if (kCompFn(it1->value, it2->value)) { // value1 < value2 143 // Take value1 and use the last added value2 to compute 144 // the min rank and the current value2 to compute the max rank. 145 entries_.emplace_back(it1->value, it1->weight, 146 it1->min_rank + next_min_rank2, 147 it1->max_rank + it2->PrevMaxRank()); 148 // Update next min rank 1. 149 next_min_rank1 = it1->NextMinRank(); 150 ++it1; 151 } else if (kCompFn(it2->value, it1->value)) { // value1 > value2 152 // Take value2 and use the last added value1 to compute 153 // the min rank and the current value1 to compute the max rank. 154 entries_.emplace_back(it2->value, it2->weight, 155 it2->min_rank + next_min_rank1, 156 it2->max_rank + it1->PrevMaxRank()); 157 // Update next min rank 2. 158 next_min_rank2 = it2->NextMinRank(); 159 ++it2; 160 } else { // value1 == value2 161 // Straight additive merger of the two entries into one. 162 entries_.emplace_back(it1->value, it1->weight + it2->weight, 163 it1->min_rank + it2->min_rank, 164 it1->max_rank + it2->max_rank); 165 // Update next min ranks for both. 166 next_min_rank1 = it1->NextMinRank(); 167 next_min_rank2 = it2->NextMinRank(); 168 ++it1; 169 ++it2; 170 } 171 } 172 173 // Fill in any residual. 174 while (it1 != base_entries.cend()) { 175 entries_.emplace_back(it1->value, it1->weight, 176 it1->min_rank + next_min_rank2, 177 it1->max_rank + other_entries.back().max_rank); 178 ++it1; 179 } 180 while (it2 != other_entries.cend()) { 181 entries_.emplace_back(it2->value, it2->weight, 182 it2->min_rank + next_min_rank1, 183 it2->max_rank + base_entries.back().max_rank); 184 ++it2; 185 } 186 } 187 188 // Compresses buffer into desired size. The size specification is 189 // considered a hint as we always keep the first and last elements and 190 // maintain strict approximation error bounds. 191 // The approximation error delta is taken as the max of either the requested 192 // min error or 1 / size_hint. 193 // After compression, the approximation error is guaranteed to increase 194 // by no more than that error delta. 195 // This algorithm is linear in the original size of the summary and is 196 // designed to be cache-friendly. 197 void Compress(int64 size_hint, double min_eps = 0) { 198 // No-op if we're already within the size requirement. 199 size_hint = std::max(size_hint, int64{2}); 200 if (entries_.size() <= size_hint) { 201 return; 202 } 203 204 // First compute the max error bound delta resulting from this compression. 205 double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps); 206 207 // Compress elements ensuring approximation bounds and elements diversity 208 // are both maintained. 209 int64 add_accumulator = 0, add_step = entries_.size(); 210 auto write_it = entries_.begin() + 1, last_it = write_it; 211 for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) { 212 auto next_it = read_it + 1; 213 while (next_it != entries_.end() && add_accumulator < add_step && 214 next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) { 215 add_accumulator += size_hint; 216 ++next_it; 217 } 218 if (read_it == next_it - 1) { 219 ++read_it; 220 } else { 221 read_it = next_it - 1; 222 } 223 (*write_it++) = (*read_it); 224 last_it = read_it; 225 add_accumulator -= add_step; 226 } 227 // Write last element and resize. 228 if (last_it + 1 != entries_.end()) { 229 (*write_it++) = entries_.back(); 230 } 231 entries_.resize(write_it - entries_.begin()); 232 } 233 234 // To construct the boundaries we first run a soft compress over a copy 235 // of the summary and retrieve the values. 236 // The resulting boundaries are guaranteed to both contain at least 237 // num_boundaries unique elements and maintain approximation bounds. GenerateBoundaries(int64 num_boundaries)238 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const { 239 std::vector<ValueType> output; 240 if (entries_.empty()) { 241 return output; 242 } 243 244 // Generate soft compressed summary. 245 WeightedQuantilesSummary<ValueType, WeightType, CompareFn> 246 compressed_summary; 247 compressed_summary.BuildFromSummaryEntries(entries_); 248 // Set an epsilon for compression that's at most 1.0 / num_boundaries 249 // more than epsilon of original our summary since the compression operation 250 // adds ~1.0/num_boundaries to final approximation error. 251 float compression_eps = ApproximationError() + (1.0 / num_boundaries); 252 compressed_summary.Compress(num_boundaries, compression_eps); 253 254 // Remove the least important boundaries by the gap removing them would 255 // create. 256 std::list<int64> boundaries_to_keep; 257 for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) { 258 boundaries_to_keep.push_back(i); 259 } 260 while (boundaries_to_keep.size() > num_boundaries) { 261 std::list<int64>::iterator min_element = boundaries_to_keep.end(); 262 auto prev = boundaries_to_keep.begin(); 263 auto curr = prev; 264 ++curr; 265 auto next = curr; 266 ++next; 267 WeightType min_weight = TotalWeight(); 268 for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) { 269 WeightType new_weight = 270 compressed_summary.entries_[*next].PrevMaxRank() - 271 compressed_summary.entries_[*prev].NextMinRank(); 272 if (new_weight < min_weight) { 273 min_element = curr; 274 min_weight = new_weight; 275 } 276 } 277 boundaries_to_keep.erase(min_element); 278 } 279 280 // Return boundaries. 281 output.reserve(boundaries_to_keep.size()); 282 for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end(); 283 ++itr) { 284 output.push_back(compressed_summary.entries_[*itr].value); 285 } 286 return output; 287 } 288 289 // To construct the desired n-quantiles we repetitively query n ranks from the 290 // original summary. The following algorithm is an efficient cache-friendly 291 // O(n) implementation of that idea which avoids the cost of the repetitive 292 // full rank queries O(nlogn). GenerateQuantiles(int64 num_quantiles)293 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const { 294 std::vector<ValueType> output; 295 if (entries_.empty()) { 296 return output; 297 } 298 num_quantiles = std::max(num_quantiles, int64{2}); 299 output.reserve(num_quantiles + 1); 300 301 // Make successive rank queries to get boundaries. 302 // We always keep the first (min) and last (max) entries. 303 for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) { 304 // This step boils down to finding the next element sub-range defined by 305 // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r. 306 WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles); 307 size_t next_idx = cur_idx + 1; 308 while (next_idx < entries_.size() && 309 d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) { 310 ++next_idx; 311 } 312 cur_idx = next_idx - 1; 313 314 // Determine insertion order. 315 if (next_idx == entries_.size() || 316 d_2 < entries_[cur_idx].NextMinRank() + 317 entries_[next_idx].PrevMaxRank()) { 318 output.push_back(entries_[cur_idx].value); 319 } else { 320 output.push_back(entries_[next_idx].value); 321 } 322 } 323 return output; 324 } 325 326 // Calculates current approximation error which should always be <= eps. ApproximationError()327 double ApproximationError() const { 328 if (entries_.empty()) { 329 return 0; 330 } 331 332 WeightType max_gap = 0; 333 for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) { 334 max_gap = std::max(max_gap, 335 std::max(it->max_rank - it->min_rank - it->weight, 336 it->PrevMaxRank() - (it - 1)->NextMinRank())); 337 } 338 return static_cast<double>(max_gap) / TotalWeight(); 339 } 340 MinValue()341 ValueType MinValue() const { 342 return !entries_.empty() ? entries_.front().value 343 : std::numeric_limits<ValueType>::max(); 344 } MaxValue()345 ValueType MaxValue() const { 346 return !entries_.empty() ? entries_.back().value 347 : std::numeric_limits<ValueType>::max(); 348 } TotalWeight()349 WeightType TotalWeight() const { 350 return !entries_.empty() ? entries_.back().max_rank : 0; 351 } Size()352 int64 Size() const { return entries_.size(); } Clear()353 void Clear() { entries_.clear(); } GetEntryList()354 const std::vector<SummaryEntry>& GetEntryList() const { return entries_; } 355 356 private: 357 // Comparison function. 358 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 359 360 // Summary entries. 361 std::vector<SummaryEntry> entries_; 362 }; 363 364 template <typename ValueType, typename WeightType, typename CompareFn> 365 constexpr decltype(CompareFn()) 366 WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn; 367 368 } // namespace quantiles 369 } // namespace boosted_trees 370 } // namespace tensorflow 371 372 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 373