1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/random/weighted_picker.h"
17
18 #include <string.h>
19 #include <algorithm>
20
21 #include "tensorflow/core/lib/random/simple_philox.h"
22
23 namespace tensorflow {
24 namespace random {
25
WeightedPicker(int N)26 WeightedPicker::WeightedPicker(int N) {
27 CHECK_GE(N, 0);
28 N_ = N;
29
30 // Find the number of levels
31 num_levels_ = 1;
32 while (LevelSize(num_levels_ - 1) < N) {
33 num_levels_++;
34 }
35
36 // Initialize the levels
37 level_ = new int32*[num_levels_];
38 for (int l = 0; l < num_levels_; l++) {
39 level_[l] = new int32[LevelSize(l)];
40 }
41
42 SetAllWeights(1);
43 }
44
~WeightedPicker()45 WeightedPicker::~WeightedPicker() {
46 for (int l = 0; l < num_levels_; l++) {
47 delete[] level_[l];
48 }
49 delete[] level_;
50 }
51
UnbiasedUniform(SimplePhilox * r,int32 n)52 static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
53 CHECK_LE(0, n);
54 const uint32 range = ~static_cast<uint32>(0);
55 if (n == 0) {
56 return r->Rand32() * n;
57 } else if (0 == (n & (n - 1))) {
58 // N is a power of two, so just mask off the lower bits.
59 return r->Rand32() & (n - 1);
60 } else {
61 // Reject all numbers that skew the distribution towards 0.
62
63 // Rand32's output is uniform in the half-open interval [0, 2^{32}).
64 // For any interval [m,n), the number of elements in it is n-m.
65
66 uint32 rem = (range % n) + 1;
67 uint32 rnd;
68
69 // rem = ((2^{32}-1) \bmod n) + 1
70 // 1 <= rem <= n
71
72 // NB: rem == n is impossible, since n is not a power of 2 (from
73 // earlier check).
74
75 do {
76 rnd = r->Rand32(); // rnd uniform over [0, 2^{32})
77 } while (rnd < rem); // reject [0, rem)
78 // rnd is uniform over [rem, 2^{32})
79 //
80 // The number of elements in the half-open interval is
81 //
82 // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
83 // = 2^{32}-1 - ((2^{32}-1) \bmod n)
84 // = n \cdot \lfloor (2^{32}-1)/n \rfloor
85 //
86 // therefore n evenly divides the number of integers in the
87 // interval.
88 //
89 // The function v \rightarrow v % n takes values from [bias,
90 // 2^{32}) to [0, n). Each integer in the range interval [0, n)
91 // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
92 // the domain interval.
93 //
94 // Therefore, v % n is uniform over [0, n). QED.
95
96 return rnd % n;
97 }
98 }
99
Pick(SimplePhilox * rnd) const100 int WeightedPicker::Pick(SimplePhilox* rnd) const {
101 if (total_weight() == 0) return -1;
102
103 // using unbiased uniform distribution to avoid bias
104 // toward low elements resulting from a possible use
105 // of big weights.
106 return PickAt(UnbiasedUniform(rnd, total_weight()));
107 }
108
PickAt(int32 weight_index) const109 int WeightedPicker::PickAt(int32 weight_index) const {
110 if (weight_index < 0 || weight_index >= total_weight()) return -1;
111
112 int32 position = weight_index;
113 int index = 0;
114
115 for (int l = 1; l < num_levels_; l++) {
116 // Pick left or right child of "level_[l-1][index]"
117 const int32 left_weight = level_[l][2 * index];
118 if (position < left_weight) {
119 // Descend to left child
120 index = 2 * index;
121 } else {
122 // Descend to right child
123 index = 2 * index + 1;
124 position -= left_weight;
125 }
126 }
127 CHECK_GE(index, 0);
128 CHECK_LT(index, N_);
129 CHECK_LE(position, level_[num_levels_ - 1][index]);
130 return index;
131 }
132
set_weight(int index,int32 weight)133 void WeightedPicker::set_weight(int index, int32 weight) {
134 assert(index >= 0);
135 assert(index < N_);
136
137 // Adjust the sums all the way up to the root
138 const int32 delta = weight - get_weight(index);
139 for (int l = num_levels_ - 1; l >= 0; l--) {
140 level_[l][index] += delta;
141 index >>= 1;
142 }
143 }
144
SetAllWeights(int32 weight)145 void WeightedPicker::SetAllWeights(int32 weight) {
146 // Initialize leaves
147 int32* leaves = level_[num_levels_ - 1];
148 for (int i = 0; i < N_; i++) leaves[i] = weight;
149 for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
150
151 // Now sum up towards the root
152 RebuildTreeWeights();
153 }
154
SetWeightsFromArray(int N,const int32 * weights)155 void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
156 Resize(N);
157
158 // Initialize leaves
159 int32* leaves = level_[num_levels_ - 1];
160 for (int i = 0; i < N_; i++) leaves[i] = weights[i];
161 for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
162
163 // Now sum up towards the root
164 RebuildTreeWeights();
165 }
166
RebuildTreeWeights()167 void WeightedPicker::RebuildTreeWeights() {
168 for (int l = num_levels_ - 2; l >= 0; l--) {
169 int32* level = level_[l];
170 int32* children = level_[l + 1];
171 for (int i = 0; i < LevelSize(l); i++) {
172 level[i] = children[2 * i] + children[2 * i + 1];
173 }
174 }
175 }
176
Append(int32 weight)177 void WeightedPicker::Append(int32 weight) {
178 Resize(num_elements() + 1);
179 set_weight(num_elements() - 1, weight);
180 }
181
Resize(int new_size)182 void WeightedPicker::Resize(int new_size) {
183 CHECK_GE(new_size, 0);
184 if (new_size <= LevelSize(num_levels_ - 1)) {
185 // The new picker fits in the existing levels.
186
187 // First zero out any of the weights that are being dropped so
188 // that the levels are correct (only needed when shrinking)
189 for (int i = new_size; i < N_; i++) {
190 set_weight(i, 0);
191 }
192
193 // We do not need to set any new weights when enlarging because
194 // the unneeded entries always have weight zero.
195 N_ = new_size;
196 return;
197 }
198
199 // We follow the simple strategy of just copying the old
200 // WeightedPicker into a new WeightedPicker. The cost is
201 // O(N) regardless.
202 assert(new_size > N_);
203 WeightedPicker new_picker(new_size);
204 int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
205 int32* src = this->level_[this->num_levels_ - 1];
206 memcpy(dst, src, sizeof(dst[0]) * N_);
207 memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
208 new_picker.RebuildTreeWeights();
209
210 // Now swap the two pickers
211 std::swap(new_picker.N_, this->N_);
212 std::swap(new_picker.num_levels_, this->num_levels_);
213 std::swap(new_picker.level_, this->level_);
214 assert(this->N_ == new_size);
215 }
216
217 } // namespace random
218 } // namespace tensorflow
219