• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/random_ops.cc.
17 // NOTE: If the algorithm is changed, please run the test
18 // .../python/kernel_tests/random:random_binomial_test
19 // commenting out the "tf.set_random_seed(seed)" lines, and using the
20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the
21 // op results.
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/kernels/random_binomial_op.h"
26 
27 #include <algorithm>
28 #include <cmath>
29 #include <memory>
30 
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
36 #include "tensorflow/core/kernels/training_op_helpers.h"
37 #include "tensorflow/core/lib/core/refcount.h"
38 #include "tensorflow/core/lib/random/random_distributions.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/util/bcast.h"
41 #include "tensorflow/core/util/guarded_philox_random.h"
42 #include "tensorflow/core/util/work_sharder.h"
43 
44 #define UNIFORM(X)                                    \
45   if (uniform_remaining == 0) {                       \
46     uniform_remaining = Uniform::kResultElementCount; \
47     uniform_result = uniform(gen);                    \
48   }                                                   \
49   uniform_remaining--;                                \
50   double X = uniform_result[uniform_remaining]
51 
52 namespace tensorflow {
53 
54 typedef Eigen::ThreadPoolDevice CPUDevice;
55 typedef Eigen::GpuDevice GPUDevice;
56 
57 namespace {
58 
59 typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
60 
61 // Binomial inversion. Given prob, sum geometric random variables until they
62 // exceed count. The number of random variables used is binomially distributed.
63 // This is also known as binomial inversion, as this is equivalent to inverting
64 // the Binomial CDF.
binomial_inversion(double count,double prob,random::PhiloxRandom * gen)65 double binomial_inversion(double count, double prob,
66                           random::PhiloxRandom* gen) {
67   using Eigen::numext::ceil;
68   using Eigen::numext::log;
69   using Eigen::numext::log1p;
70 
71   double geom_sum = 0;
72   int num_geom = 0;
73 
74   Uniform uniform;
75   typename Uniform::ResultType uniform_result;
76   int16 uniform_remaining = 0;
77 
78   while (true) {
79     UNIFORM(u);
80     double geom = ceil(log(u) / log1p(-prob));
81     geom_sum += geom;
82     if (geom_sum > count) {
83       break;
84     }
85     ++num_geom;
86   }
87   return num_geom;
88 }
89 
stirling_approx_tail(double k)90 inline double stirling_approx_tail(double k) {
91   static double kTailValues[] = {0.0810614667953272,  0.0413406959554092,
92                                  0.0276779256849983,  0.02079067210376509,
93                                  0.0166446911898211,  0.0138761288230707,
94                                  0.0118967099458917,  0.0104112652619720,
95                                  0.00925546218271273, 0.00833056343336287};
96   if (k <= 9) {
97     return kTailValues[static_cast<int>(k)];
98   }
99   double kp1sq = (k + 1) * (k + 1);
100   return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
101 }
102 
103 // We use a transformation-rejection algorithm from
104 // pairs of uniform random variables due to Hormann.
105 // https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
btrs(double count,double prob,random::PhiloxRandom * gen)106 inline double btrs(double count, double prob, random::PhiloxRandom* gen) {
107   using Eigen::numext::abs;
108   using Eigen::numext::floor;
109   using Eigen::numext::log;
110   using Eigen::numext::log1p;
111   using Eigen::numext::sqrt;
112 
113   // This is spq in the paper.
114   const double stddev = sqrt(count * prob * (1 - prob));
115 
116   // Other coefficients for Transformed Rejection sampling.
117   const double b = 1.15 + 2.53 * stddev;
118   const double a = -0.0873 + 0.0248 * b + 0.01 * prob;
119   const double c = count * prob + 0.5;
120   const double v_r = 0.92 - 4.2 / b;
121   const double r = prob / (1 - prob);
122 
123   const double alpha = (2.83 + 5.1 / b) * stddev;
124   const double m = floor((count + 1) * prob);
125 
126   Uniform uniform;
127   typename Uniform::ResultType uniform_result;
128   int16 uniform_remaining = 0;
129 
130   while (true) {
131     UNIFORM(u);
132     UNIFORM(v);
133     u = u - 0.5;
134     double us = 0.5 - abs(u);
135     double k = floor((2 * a / us + b) * u + c);
136 
137     // Region for which the box is tight, and we
138     // can return our calculated value This should happen
139     // 0.86 * v_r times. In the limit as n * p is large,
140     // the acceptance rate converges to ~79% (and in the lower
141     // regime it is ~24%).
142     if (us >= 0.07 && v <= v_r) {
143       return k;
144     }
145     // Reject non-sensical answers.
146     if (k < 0 || k > count) {
147       continue;
148     }
149 
150     // This deviates from Hormann's BRTS algorithm, as there is a log missing.
151     // For all (u, v) pairs outside of the bounding box, this calculates the
152     // transformed-reject ratio.
153     v = log(v * alpha / (a / (us * us) + b));
154     double upperbound =
155         ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) +
156          (count + 1) * log((count - m + 1) / (count - k + 1)) +
157          (k + 0.5) * log(r * (count - k + 1) / (k + 1)) +
158          stirling_approx_tail(m) + stirling_approx_tail(count - m) -
159          stirling_approx_tail(k) - stirling_approx_tail(count - k));
160     if (v <= upperbound) {
161       return k;
162     }
163   }
164 }
165 
166 }  // namespace
167 
168 namespace functor {
169 
170 template <typename T, typename U>
171 struct RandomBinomialFunctor<CPUDevice, T, U> {
operator ()tensorflow::functor::RandomBinomialFunctor172   void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
173                   int64 samples_per_batch, int64 num_elements,
174                   const BCast& bcast, typename TTypes<T>::ConstFlat counts,
175                   typename TTypes<T>::ConstFlat probs,
176                   const random::PhiloxRandom& gen,
177                   typename TTypes<U>::Flat output) {
178     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
179 
180     // The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for
181     // the sample shape and [H1, ... Hm] for the batch shape of the samples.
182     // We have B1 * ... * Bk samples per batch member we need.
183     auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
184                    &gen, &output](int start_output, int limit_output) {
185       // Vectorized intermediate calculations for uniform rejection sampling.
186       // We always generate at most 4 samples.
187       Eigen::array<T, 4> z;
188       Eigen::array<T, 4> g;
189       const bool should_bcast = bcast.IsBroadcastingRequired();
190       const auto& counts_batch_indices = bcast.x_batch_indices();
191       const auto& probs_batch_indices = bcast.y_batch_indices();
192       auto output_flat = output.data();
193 
194       // We partition work across batches (count, prob) and then across samples
195       // per batch member, to avoid extra work.
196       for (int64 output_idx = start_output; output_idx < limit_output;
197            // output_idx is incremented with the inner loops below.
198       ) {
199         int64 batch_idx = output_idx / samples_per_batch;
200         U* const output_batch_offset = output_flat + batch_idx;
201         // Generate batch counts from BCast, as it has the right indices to loop
202         // over.
203         T count, prob;
204         if (should_bcast) {
205           count = counts(counts_batch_indices[batch_idx]);
206           prob = probs(probs_batch_indices[batch_idx]);
207         } else {
208           count = counts(batch_idx);
209           prob = probs(batch_idx);
210         }
211 
212         // Calculate normalized samples, then convert them.
213         // Determine the method to use.
214         double dcount = static_cast<double>(count);
215         if (dcount <= 0.0 || prob <= T(0.0)) {
216           for (int64 sample_idx = output_idx % samples_per_batch;
217                sample_idx < samples_per_batch && output_idx < limit_output;
218                ++sample_idx, ++output_idx) {
219             output_batch_offset[sample_idx * num_batches] = static_cast<U>(0.0);
220           }
221         } else if (prob >= T(1.0)) {
222           for (int64 sample_idx = output_idx % samples_per_batch;
223                sample_idx < samples_per_batch && output_idx < limit_output;
224                ++sample_idx, ++output_idx) {
225             output_batch_offset[sample_idx * num_batches] =
226                 static_cast<U>(dcount);
227           }
228         } else if (prob <= T(0.5)) {
229           double dp = static_cast<double>(prob);
230           if (count * prob >= T(10)) {
231             for (int64 sample_idx = output_idx % samples_per_batch;
232                  sample_idx < samples_per_batch && output_idx < limit_output;
233                  ++sample_idx, ++output_idx) {
234               random::PhiloxRandom gen_copy = gen;
235               gen_copy.Skip(256 * output_idx);
236               output_batch_offset[sample_idx * num_batches] =
237                   static_cast<U>(btrs(dcount, dp, &gen_copy));
238             }
239           } else {
240             for (int64 sample_idx = output_idx % samples_per_batch;
241                  sample_idx < samples_per_batch && output_idx < limit_output;
242                  ++sample_idx, ++output_idx) {
243               random::PhiloxRandom gen_copy = gen;
244               // For binomial inversion, we have mean <= 10, variance <= 10.
245               // This means on average we need at most 10 number of samples,
246               // and for 10 standard deviations, we need 42 samples. We reserve
247               // that much.
248               gen_copy.Skip(42 * output_idx);
249               output_batch_offset[sample_idx * num_batches] =
250                   static_cast<U>(binomial_inversion(dcount, dp, &gen_copy));
251             }
252           }
253         } else if (prob > T(0.5)) {
254           T q = T(1) - prob;
255           double dcount = static_cast<double>(count);
256           double dq = static_cast<double>(q);
257           if (count * q >= T(10)) {
258             for (int64 sample_idx = output_idx % samples_per_batch;
259                  sample_idx < samples_per_batch && output_idx < limit_output;
260                  ++sample_idx, ++output_idx) {
261               random::PhiloxRandom gen_copy = gen;
262               gen_copy.Skip(256 * output_idx);
263               output_batch_offset[sample_idx * num_batches] =
264                   static_cast<U>(dcount - btrs(dcount, dq, &gen_copy));
265             }
266           } else {
267             for (int64 sample_idx = output_idx % samples_per_batch;
268                  sample_idx < samples_per_batch && output_idx < limit_output;
269                  ++sample_idx, ++output_idx) {
270               random::PhiloxRandom gen_copy = gen;
271               // For binomial inversion, we have mean <= 10, variance <= 10.
272               // This means on average we need at most 10 number of samples,
273               // and for 10 standard deviations, we need 42 samples. We reserve
274               // that much.
275               gen_copy.Skip(42 * output_idx);
276               output_batch_offset[sample_idx * num_batches] = static_cast<U>(
277                   dcount - binomial_inversion(dcount, dq, &gen_copy));
278             }
279           }
280         } else {  // prob is NaN
281           // TODO(srvasude): What should happen if prob is NaN but the output
282           // type is an integer (which doesn't have a sentinel for NaN)?  Fail
283           // the whole batch sample?  Return a specialized sentinel like -1?
284           for (int64 sample_idx = output_idx % samples_per_batch;
285                sample_idx < samples_per_batch && output_idx < limit_output;
286                ++sample_idx, ++output_idx) {
287             output_batch_offset[sample_idx * num_batches] = static_cast<U>(NAN);
288           }
289         }
290       }
291     };
292 
293     // This will depend on count * p (or count * q).
294     // For n * p < 10, on average, O(n * p) calls to uniform are
295     // needed, with that
296     // many multiplies. ~10 uniform calls on average with ~200 cost op calls.
297     //
298     // Very roughly, for rate >= 10, the four calls to log
299     // occur for ~72 percent of samples.
300     // 4 x 100 (64-bit cycles per log) * 0.72 = ~288
301     // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
302     // 40 * .72  = ~25.
303     //
304     // Finally, there are several other ops that are done every loop along with
305     // 2 uniform generations along with 5 other ops at 3-6 cycles each.
306     // ~15 / .89 = ~16
307     //
308     // In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost.
309     // We assume that half the tensor has rate < 10, so on average 6
310     // uniform's
311     // will be needed. We will upper bound the other op cost by the one for
312     // rate > 10.
313     static const int kElementCost = 329 + 6 * Uniform::kElementCost +
314                                     6 * random::PhiloxRandom::kElementCost;
315     Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
316           kElementCost, DoWork);
317   }
318 };
319 
320 }  // namespace functor
321 
322 namespace {
323 
324 // Samples from a binomial distribution, using the given parameters.
325 template <typename Device, typename T, typename U>
326 class RandomBinomialOp : public OpKernel {
327   // Reshape batches so each batch is this size if possible.
328   static const int32 kDesiredBatchSize = 100;
329 
330  public:
RandomBinomialOp(OpKernelConstruction * context)331   explicit RandomBinomialOp(OpKernelConstruction* context)
332       : OpKernel(context) {}
333 
Compute(OpKernelContext * ctx)334   void Compute(OpKernelContext* ctx) override {
335     const Tensor& alg_tensor = ctx->input(1);
336     const Tensor& shape_tensor = ctx->input(2);
337     const Tensor& counts_tensor = ctx->input(3);
338     const Tensor& probs_tensor = ctx->input(4);
339 
340     tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
341                             probs_tensor.shape().dim_sizes(),
342                             /*fewer_dims_optimization=*/false,
343                             /*return_flattened_batch_indices=*/true);
344     OP_REQUIRES(ctx, bcast.IsValid(),
345                 errors::InvalidArgument(
346                     "counts and probs must have compatible batch dimensions: ",
347                     counts_tensor.shape().DebugString(), " vs. ",
348                     probs_tensor.shape().DebugString()));
349     OP_REQUIRES(
350         ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
351         errors::InvalidArgument("Input shape should be a vector, got shape: ",
352                                 shape_tensor.shape().DebugString()));
353     OP_REQUIRES(ctx,
354                 (shape_tensor.dtype() == DataType::DT_INT32 ||
355                  shape_tensor.dtype() == DataType::DT_INT64),
356                 errors::InvalidArgument(
357                     "Input shape should have dtype {int32, int64}."));
358 
359     // Let's check that the shape tensor dominates the broadcasted tensor.
360     TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
361     TensorShape output_shape;
362     if (shape_tensor.dtype() == DataType::DT_INT32) {
363       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
364                                                       &output_shape));
365     } else {
366       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
367                                                       &output_shape));
368     }
369     OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
370                 errors::InvalidArgument(
371                     "Shape passed in must end with broadcasted shape."));
372     // Now that we have a guarantee, we can get the additional dimensions added
373     // by sampling.
374     OP_REQUIRES(ctx, alg_tensor.dims() == 0,
375                 errors::InvalidArgument("algorithm must be of shape [], not ",
376                                         alg_tensor.shape().DebugString()));
377     Algorithm alg = alg_tensor.flat<Algorithm>()(0);
378 
379     int64 samples_per_batch = 1;
380     const int64 num_sample_dims =
381         (shape_tensor.dim_size(0) - bcast.output_shape().size());
382     for (int64 i = 0; i < num_sample_dims; ++i) {
383       samples_per_batch *= shape_tensor.flat<int32>()(i);
384     }
385     int64 num_batches = 1;
386     for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
387       num_batches *= shape_tensor.flat<int32>()(i);
388     }
389     const int64 num_elements = num_batches * samples_per_batch;
390 
391     Tensor* samples_tensor;
392     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
393 
394     core::RefCountPtr<Var> var;
395     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
396 
397     Tensor* var_tensor = var->tensor();
398     OP_REQUIRES(
399         ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
400         errors::InvalidArgument("dtype of RNG state variable must be ",
401                                 DataTypeString(STATE_ELEMENT_DTYPE), ", not ",
402                                 DataTypeString(var_tensor->dtype())));
403     OP_REQUIRES(ctx, var_tensor->dims() == 1,
404                 errors::InvalidArgument(
405                     "RNG state must have one and only one dimension, not ",
406                     var_tensor->dims()));
407     auto var_tensor_flat = var_tensor->flat<StateElementType>();
408     OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX,
409                 errors::InvalidArgument("Unsupported algorithm id: ", alg));
410     static_assert(std::is_same<StateElementType, int64>::value,
411                   "StateElementType must be int64");
412     static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
413                   "PhiloxRandom::ResultElementType must be uint32");
414     OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE,
415                 errors::InvalidArgument(
416                     "For Philox algorithm, the size of state must be at least ",
417                     PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size()));
418 
419     OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
420                             ctx, var_tensor, var->copy_on_read_mode.load()));
421     auto var_data = var_tensor_flat.data();
422     auto philox = GetPhiloxRandomFromMem(var_data);
423     UpdateMemWithPhiloxRandom(
424         philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
425 
426     auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
427     binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
428                      samples_per_batch, num_elements, bcast,
429                      counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
430                      samples_tensor->flat<U>());
431   }
432 
433  private:
434   TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp);
435 };
436 
437 }  // namespace
438 
439 #define REGISTER(RTYPE, TYPE)                                 \
440   REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial")      \
441                               .Device(DEVICE_CPU)             \
442                               .HostMemory("resource")         \
443                               .HostMemory("algorithm")        \
444                               .HostMemory("shape")            \
445                               .HostMemory("counts")           \
446                               .HostMemory("probs")            \
447                               .TypeConstraint<RTYPE>("dtype") \
448                               .TypeConstraint<TYPE>("T"),     \
449                           RandomBinomialOp<CPUDevice, TYPE, RTYPE>)
450 
451 #define REGISTER_ALL(RTYPE)     \
452   REGISTER(RTYPE, Eigen::half); \
453   REGISTER(RTYPE, float);       \
454   REGISTER(RTYPE, double);
455 
456 REGISTER_ALL(Eigen::half);
457 REGISTER_ALL(float);
458 REGISTER_ALL(double);
459 REGISTER_ALL(int32);
460 REGISTER_ALL(int64);
461 
462 #undef REGISTER
463 #undef REGISTER_ALL
464 
465 }  // end namespace tensorflow
466