• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/core/kernels/range_sampler.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/lib/random/simple_philox.h"
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 using gtl::ArraySlice;
30 using gtl::MutableArraySlice;
31 
32 class RangeSamplerTest : public ::testing::Test {
33  protected:
CheckProbabilitiesSumToOne()34   void CheckProbabilitiesSumToOne() {
35     double sum = 0;
36     for (int i = 0; i < sampler_->range(); i++) {
37       sum += sampler_->Probability(i);
38     }
39     EXPECT_NEAR(sum, 1.0, 1e-4);
40   }
CheckHistogram(int num_samples,float tolerance)41   void CheckHistogram(int num_samples, float tolerance) {
42     const int range = sampler_->range();
43     std::vector<int> h(range);
44     std::vector<int64> a(num_samples);
45     // Using a fixed random seed to make the test deterministic.
46     random::PhiloxRandom philox(123, 17);
47     random::SimplePhilox rnd(&philox);
48     sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
49     for (int i = 0; i < num_samples; i++) {
50       int64 val = a[i];
51       ASSERT_GE(val, 0);
52       ASSERT_LT(val, range);
53       h[val]++;
54     }
55     for (int val = 0; val < range; val++) {
56       EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
57                   tolerance);
58     }
59   }
Update1()60   void Update1() {
61     // Add the value 3 ten times.
62     std::vector<int64> a(10);
63     for (int i = 0; i < 10; i++) {
64       a[i] = 3;
65     }
66     sampler_->Update(a);
67   }
Update2()68   void Update2() {
69     // Add the value n times.
70     int64 a[10];
71     for (int i = 0; i < 10; i++) {
72       a[i] = i;
73     }
74     for (int64 i = 1; i < 10; i++) {
75       sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
76     }
77   }
78   std::unique_ptr<RangeSampler> sampler_;
79 };
80 
TEST_F(RangeSamplerTest,UniformProbabilities)81 TEST_F(RangeSamplerTest, UniformProbabilities) {
82   sampler_.reset(new UniformSampler(10));
83   for (int i = 0; i < 10; i++) {
84     CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
85   }
86 }
87 
TEST_F(RangeSamplerTest,UniformChecksum)88 TEST_F(RangeSamplerTest, UniformChecksum) {
89   sampler_.reset(new UniformSampler(10));
90   CheckProbabilitiesSumToOne();
91 }
92 
TEST_F(RangeSamplerTest,UniformHistogram)93 TEST_F(RangeSamplerTest, UniformHistogram) {
94   sampler_.reset(new UniformSampler(10));
95   CheckHistogram(1000, 0.05);
96 }
97 
TEST_F(RangeSamplerTest,LogUniformProbabilities)98 TEST_F(RangeSamplerTest, LogUniformProbabilities) {
99   int range = 1000000;
100   sampler_.reset(new LogUniformSampler(range));
101   for (int i = 100; i < range; i *= 2) {
102     float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
103     EXPECT_NEAR(ratio, 0.5, 0.1);
104   }
105 }
106 
TEST_F(RangeSamplerTest,LogUniformChecksum)107 TEST_F(RangeSamplerTest, LogUniformChecksum) {
108   sampler_.reset(new LogUniformSampler(10));
109   CheckProbabilitiesSumToOne();
110 }
111 
TEST_F(RangeSamplerTest,LogUniformHistogram)112 TEST_F(RangeSamplerTest, LogUniformHistogram) {
113   sampler_.reset(new LogUniformSampler(10));
114   CheckHistogram(1000, 0.05);
115 }
116 
TEST_F(RangeSamplerTest,UnigramProbabilities1)117 TEST_F(RangeSamplerTest, UnigramProbabilities1) {
118   sampler_.reset(new UnigramSampler(10));
119   Update1();
120   EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
121   for (int i = 0; i < 10; i++) {
122     if (i != 3) {
123       ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
124     }
125   }
126 }
TEST_F(RangeSamplerTest,UnigramProbabilities2)127 TEST_F(RangeSamplerTest, UnigramProbabilities2) {
128   sampler_.reset(new UnigramSampler(10));
129   Update2();
130   for (int i = 0; i < 10; i++) {
131     ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
132   }
133 }
TEST_F(RangeSamplerTest,UnigramChecksum)134 TEST_F(RangeSamplerTest, UnigramChecksum) {
135   sampler_.reset(new UnigramSampler(10));
136   Update1();
137   CheckProbabilitiesSumToOne();
138 }
139 
TEST_F(RangeSamplerTest,UnigramHistogram)140 TEST_F(RangeSamplerTest, UnigramHistogram) {
141   sampler_.reset(new UnigramSampler(10));
142   Update1();
143   CheckHistogram(1000, 0.05);
144 }
145 
146 static const char kVocabContent[] =
147     "w1,1\n"
148     "w2,2\n"
149     "w3,4\n"
150     "w4,8\n"
151     "w5,16\n"
152     "w6,32\n"
153     "w7,64\n"
154     "w8,128\n"
155     "w9,256";
TEST_F(RangeSamplerTest,FixedUnigramProbabilities)156 TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
157   Env* env = Env::Default();
158   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
159   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
160   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
161   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
162   for (int i = 0; i < 9; i++) {
163     ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
164   }
165 }
TEST_F(RangeSamplerTest,FixedUnigramChecksum)166 TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
167   Env* env = Env::Default();
168   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
169   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
170   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
171   CheckProbabilitiesSumToOne();
172 }
173 
TEST_F(RangeSamplerTest,FixedUnigramHistogram)174 TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
175   Env* env = Env::Default();
176   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
177   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
178   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
179   CheckHistogram(1000, 0.05);
180 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve1)181 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
182   Env* env = Env::Default();
183   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
184   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
185   sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
186   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
187   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
188   for (int i = 1; i < 10; i++) {
189     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
190   }
191 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve2)192 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
193   Env* env = Env::Default();
194   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
195   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
196   sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
197   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
198   ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
199   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
200   for (int i = 2; i < 11; i++) {
201     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
202   }
203 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesFromVector)204 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
205   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
206   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
207   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
208   for (int i = 0; i < 9; i++) {
209     ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
210   }
211 }
TEST_F(RangeSamplerTest,FixedUnigramChecksumFromVector)212 TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
213   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
214   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
215   CheckProbabilitiesSumToOne();
216 }
TEST_F(RangeSamplerTest,FixedUnigramHistogramFromVector)217 TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
218   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
219   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
220   CheckHistogram(1000, 0.05);
221 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve1FromVector)222 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
223   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
224   sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
225   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
226   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
227   for (int i = 1; i < 10; i++) {
228     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
229   }
230 }
TEST_F(RangeSamplerTest,FixedUnigramProbabilitiesReserve2FromVector)231 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
232   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
233   sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
234   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
235   ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
236   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
237   for (int i = 2; i < 11; i++) {
238     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
239   }
240 }
241 
242 // AllSampler cannot call Sample or Probability directly.
243 // We will test SampleBatchGetExpectedCount instead.
TEST_F(RangeSamplerTest,All)244 TEST_F(RangeSamplerTest, All) {
245   int batch_size = 10;
246   sampler_.reset(new AllSampler(10));
247   std::vector<int64> batch(batch_size);
248   std::vector<float> batch_expected(batch_size);
249   std::vector<int64> extras(2);
250   std::vector<float> extras_expected(2);
251   extras[0] = 0;
252   extras[1] = batch_size - 1;
253   sampler_->SampleBatchGetExpectedCount(nullptr,  // no random numbers needed
254                                         false, absl::MakeSpan(batch),
255                                         absl::MakeSpan(batch_expected), extras,
256                                         absl::MakeSpan(extras_expected));
257   for (int i = 0; i < batch_size; i++) {
258     EXPECT_EQ(i, batch[i]);
259     EXPECT_EQ(1, batch_expected[i]);
260   }
261   EXPECT_EQ(1, extras_expected[0]);
262   EXPECT_EQ(1, extras_expected[1]);
263 }
264 
TEST_F(RangeSamplerTest,Unique)265 TEST_F(RangeSamplerTest, Unique) {
266   // We sample num_batches batches, each without replacement.
267   //
268   // We check that the returned expected counts roughly agree with each other
269   // and with the average observed frequencies over the set of batches.
270   random::PhiloxRandom philox(123, 17);
271   random::SimplePhilox rnd(&philox);
272   const int range = 100;
273   const int batch_size = 50;
274   const int num_batches = 100;
275   sampler_.reset(new LogUniformSampler(range));
276   std::vector<int> histogram(range);
277   std::vector<int64> batch(batch_size);
278   std::vector<int64> all_values(range);
279   for (int i = 0; i < range; i++) {
280     all_values[i] = i;
281   }
282   std::vector<float> expected(range);
283 
284   // Sample one batch and get the expected counts of all values
285   sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
286                                         MutableArraySlice<float>(), all_values,
287                                         absl::MakeSpan(expected));
288   // Check that all elements are unique
289   std::set<int64> s(batch.begin(), batch.end());
290   CHECK_EQ(batch_size, s.size());
291 
292   for (int trial = 0; trial < num_batches; trial++) {
293     std::vector<float> trial_expected(range);
294     sampler_->SampleBatchGetExpectedCount(
295         &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
296         all_values, absl::MakeSpan(trial_expected));
297     for (int i = 0; i < range; i++) {
298       EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
299     }
300     for (int i = 0; i < batch_size; i++) {
301       histogram[batch[i]]++;
302     }
303   }
304   for (int i = 0; i < range; i++) {
305     // Check that the computed expected count agrees with the average observed
306     // count.
307     const float average_count = static_cast<float>(histogram[i]) / num_batches;
308     EXPECT_NEAR(expected[i], average_count, 0.2);
309   }
310 }
311 
TEST_F(RangeSamplerTest,Avoid)312 TEST_F(RangeSamplerTest, Avoid) {
313   random::PhiloxRandom philox(123, 17);
314   random::SimplePhilox rnd(&philox);
315   sampler_.reset(new LogUniformSampler(100));
316   std::vector<int64> avoided(2);
317   avoided[0] = 17;
318   avoided[1] = 23;
319   std::vector<int64> batch(98);
320 
321   // We expect to pick all elements of [0, 100) except the avoided two.
322   sampler_->SampleBatchGetExpectedCountAvoid(
323       &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
324       ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
325 
326   int sum = 0;
327   for (auto val : batch) {
328     sum += val;
329   }
330   const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
331   EXPECT_EQ(expected_sum, sum);
332 }
333 
334 }  // namespace
335 
336 }  // namespace tensorflow
337