• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/random_ops.cc.
17 // NOTE: If the algorithm is changed, please run the test
18 // .../python/kernel_tests:parameterized_truncated_normal_op_test
19 // commenting out the "tf.set_random_seed(seed)" lines, and using the
20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the
21 // op results.
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
26 
27 #include <algorithm>
28 #include <cmath>
29 #include <memory>
30 
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/lib/random/random_distributions.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/util/guarded_philox_random.h"
38 #include "tensorflow/core/util/work_sharder.h"
39 
40 namespace tensorflow {
41 
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 typedef Eigen::GpuDevice GPUDevice;
44 
45 namespace functor {
46 using random::PhiloxRandom;
47 
48 static constexpr int kMaxIterations = 1000;
49 
50 template <typename T>
51 struct TruncatedNormalFunctor<CPUDevice, T> {
operator ()tensorflow::functor::TruncatedNormalFunctor52   void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
53                   int64 samples_per_batch, int64 num_elements,
54                   typename TTypes<T>::ConstFlat means,
55                   typename TTypes<T>::ConstFlat stddevs,
56                   typename TTypes<T>::ConstFlat minvals,
57                   typename TTypes<T>::ConstFlat maxvals,
58                   const random::PhiloxRandom& gen,
59                   typename TTypes<T>::Flat output) {
60     // The randn rejection sampling is used when the mean and at least this many
61     // standard deviations are inside the bounds.
62     // The uniform proposal samplers become less efficient as the bounds are
63     // further from the mean, the reverse is true for the randn sampler.
64     // This number was chosen by empirical benchmarking. If modified, the
65     // benchmarks in parameterized_truncated_normal_op_test should also be
66     // changed.
67     const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
68     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
69 
70     auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
71                    &minvals, &maxvals, &gen, &output,
72                    kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
73                                                           int limit_batch) {
74       // Capturing "gen" by-value would only make a copy for the _shared_
75       // lambda.  Since we want to let each worker have its own copy, we pass
76       // "gen" by reference and explicitly do a copy assignment here.
77       random::PhiloxRandom gen_copy = gen;
78       // Skip takes units of 128 bytes.  +3 is so rounding doesn't lead to
79       // us using the same state in different batches.
80       // The sample from each iteration uses 2 random numbers.
81       gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
82                     4);
83       typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
84       Uniform dist;
85       typedef random::NormalDistribution<random::PhiloxRandom, T> Normal;
86       Normal normal_dist;
87 
88       // Vectorized intermediate calculations for uniform rejection sampling.
89       // We always generate at most 4 samples.
90       Eigen::array<T, 4> z;
91       Eigen::array<T, 4> g;
92 
93       for (int64 b = start_batch; b < limit_batch; ++b) {
94         // We are passed a flat array for each of the parameter tensors.
95         // The input is either a scalar broadcasted to all batches or a vector
96         // with length num_batches, but the scalar becomes an array of length 1.
97         T mean = means((means.dimension(0) == 1) ? 0 : b);
98         T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b);
99         T minval = minvals((minvals.dimension(0) == 1) ? 0 : b);
100         T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b);
101 
102         // The last batch can be short, if we adjusted num_batches and
103         // samples_per_batch.
104         const int64 limit_sample =
105             std::min((b + 1) * samples_per_batch, num_elements);
106         int64 sample = b * samples_per_batch;
107 
108         // On GPU, this check will just fill samples with NAN if it fails.
109         OP_REQUIRES(ctx,
110                     stddev > T(0) && minval < maxval &&
111                         (Eigen::numext::isfinite(minval) ||
112                          Eigen::numext::isfinite(maxval)),
113                     errors::InvalidArgument("Invalid parameters"));
114 
115         int numIterations = 0;
116 
117         // If possible, make one-sided bound be the lower bound, or make both
118         // bounds positive. Otherwise, the bounds are on either side of the
119         // mean.
120         if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
121           // Reverse all calculations. normMin and normMax will be flipped.
122           std::swap(minval, maxval);
123           stddev = -stddev;
124         }
125 
126         // Calculate normalized samples, then convert them.
127         const T normMin = (minval - mean) / stddev;
128         const T normMax = (maxval - mean) / stddev;
129 
130         // Determine the method to use.
131         const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
132         const T cutoff =
133             T(2) *
134             Eigen::numext::exp(T(0.5) +
135                                (normMin * (normMin - sqrtFactor)) / T(4)) /
136             (normMin + sqrtFactor);
137         const T diff = normMax - normMin;
138 
139         if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
140              (normMax >= T(0.))) ||
141             ((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
142              (normMin <= T(0.)))) {
143           // If the bounds are a least 3 standard deviations from the mean
144           // on at least one side then we rejection sample by sampling
145           // from the normal distribution and rejecting samples outside
146           // the bounds.
147           // Under this condition the acceptance rate per iteration should
148           // always be ~ 50%. This sampler is more efficient (and more
149           // numerically stable when one or both bounds is far from the mean).
150 
151           while (sample < limit_sample) {
152             const auto randn_sample = normal_dist(&gen_copy);
153             const int size = randn_sample.size();
154 
155             for (int i = 0; i < size; i++) {
156               if ((randn_sample[i] >= normMin) &&
157                   (randn_sample[i] <= normMax)) {
158                 output(sample) = randn_sample[i] * stddev + mean;
159                 sample++;
160                 if (sample >= limit_sample) {
161                   break;
162                 }
163                 numIterations = 0;
164               } else {
165                 numIterations++;
166                 if (numIterations > kMaxIterations) {
167                   // This should never occur because this sampler should
168                   // (by the selection criteria above) be used if at least 3
169                   // standard deviations of one side of the distribution
170                   // is within the limits (so acceptance probability per
171                   // iterations >~ 1/2 per iteration).
172                   LOG(ERROR) << "TruncatedNormal randn rejection sampler "
173                              << "exceeded maximum iterations for "
174                              << "normMin=" << normMin << " normMax=" << normMax
175                              << " kMaxIterations=" << kMaxIterations;
176                   ctx->SetStatus(errors::Internal(
177                       "TruncatedNormal randn rejection sampler failed to accept"
178                       " a sample."));
179                   return;
180                 }
181               }
182             }
183           }
184         } else if (diff < cutoff) {
185           // Sample from a uniform distribution on [normMin, normMax].
186 
187           const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
188 
189           while (sample < limit_sample) {
190             const auto rand = dist(&gen_copy);
191             const int size = rand.size();
192             // NOTE(ringwalt): These loops seem to only generate packed AVX
193             // instructions for float32.
194             for (int i = 0; i < size; i++) {
195               z[i] = rand[i] * diff + normMin;
196             }
197             for (int i = 0; i < size; i++) {
198               g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
199             }
200 
201             const auto u = dist(&gen_copy);
202             for (int i = 0; i < size; i++) {
203               auto accept = u[i] <= Eigen::numext::exp(g[i]);
204               if (accept || numIterations + 1 >= kMaxIterations) {
205                 // Accept the sample z.
206                 // If we run out of iterations, just use the current uniform
207                 // sample, but emit a warning.
208                 // TODO(jjhunt) For small entropies (relative to the bounds),
209                 // this sampler is poor and may take many iterations since
210                 // the proposal distribution is the uniform distribution
211                 // U(lower_bound, upper_bound).
212                 if (!accept) {
213                   LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
214                              << "exceeded max iterations. Sample may contain "
215                              << "outliers.";
216                   ctx->SetStatus(errors::Internal(
217                       "TruncatedNormal uniform rejection sampler failed to "
218                       " accept a sample."));
219                   return;
220                 }
221                 output(sample) = z[i] * stddev + mean;
222                 sample++;
223                 if (sample >= limit_sample) {
224                   break;
225                 }
226                 numIterations = 0;
227               } else {
228                 numIterations++;
229               }
230             }
231           }
232         } else {
233           // Sample from an exponential distribution with alpha maximizing
234           // acceptance probability, offset by normMin from the origin.
235           // Accept only if less than normMax.
236           const T alpha =
237               (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
238               T(2);
239           while (sample < limit_sample) {
240             auto rand = dist(&gen_copy);
241             const int size = rand.size();
242             int i = 0;
243             while (i < size) {
244               const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
245               i++;
246               const T x = normMin < alpha ? alpha - z : normMin - alpha;
247               const T g = Eigen::numext::exp(-x * x / T(2.0));
248               const T u = rand[i];
249               i++;
250               auto accept = (u <= g && z < normMax);
251               if (accept || numIterations + 1 >= kMaxIterations) {
252                 if (!accept) {
253                   LOG(ERROR) << "TruncatedNormal exponential distribution "
254                              << "rejection sampler exceeds max iterations. "
255                              << "Sample may contain outliers.";
256                   ctx->SetStatus(errors::Internal(
257                       "TruncatedNormal exponential distribution rejection"
258                       " sampler failed to accept a sample."));
259                   return;
260                 }
261                 output(sample) = z * stddev + mean;
262                 sample++;
263                 if (sample >= limit_sample) {
264                   break;
265                 }
266                 numIterations = 0;
267               } else {
268                 numIterations++;
269               }
270             }
271           }
272         }
273       }
274     };
275     // The cost of the initial calculations for the batch.
276     const int64 batchInitCost =
277         // normMin, normMax
278         (Eigen::TensorOpCost::AddCost<T>() +
279          Eigen::TensorOpCost::MulCost<T>()) *
280             2
281         // sqrtFactor
282         + Eigen::TensorOpCost::AddCost<T>() +
283         Eigen::TensorOpCost::MulCost<T>() +
284         Eigen::internal::functor_traits<
285             Eigen::internal::scalar_sqrt_op<T>>::Cost
286         // cutoff
287         + Eigen::TensorOpCost::MulCost<T>() * 4 +
288         Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
289         // diff
290         + Eigen::TensorOpCost::AddCost<T>();
291     const int64 uniformSampleCost =
292         random::PhiloxRandom::kElementCost +
293         random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
294     // The cost of a single uniform sampling round.
295     const int64 uniformRejectionSamplingCost =
296         uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
297         Eigen::TensorOpCost::AddCost<T>() +
298         Eigen::TensorOpCost::MulCost<T>() * 2 +
299         Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
300         Eigen::internal::functor_traits<
301             Eigen::internal::scalar_exp_op<T>>::Cost +
302         Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
303     // Estimate the cost for an entire batch.
304     // Assume we use uniform sampling, and accept the 2nd sample on average.
305     const int64 batchCost =
306         batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
307     Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
308           batchCost, DoWork);
309   }
310 };
311 
312 }  // namespace functor
313 
314 namespace {
315 
316 // Samples from a truncated normal distribution, using the given parameters.
317 template <typename Device, typename T>
318 class ParameterizedTruncatedNormalOp : public OpKernel {
319   // Reshape batches so each batch is this size if possible.
320   static const int32 kDesiredBatchSize = 100;
321 
322  public:
ParameterizedTruncatedNormalOp(OpKernelConstruction * context)323   explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context)
324       : OpKernel(context) {
325     OP_REQUIRES_OK(context, generator_.Init(context));
326   }
327 
Compute(OpKernelContext * ctx)328   void Compute(OpKernelContext* ctx) override {
329     const Tensor& shape_tensor = ctx->input(0);
330     const Tensor& means_tensor = ctx->input(1);
331     const Tensor& stddevs_tensor = ctx->input(2);
332     const Tensor& minvals_tensor = ctx->input(3);
333     const Tensor& maxvals_tensor = ctx->input(4);
334 
335     OP_REQUIRES(
336         ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
337         errors::InvalidArgument("Input shape should be a vector, got shape: ",
338                                 shape_tensor.shape().DebugString()));
339     int32 num_batches = shape_tensor.flat<int32>()(0);
340 
341     int32 samples_per_batch = 1;
342     const int32 num_dims = shape_tensor.dim_size(0);
343     for (int32 i = 1; i < num_dims; i++) {
344       samples_per_batch *= shape_tensor.flat<int32>()(i);
345     }
346     const int32 num_elements = num_batches * samples_per_batch;
347 
348     // Allocate the output before fudging num_batches and samples_per_batch.
349     auto shape_vec = shape_tensor.flat<int32>();
350     TensorShape tensor_shape;
351     OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
352                             shape_vec.data(), shape_vec.size(), &tensor_shape));
353     Tensor* samples_tensor;
354     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
355 
356     // Parameters must be 0-d or 1-d.
357     OP_REQUIRES(ctx, means_tensor.dims() <= 1,
358                 errors::InvalidArgument(
359                     "Input means should be a scalar or vector, got shape: ",
360                     means_tensor.shape().DebugString()));
361     OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1,
362                 errors::InvalidArgument(
363                     "Input stddevs should be a scalar or vector, got shape: ",
364                     stddevs_tensor.shape().DebugString()));
365     OP_REQUIRES(ctx, minvals_tensor.dims() <= 1,
366                 errors::InvalidArgument(
367                     "Input minvals should be a scalar or vector, got shape: ",
368                     minvals_tensor.shape().DebugString()));
369     OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1,
370                 errors::InvalidArgument(
371                     "Input maxvals should be a scalar or vector, got shape: ",
372                     maxvals_tensor.shape().DebugString()));
373 
374     if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) &&
375         (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) &&
376         minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) {
377       // All batches have the same parameters, so we can update the batch size
378       // to a reasonable value to improve parallelism (ensure enough batches,
379       // and no very small batches which have high overhead).
380       int32 size = num_batches * samples_per_batch;
381       int32 adjusted_samples = kDesiredBatchSize;
382       // Ensure adjusted_batches * adjusted_samples >= size.
383       int32 adjusted_batches = Eigen::divup(size, adjusted_samples);
384       num_batches = adjusted_batches;
385       samples_per_batch = adjusted_samples;
386     } else {
387       // Parameters must be broadcastable to the shape [num_batches].
388       OP_REQUIRES(
389           ctx,
390           TensorShapeUtils::IsScalar(means_tensor.shape()) ||
391               means_tensor.dim_size(0) == 1 ||
392               means_tensor.dim_size(0) == num_batches,
393           errors::InvalidArgument(
394               "Input means should have length 1 or shape[0], got shape: ",
395               means_tensor.shape().DebugString()));
396       OP_REQUIRES(
397           ctx,
398           TensorShapeUtils::IsScalar(stddevs_tensor.shape()) ||
399               stddevs_tensor.dim_size(0) == 1 ||
400               stddevs_tensor.dim_size(0) == num_batches,
401           errors::InvalidArgument(
402               "Input stddevs should have length 1 or shape[0], got shape: ",
403               stddevs_tensor.shape().DebugString()));
404       OP_REQUIRES(
405           ctx,
406           TensorShapeUtils::IsScalar(minvals_tensor.shape()) ||
407               minvals_tensor.dim_size(0) == 1 ||
408               minvals_tensor.dim_size(0) == num_batches,
409           errors::InvalidArgument(
410               "Input minvals should have length 1 or shape[0], got shape: ",
411               minvals_tensor.shape().DebugString()));
412       OP_REQUIRES(
413           ctx,
414           TensorShapeUtils::IsScalar(maxvals_tensor.shape()) ||
415               maxvals_tensor.dim_size(0) == 1 ||
416               maxvals_tensor.dim_size(0) == num_batches,
417           errors::InvalidArgument(
418               "Input maxvals should have length 1 or shape[0], got shape: ",
419               maxvals_tensor.shape().DebugString()));
420     }
421 
422     auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>();
423     // Each worker has the fudge factor for samples_per_batch, so use it here.
424     random::PhiloxRandom rng =
425         generator_.ReserveSamples128(num_batches * 2 * functor::kMaxIterations *
426                                      (samples_per_batch + 3) / 4);
427     truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
428                  samples_per_batch, num_elements, means_tensor.flat<T>(),
429                  stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
430                  maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>());
431   }
432 
433  private:
434   GuardedPhiloxRandom generator_;
435 
436   TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
437 };
438 
439 }  // namespace
440 
441 #define REGISTER(TYPE)                                         \
442   REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
443                               .Device(DEVICE_CPU)              \
444                               .TypeConstraint<TYPE>("dtype"),  \
445                           ParameterizedTruncatedNormalOp<CPUDevice, TYPE>)
446 
447 TF_CALL_half(REGISTER);
448 TF_CALL_float(REGISTER);
449 TF_CALL_double(REGISTER);
450 
451 #undef REGISTER
452 
453 #if GOOGLE_CUDA
454 
455 #define REGISTER(TYPE)                                         \
456   REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
457                               .Device(DEVICE_GPU)              \
458                               .HostMemory("shape")             \
459                               .TypeConstraint<TYPE>("dtype"),  \
460                           ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
461 
462 TF_CALL_half(REGISTER);
463 TF_CALL_float(REGISTER);
464 TF_CALL_double(REGISTER);
465 
466 #undef REGISTER
467 
468 #endif  // GOOGLE_CUDA
469 
470 }  // end namespace tensorflow
471