1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "tensorflow/core/kernels/range_sampler.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/lib/random/simple_philox.h"
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace {
28
29 using gtl::ArraySlice;
30 using gtl::MutableArraySlice;
31
32 class RangeSamplerTest : public ::testing::Test {
33 protected:
CheckProbabilitiesSumToOne()34 void CheckProbabilitiesSumToOne() {
35 double sum = 0;
36 for (int i = 0; i < sampler_->range(); i++) {
37 sum += sampler_->Probability(i);
38 }
39 EXPECT_NEAR(sum, 1.0, 1e-4);
40 }
CheckHistogram(int num_samples,float tolerance)41 void CheckHistogram(int num_samples, float tolerance) {
42 const int range = sampler_->range();
43 std::vector<int> h(range);
44 std::vector<int64> a(num_samples);
45 // Using a fixed random seed to make the test deterministic.
46 random::PhiloxRandom philox(123, 17);
47 random::SimplePhilox rnd(&philox);
48 sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
49 for (int i = 0; i < num_samples; i++) {
50 int64 val = a[i];
51 ASSERT_GE(val, 0);
52 ASSERT_LT(val, range);
53 h[val]++;
54 }
55 for (int val = 0; val < range; val++) {
56 EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
57 tolerance);
58 }
59 }
Update1()60 void Update1() {
61 // Add the value 3 ten times.
62 std::vector<int64> a(10);
63 for (int i = 0; i < 10; i++) {
64 a[i] = 3;
65 }
66 sampler_->Update(a);
67 }
Update2()68 void Update2() {
69 // Add the value n times.
70 int64 a[10];
71 for (int i = 0; i < 10; i++) {
72 a[i] = i;
73 }
74 for (int64 i = 1; i < 10; i++) {
75 sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
76 }
77 }
78 std::unique_ptr<RangeSampler> sampler_;
79 };
80
TEST_F(RangeSamplerTest,UniformProbabilities)81 TEST_F(RangeSamplerTest, UniformProbabilities) {
82 sampler_.reset(new UniformSampler(10));
83 for (int i = 0; i < 10; i++) {
84 CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
85 }
86 }
87
TEST_F(RangeSamplerTest,UniformChecksum)88 TEST_F(RangeSamplerTest, UniformChecksum) {
89 sampler_.reset(new UniformSampler(10));
90 CheckProbabilitiesSumToOne();
91 }
92
TEST_F(RangeSamplerTest,UniformHistogram)93 TEST_F(RangeSamplerTest, UniformHistogram) {
94 sampler_.reset(new UniformSampler(10));
95 CheckHistogram(1000, 0.05);
96 }
97
TEST_F(RangeSamplerTest,LogUniformProbabilities)98 TEST_F(RangeSamplerTest, LogUniformProbabilities) {
99 int range = 1000000;
100 sampler_.reset(new LogUniformSampler(range));
101 for (int i = 100; i < range; i *= 2) {
102 float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
103 EXPECT_NEAR(ratio, 0.5, 0.1);
104 }
105 }
106
TEST_F(RangeSamplerTest,LogUniformChecksum)107 TEST_F(RangeSamplerTest, LogUniformChecksum) {
108 sampler_.reset(new LogUniformSampler(10));
109 CheckProbabilitiesSumToOne();
110 }
111
TEST_F(RangeSamplerTest,LogUniformHistogram)112 TEST_F(RangeSamplerTest, LogUniformHistogram) {
113 sampler_.reset(new LogUniformSampler(10));
114 CheckHistogram(1000, 0.05);
115 }
116
TEST_F(RangeSamplerTest,UnigramProbabilities1)117 TEST_F(RangeSamplerTest, UnigramProbabilities1) {
118 sampler_.reset(new UnigramSampler(10));
119 Update1();
120 EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
121 for (int i = 0; i < 10; i++) {
122 if (i != 3) {
123 ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
124 }
125 }
126 }
TEST_F(RangeSamplerTest,UnigramProbabilities2)127 TEST_F(RangeSamplerTest, UnigramProbabilities2) {
128 sampler_.reset(new UnigramSampler(10));
129 Update2();
130 for (int i = 0; i < 10; i++) {
131 ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
132 }
133 }
TEST_F(RangeSamplerTest,UnigramChecksum)134 TEST_F(RangeSamplerTest, UnigramChecksum) {
135 sampler_.reset(new UnigramSampler(10));
136 Update1();
137 CheckProbabilitiesSumToOne();
138 }
139
TEST_F(RangeSamplerTest,UnigramHistogram)140 TEST_F(RangeSamplerTest, UnigramHistogram) {
141 sampler_.reset(new UnigramSampler(10));
142 Update1();
143 CheckHistogram(1000, 0.05);
144 }
145
146 static const char kVocabContent[] =
147 "w1,1\n"
148 "w2,2\n"
149 "w3,4\n"
150 "w4,8\n"
151 "w5,16\n"
152 "w6,32\n"
153 "w7,64\n"
154 "w8,128\n"
155 "w9,256";
TEST_F(RangeSamplerTest,FixedUnigramProbabilities)156 TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
157 Env* env = Env::Default();
158 string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
159 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
160 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
161 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
162 for (int i = 0; i < 9; i++) {
163 ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
164 }
165 }
TEST_F(RangeSamplerTest,FixedUnigramChecksum)166 TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
167 Env* env = Env::Default();
168 string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
169 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
170 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
171 CheckProbabilitiesSumToOne();
172 }
173
TEST_F(RangeSamplerTest,FixedUnigramHistogram)174 TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
175 Env* env = Env::Default();
176 string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
177 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
178 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
179 CheckHistogram(1000, 0.05);
180 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve1)181 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
182 Env* env = Env::Default();
183 string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
184 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
185 sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
186 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
187 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
188 for (int i = 1; i < 10; i++) {
189 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
190 }
191 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve2)192 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
193 Env* env = Env::Default();
194 string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
195 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
196 sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
197 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
198 ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
199 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
200 for (int i = 2; i < 11; i++) {
201 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
202 }
203 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesFromVector)204 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
205 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
206 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
207 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
208 for (int i = 0; i < 9; i++) {
209 ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
210 }
211 }
TEST_F(RangeSamplerTest,FixedUnigramChecksumFromVector)212 TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
213 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
214 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
215 CheckProbabilitiesSumToOne();
216 }
TEST_F(RangeSamplerTest,FixedUnigramHistogramFromVector)217 TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
218 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
219 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
220 CheckHistogram(1000, 0.05);
221 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve1FromVector)222 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
223 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
224 sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
225 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
226 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
227 for (int i = 1; i < 10; i++) {
228 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
229 }
230 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve2FromVector)231 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
232 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
233 sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
234 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
235 ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
236 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
237 for (int i = 2; i < 11; i++) {
238 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
239 }
240 }
241
242 // AllSampler cannot call Sample or Probability directly.
243 // We will test SampleBatchGetExpectedCount instead.
TEST_F(RangeSamplerTest,All)244 TEST_F(RangeSamplerTest, All) {
245 int batch_size = 10;
246 sampler_.reset(new AllSampler(10));
247 std::vector<int64> batch(batch_size);
248 std::vector<float> batch_expected(batch_size);
249 std::vector<int64> extras(2);
250 std::vector<float> extras_expected(2);
251 extras[0] = 0;
252 extras[1] = batch_size - 1;
253 sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
254 false, absl::MakeSpan(batch),
255 absl::MakeSpan(batch_expected), extras,
256 absl::MakeSpan(extras_expected));
257 for (int i = 0; i < batch_size; i++) {
258 EXPECT_EQ(i, batch[i]);
259 EXPECT_EQ(1, batch_expected[i]);
260 }
261 EXPECT_EQ(1, extras_expected[0]);
262 EXPECT_EQ(1, extras_expected[1]);
263 }
264
TEST_F(RangeSamplerTest,Unique)265 TEST_F(RangeSamplerTest, Unique) {
266 // We sample num_batches batches, each without replacement.
267 //
268 // We check that the returned expected counts roughly agree with each other
269 // and with the average observed frequencies over the set of batches.
270 random::PhiloxRandom philox(123, 17);
271 random::SimplePhilox rnd(&philox);
272 const int range = 100;
273 const int batch_size = 50;
274 const int num_batches = 100;
275 sampler_.reset(new LogUniformSampler(range));
276 std::vector<int> histogram(range);
277 std::vector<int64> batch(batch_size);
278 std::vector<int64> all_values(range);
279 for (int i = 0; i < range; i++) {
280 all_values[i] = i;
281 }
282 std::vector<float> expected(range);
283
284 // Sample one batch and get the expected counts of all values
285 sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
286 MutableArraySlice<float>(), all_values,
287 absl::MakeSpan(expected));
288 // Check that all elements are unique
289 std::set<int64> s(batch.begin(), batch.end());
290 CHECK_EQ(batch_size, s.size());
291
292 for (int trial = 0; trial < num_batches; trial++) {
293 std::vector<float> trial_expected(range);
294 sampler_->SampleBatchGetExpectedCount(
295 &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
296 all_values, absl::MakeSpan(trial_expected));
297 for (int i = 0; i < range; i++) {
298 EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
299 }
300 for (int i = 0; i < batch_size; i++) {
301 histogram[batch[i]]++;
302 }
303 }
304 for (int i = 0; i < range; i++) {
305 // Check that the computed expected count agrees with the average observed
306 // count.
307 const float average_count = static_cast<float>(histogram[i]) / num_batches;
308 EXPECT_NEAR(expected[i], average_count, 0.2);
309 }
310 }
311
TEST_F(RangeSamplerTest,Avoid)312 TEST_F(RangeSamplerTest, Avoid) {
313 random::PhiloxRandom philox(123, 17);
314 random::SimplePhilox rnd(&philox);
315 sampler_.reset(new LogUniformSampler(100));
316 std::vector<int64> avoided(2);
317 avoided[0] = 17;
318 avoided[1] = 23;
319 std::vector<int64> batch(98);
320
321 // We expect to pick all elements of [0, 100) except the avoided two.
322 sampler_->SampleBatchGetExpectedCountAvoid(
323 &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
324 ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
325
326 int sum = 0;
327 for (auto val : batch) {
328 sum += val;
329 }
330 const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
331 EXPECT_EQ(expected_sum, sum);
332 }
333
334 } // namespace
335
336 } // namespace tensorflow
337