1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/random_ops.cc. 17 // NOTE: If the algorithm is changed, please run the test 18 // .../python/kernel_tests:parameterized_truncated_normal_op_test 19 // commenting out the "tf.set_random_seed(seed)" lines, and using the 20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the 21 // op results. 22 23 #define EIGEN_USE_THREADS 24 25 #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" 26 27 #include <algorithm> 28 #include <cmath> 29 #include <memory> 30 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/lib/random/random_distributions.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/util/guarded_philox_random.h" 38 #include "tensorflow/core/util/work_sharder.h" 39 40 namespace tensorflow { 41 42 typedef Eigen::ThreadPoolDevice CPUDevice; 43 typedef Eigen::GpuDevice GPUDevice; 44 45 namespace functor { 46 using random::PhiloxRandom; 47 48 static constexpr int kMaxIterations = 1000; 49 50 template <typename T> 51 struct TruncatedNormalFunctor<CPUDevice, T> { operator ()tensorflow::functor::TruncatedNormalFunctor52 void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, 53 int64 samples_per_batch, int64 num_elements, 54 typename TTypes<T>::ConstFlat means, 55 typename TTypes<T>::ConstFlat stddevs, 56 typename TTypes<T>::ConstFlat minvals, 57 typename TTypes<T>::ConstFlat maxvals, 58 const random::PhiloxRandom& gen, 59 typename TTypes<T>::Flat output) { 60 // The randn rejection sampling is used when the mean and at least this many 61 // standard deviations are inside the bounds. 62 // The uniform proposal samplers become less efficient as the bounds are 63 // further from the mean, the reverse is true for the randn sampler. 64 // This number was chosen by empirical benchmarking. If modified, the 65 // benchmarks in parameterized_truncated_normal_op_test should also be 66 // changed. 67 const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3); 68 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 69 70 auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs, 71 &minvals, &maxvals, &gen, &output, 72 kStdDevsInsideBoundsToUseRandnSampler](int start_batch, 73 int limit_batch) { 74 // Capturing "gen" by-value would only make a copy for the _shared_ 75 // lambda. Since we want to let each worker have its own copy, we pass 76 // "gen" by reference and explicitly do a copy assignment here. 77 random::PhiloxRandom gen_copy = gen; 78 // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to 79 // us using the same state in different batches. 80 // The sample from each iteration uses 2 random numbers. 81 gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) / 82 4); 83 typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform; 84 Uniform dist; 85 typedef random::NormalDistribution<random::PhiloxRandom, T> Normal; 86 Normal normal_dist; 87 88 // Vectorized intermediate calculations for uniform rejection sampling. 89 // We always generate at most 4 samples. 90 Eigen::array<T, 4> z; 91 Eigen::array<T, 4> g; 92 93 for (int64 b = start_batch; b < limit_batch; ++b) { 94 // We are passed a flat array for each of the parameter tensors. 95 // The input is either a scalar broadcasted to all batches or a vector 96 // with length num_batches, but the scalar becomes an array of length 1. 97 T mean = means((means.dimension(0) == 1) ? 0 : b); 98 T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b); 99 T minval = minvals((minvals.dimension(0) == 1) ? 0 : b); 100 T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b); 101 102 // The last batch can be short, if we adjusted num_batches and 103 // samples_per_batch. 104 const int64 limit_sample = 105 std::min((b + 1) * samples_per_batch, num_elements); 106 int64 sample = b * samples_per_batch; 107 108 // On GPU, this check will just fill samples with NAN if it fails. 109 OP_REQUIRES(ctx, 110 stddev > T(0) && minval < maxval && 111 (Eigen::numext::isfinite(minval) || 112 Eigen::numext::isfinite(maxval)), 113 errors::InvalidArgument("Invalid parameters")); 114 115 int numIterations = 0; 116 117 // If possible, make one-sided bound be the lower bound, or make both 118 // bounds positive. Otherwise, the bounds are on either side of the 119 // mean. 120 if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) { 121 // Reverse all calculations. normMin and normMax will be flipped. 122 std::swap(minval, maxval); 123 stddev = -stddev; 124 } 125 126 // Calculate normalized samples, then convert them. 127 const T normMin = (minval - mean) / stddev; 128 const T normMax = (maxval - mean) / stddev; 129 130 // Determine the method to use. 131 const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); 132 const T cutoff = 133 T(2) * 134 Eigen::numext::exp(T(0.5) + 135 (normMin * (normMin - sqrtFactor)) / T(4)) / 136 (normMin + sqrtFactor); 137 const T diff = normMax - normMin; 138 139 if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) && 140 (normMax >= T(0.))) || 141 ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && 142 (normMin <= T(0.)))) { 143 // If the bounds are a least 3 standard deviations from the mean 144 // on at least one side then we rejection sample by sampling 145 // from the normal distribution and rejecting samples outside 146 // the bounds. 147 // Under this condition the acceptance rate per iteration should 148 // always be ~ 50%. This sampler is more efficient (and more 149 // numerically stable when one or both bounds is far from the mean). 150 151 while (sample < limit_sample) { 152 const auto randn_sample = normal_dist(&gen_copy); 153 const int size = randn_sample.size(); 154 155 for (int i = 0; i < size; i++) { 156 if ((randn_sample[i] >= normMin) && 157 (randn_sample[i] <= normMax)) { 158 output(sample) = randn_sample[i] * stddev + mean; 159 sample++; 160 if (sample >= limit_sample) { 161 break; 162 } 163 numIterations = 0; 164 } else { 165 numIterations++; 166 if (numIterations > kMaxIterations) { 167 // This should never occur because this sampler should 168 // (by the selection criteria above) be used if at least 3 169 // standard deviations of one side of the distribution 170 // is within the limits (so acceptance probability per 171 // iterations >~ 1/2 per iteration). 172 LOG(ERROR) << "TruncatedNormal randn rejection sampler " 173 << "exceeded maximum iterations for " 174 << "normMin=" << normMin << " normMax=" << normMax 175 << " kMaxIterations=" << kMaxIterations; 176 ctx->SetStatus(errors::Internal( 177 "TruncatedNormal randn rejection sampler failed to accept" 178 " a sample.")); 179 return; 180 } 181 } 182 } 183 } 184 } else if (diff < cutoff) { 185 // Sample from a uniform distribution on [normMin, normMax]. 186 187 const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; 188 189 while (sample < limit_sample) { 190 const auto rand = dist(&gen_copy); 191 const int size = rand.size(); 192 // NOTE(ringwalt): These loops seem to only generate packed AVX 193 // instructions for float32. 194 for (int i = 0; i < size; i++) { 195 z[i] = rand[i] * diff + normMin; 196 } 197 for (int i = 0; i < size; i++) { 198 g[i] = (plusFactor - z[i] * z[i]) / T(2.0); 199 } 200 201 const auto u = dist(&gen_copy); 202 for (int i = 0; i < size; i++) { 203 auto accept = u[i] <= Eigen::numext::exp(g[i]); 204 if (accept || numIterations + 1 >= kMaxIterations) { 205 // Accept the sample z. 206 // If we run out of iterations, just use the current uniform 207 // sample, but emit a warning. 208 // TODO(jjhunt) For small entropies (relative to the bounds), 209 // this sampler is poor and may take many iterations since 210 // the proposal distribution is the uniform distribution 211 // U(lower_bound, upper_bound). 212 if (!accept) { 213 LOG(ERROR) << "TruncatedNormal uniform rejection sampler " 214 << "exceeded max iterations. Sample may contain " 215 << "outliers."; 216 ctx->SetStatus(errors::Internal( 217 "TruncatedNormal uniform rejection sampler failed to " 218 " accept a sample.")); 219 return; 220 } 221 output(sample) = z[i] * stddev + mean; 222 sample++; 223 if (sample >= limit_sample) { 224 break; 225 } 226 numIterations = 0; 227 } else { 228 numIterations++; 229 } 230 } 231 } 232 } else { 233 // Sample from an exponential distribution with alpha maximizing 234 // acceptance probability, offset by normMin from the origin. 235 // Accept only if less than normMax. 236 const T alpha = 237 (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / 238 T(2); 239 while (sample < limit_sample) { 240 auto rand = dist(&gen_copy); 241 const int size = rand.size(); 242 int i = 0; 243 while (i < size) { 244 const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; 245 i++; 246 const T x = normMin < alpha ? alpha - z : normMin - alpha; 247 const T g = Eigen::numext::exp(-x * x / T(2.0)); 248 const T u = rand[i]; 249 i++; 250 auto accept = (u <= g && z < normMax); 251 if (accept || numIterations + 1 >= kMaxIterations) { 252 if (!accept) { 253 LOG(ERROR) << "TruncatedNormal exponential distribution " 254 << "rejection sampler exceeds max iterations. " 255 << "Sample may contain outliers."; 256 ctx->SetStatus(errors::Internal( 257 "TruncatedNormal exponential distribution rejection" 258 " sampler failed to accept a sample.")); 259 return; 260 } 261 output(sample) = z * stddev + mean; 262 sample++; 263 if (sample >= limit_sample) { 264 break; 265 } 266 numIterations = 0; 267 } else { 268 numIterations++; 269 } 270 } 271 } 272 } 273 } 274 }; 275 // The cost of the initial calculations for the batch. 276 const int64 batchInitCost = 277 // normMin, normMax 278 (Eigen::TensorOpCost::AddCost<T>() + 279 Eigen::TensorOpCost::MulCost<T>()) * 280 2 281 // sqrtFactor 282 + Eigen::TensorOpCost::AddCost<T>() + 283 Eigen::TensorOpCost::MulCost<T>() + 284 Eigen::internal::functor_traits< 285 Eigen::internal::scalar_sqrt_op<T>>::Cost 286 // cutoff 287 + Eigen::TensorOpCost::MulCost<T>() * 4 + 288 Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost 289 // diff 290 + Eigen::TensorOpCost::AddCost<T>(); 291 const int64 uniformSampleCost = 292 random::PhiloxRandom::kElementCost + 293 random::UniformDistribution<random::PhiloxRandom, T>::kElementCost; 294 // The cost of a single uniform sampling round. 295 const int64 uniformRejectionSamplingCost = 296 uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() + 297 Eigen::TensorOpCost::AddCost<T>() + 298 Eigen::TensorOpCost::MulCost<T>() * 2 + 299 Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost + 300 Eigen::internal::functor_traits< 301 Eigen::internal::scalar_exp_op<T>>::Cost + 302 Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>(); 303 // Estimate the cost for an entire batch. 304 // Assume we use uniform sampling, and accept the 2nd sample on average. 305 const int64 batchCost = 306 batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch; 307 Shard(worker_threads.num_threads, worker_threads.workers, num_batches, 308 batchCost, DoWork); 309 } 310 }; 311 312 } // namespace functor 313 314 namespace { 315 316 // Samples from a truncated normal distribution, using the given parameters. 317 template <typename Device, typename T> 318 class ParameterizedTruncatedNormalOp : public OpKernel { 319 // Reshape batches so each batch is this size if possible. 320 static const int32 kDesiredBatchSize = 100; 321 322 public: ParameterizedTruncatedNormalOp(OpKernelConstruction * context)323 explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context) 324 : OpKernel(context) { 325 OP_REQUIRES_OK(context, generator_.Init(context)); 326 } 327 Compute(OpKernelContext * ctx)328 void Compute(OpKernelContext* ctx) override { 329 const Tensor& shape_tensor = ctx->input(0); 330 const Tensor& means_tensor = ctx->input(1); 331 const Tensor& stddevs_tensor = ctx->input(2); 332 const Tensor& minvals_tensor = ctx->input(3); 333 const Tensor& maxvals_tensor = ctx->input(4); 334 335 OP_REQUIRES( 336 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), 337 errors::InvalidArgument("Input shape should be a vector, got shape: ", 338 shape_tensor.shape().DebugString())); 339 int32 num_batches = shape_tensor.flat<int32>()(0); 340 341 int32 samples_per_batch = 1; 342 const int32 num_dims = shape_tensor.dim_size(0); 343 for (int32 i = 1; i < num_dims; i++) { 344 samples_per_batch *= shape_tensor.flat<int32>()(i); 345 } 346 const int32 num_elements = num_batches * samples_per_batch; 347 348 // Allocate the output before fudging num_batches and samples_per_batch. 349 auto shape_vec = shape_tensor.flat<int32>(); 350 TensorShape tensor_shape; 351 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( 352 shape_vec.data(), shape_vec.size(), &tensor_shape)); 353 Tensor* samples_tensor; 354 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); 355 356 // Parameters must be 0-d or 1-d. 357 OP_REQUIRES(ctx, means_tensor.dims() <= 1, 358 errors::InvalidArgument( 359 "Input means should be a scalar or vector, got shape: ", 360 means_tensor.shape().DebugString())); 361 OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1, 362 errors::InvalidArgument( 363 "Input stddevs should be a scalar or vector, got shape: ", 364 stddevs_tensor.shape().DebugString())); 365 OP_REQUIRES(ctx, minvals_tensor.dims() <= 1, 366 errors::InvalidArgument( 367 "Input minvals should be a scalar or vector, got shape: ", 368 minvals_tensor.shape().DebugString())); 369 OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1, 370 errors::InvalidArgument( 371 "Input maxvals should be a scalar or vector, got shape: ", 372 maxvals_tensor.shape().DebugString())); 373 374 if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) && 375 (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) && 376 minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) { 377 // All batches have the same parameters, so we can update the batch size 378 // to a reasonable value to improve parallelism (ensure enough batches, 379 // and no very small batches which have high overhead). 380 int32 size = num_batches * samples_per_batch; 381 int32 adjusted_samples = kDesiredBatchSize; 382 // Ensure adjusted_batches * adjusted_samples >= size. 383 int32 adjusted_batches = Eigen::divup(size, adjusted_samples); 384 num_batches = adjusted_batches; 385 samples_per_batch = adjusted_samples; 386 } else { 387 // Parameters must be broadcastable to the shape [num_batches]. 388 OP_REQUIRES( 389 ctx, 390 TensorShapeUtils::IsScalar(means_tensor.shape()) || 391 means_tensor.dim_size(0) == 1 || 392 means_tensor.dim_size(0) == num_batches, 393 errors::InvalidArgument( 394 "Input means should have length 1 or shape[0], got shape: ", 395 means_tensor.shape().DebugString())); 396 OP_REQUIRES( 397 ctx, 398 TensorShapeUtils::IsScalar(stddevs_tensor.shape()) || 399 stddevs_tensor.dim_size(0) == 1 || 400 stddevs_tensor.dim_size(0) == num_batches, 401 errors::InvalidArgument( 402 "Input stddevs should have length 1 or shape[0], got shape: ", 403 stddevs_tensor.shape().DebugString())); 404 OP_REQUIRES( 405 ctx, 406 TensorShapeUtils::IsScalar(minvals_tensor.shape()) || 407 minvals_tensor.dim_size(0) == 1 || 408 minvals_tensor.dim_size(0) == num_batches, 409 errors::InvalidArgument( 410 "Input minvals should have length 1 or shape[0], got shape: ", 411 minvals_tensor.shape().DebugString())); 412 OP_REQUIRES( 413 ctx, 414 TensorShapeUtils::IsScalar(maxvals_tensor.shape()) || 415 maxvals_tensor.dim_size(0) == 1 || 416 maxvals_tensor.dim_size(0) == num_batches, 417 errors::InvalidArgument( 418 "Input maxvals should have length 1 or shape[0], got shape: ", 419 maxvals_tensor.shape().DebugString())); 420 } 421 422 auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>(); 423 // Each worker has the fudge factor for samples_per_batch, so use it here. 424 random::PhiloxRandom rng = 425 generator_.ReserveSamples128(num_batches * 2 * functor::kMaxIterations * 426 (samples_per_batch + 3) / 4); 427 truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches, 428 samples_per_batch, num_elements, means_tensor.flat<T>(), 429 stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(), 430 maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>()); 431 } 432 433 private: 434 GuardedPhiloxRandom generator_; 435 436 TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp); 437 }; 438 439 } // namespace 440 441 #define REGISTER(TYPE) \ 442 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ 443 .Device(DEVICE_CPU) \ 444 .TypeConstraint<TYPE>("dtype"), \ 445 ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) 446 447 TF_CALL_half(REGISTER); 448 TF_CALL_float(REGISTER); 449 TF_CALL_double(REGISTER); 450 451 #undef REGISTER 452 453 #if GOOGLE_CUDA 454 455 #define REGISTER(TYPE) \ 456 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ 457 .Device(DEVICE_GPU) \ 458 .HostMemory("shape") \ 459 .TypeConstraint<TYPE>("dtype"), \ 460 ParameterizedTruncatedNormalOp<GPUDevice, TYPE>) 461 462 TF_CALL_half(REGISTER); 463 TF_CALL_float(REGISTER); 464 TF_CALL_double(REGISTER); 465 466 #undef REGISTER 467 468 #endif // GOOGLE_CUDA 469 470 } // end namespace tensorflow 471