1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/random_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/random_op.h"
21
22 #include <algorithm>
23 #include <cmath>
24 #include <memory>
25
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/lib/hash/crc32c.h"
31 #include "tensorflow/core/lib/random/random_distributions.h"
32 #include "tensorflow/core/lib/random/simple_philox.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/util/guarded_philox_random.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
38 #define DISABLE_FLOAT_EQUALITY_WARNING \
39 _Pragma("GCC diagnostic push") \
40 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
41 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
42 #else
43 #define DISABLE_FLOAT_EQUALITY_WARNING
44 #define ENABLE_FLOAT_EQUALITY_WARNING
45 #endif
46
47 namespace tensorflow {
48
49 typedef Eigen::ThreadPoolDevice CPUDevice;
50 typedef Eigen::GpuDevice GPUDevice;
51 #ifdef TENSORFLOW_USE_SYCL
52 typedef Eigen::SyclDevice SYCLDevice;
53 #endif // TENSORFLOW_USE_SYCL
54
55 namespace functor {
56 using random::PhiloxRandom;
57 using random::SingleSampleAdapter;
58
59 // The default implementation of the functor, which should never be invoked
60 // But we still need to provide implementation for now for the linker to work,
61 // since we do not support all the distributions yet.
62 template <typename Device, class Distribution>
63 struct FillPhiloxRandom {
64 typedef typename Distribution::ResultElementType T;
operator ()tensorflow::functor::FillPhiloxRandom65 void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
66 T* data, int64 size, Distribution dist) {
67 LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
68 }
69 };
70
71 // A class to fill a specified range of random groups
72 template <class Distribution, bool VariableSamplesPerOutput>
73 struct FillPhiloxRandomTask;
74
75 // Specialization for distribution that takes a fixed number of samples for
76 // each output.
77 template <class Distribution>
78 struct FillPhiloxRandomTask<Distribution, false> {
79 typedef typename Distribution::ResultElementType T;
Runtensorflow::functor::FillPhiloxRandomTask80 static void Run(random::PhiloxRandom gen, T* data, int64 size,
81 int64 start_group, int64 limit_group, Distribution dist) {
82 const int kGroupSize = Distribution::kResultElementCount;
83
84 gen.Skip(start_group);
85 int64 offset = start_group * kGroupSize;
86
87 // First fill all the full-size groups
88 int64 limit_group_full = std::min(limit_group, size / kGroupSize);
89 for (int64 index = start_group; index < limit_group_full; ++index) {
90 auto samples = dist(&gen);
91 std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
92 offset += kGroupSize;
93 }
94
95 // If there are any remaining elements that need to be filled, process them
96 if (limit_group_full < limit_group) {
97 int64 remaining_size = size - limit_group_full * kGroupSize;
98 auto samples = dist(&gen);
99 std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
100 }
101 }
102 };
103
104 // Specialization for distribution that takes a variable number of samples for
105 // each output. This will be slower due to the generality.
106 template <class Distribution>
107 struct FillPhiloxRandomTask<Distribution, true> {
108 typedef typename Distribution::ResultElementType T;
109 static const int64 kReservedSamplesPerOutput = 256;
110
Runtensorflow::functor::FillPhiloxRandomTask111 static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
112 int64 start_group, int64 limit_group, Distribution dist) {
113 const int kGroupSize = Distribution::kResultElementCount;
114
115 static const int kGeneratorSkipPerOutputGroup =
116 kGroupSize * kReservedSamplesPerOutput /
117 PhiloxRandom::kResultElementCount;
118
119 int64 offset = start_group * kGroupSize;
120
121 // First fill all the full-size groups
122 int64 limit_group_full = std::min(limit_group, size / kGroupSize);
123 int64 group_index;
124 for (group_index = start_group; group_index < limit_group_full;
125 ++group_index) {
126 // Reset the generator to the beginning of the output group region
127 // This is necessary if we want the results to be independent of order
128 // of work
129 PhiloxRandom gen = base_gen;
130 gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
131 SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
132
133 auto samples = dist(&single_samples);
134 std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
135 offset += kGroupSize;
136 }
137
138 // If there are any remaining elements that need to be filled, process them
139 if (limit_group_full < limit_group) {
140 PhiloxRandom gen = base_gen;
141 gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
142 SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
143
144 int64 remaining_size = size - limit_group_full * kGroupSize;
145 auto samples = dist(&single_samples);
146 std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
147 }
148 }
149 };
150
151 // Partial specialization for CPU to fill the entire region with randoms
152 // It splits the work into several tasks and run them in parallel
153 template <class Distribution>
operator ()(OpKernelContext * context,const CPUDevice &,random::PhiloxRandom gen,typename Distribution::ResultElementType * data,int64 size,Distribution dist)154 void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
155 OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
156 typename Distribution::ResultElementType* data, int64 size,
157 Distribution dist) {
158 const int kGroupSize = Distribution::kResultElementCount;
159
160 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
161
162 int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
163
164 const int kGroupCost =
165 random::PhiloxRandom::kResultElementCount *
166 (random::PhiloxRandom::kElementCost + Distribution::kElementCost);
167 Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
168 kGroupCost,
169 [&gen, data, size, dist](int64 start_group, int64 limit_group) {
170 FillPhiloxRandomTask<
171 Distribution,
172 Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
173 start_group,
174 limit_group, dist);
175 });
176 }
177
178 } // namespace functor
179
180 namespace {
181
AllocateOutputWithShape(OpKernelContext * ctx,const Tensor & shape,int index,Tensor ** output)182 static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
183 int index, Tensor** output) {
184 TensorShape tensor_shape;
185 TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape));
186 return ctx->allocate_output(index, tensor_shape, output);
187 }
188
189 // For now, use the same interface as RandomOp, so we can choose either one
190 // at the run-time.
191 template <typename Device, class Distribution>
192 class PhiloxRandomOp : public OpKernel {
193 public:
194 typedef typename Distribution::ResultElementType T;
PhiloxRandomOp(OpKernelConstruction * ctx)195 explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
196 OP_REQUIRES_OK(ctx, generator_.Init(ctx));
197 }
198
Compute(OpKernelContext * ctx)199 void Compute(OpKernelContext* ctx) override {
200 const Tensor& shape = ctx->input(0);
201 Tensor* output;
202 OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
203 auto output_flat = output->flat<T>();
204 functor::FillPhiloxRandom<Device, Distribution>()(
205 ctx, ctx->eigen_device<Device>(),
206 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
207 // it just here.
208 generator_.ReserveRandomOutputs(output_flat.size(), 256),
209 output_flat.data(), output_flat.size(), Distribution());
210 }
211
212 private:
213 GuardedPhiloxRandom generator_;
214 };
215
216 template <typename Device, class IntType>
217 class RandomUniformIntOp : public OpKernel {
218 public:
RandomUniformIntOp(OpKernelConstruction * ctx)219 explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
220 OP_REQUIRES_OK(ctx, generator_.Init(ctx));
221 }
222
Compute(OpKernelContext * ctx)223 void Compute(OpKernelContext* ctx) override {
224 const Tensor& shape = ctx->input(0);
225 const Tensor& minval = ctx->input(1);
226 const Tensor& maxval = ctx->input(2);
227 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
228 errors::InvalidArgument("minval must be 0-D, got shape ",
229 minval.shape().DebugString()));
230 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
231 errors::InvalidArgument("maxval must be 0-D, got shape ",
232 maxval.shape().DebugString()));
233
234 // Allocate output, and exit early if possible
235 Tensor* output;
236 OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
237 if (output->NumElements() == 0) return;
238
239 // Verify that minval < maxval. This check intentionally happens after the
240 // early exit for empty output. Zero impossible things are fine.
241 IntType lo = minval.scalar<IntType>()();
242 IntType hi = maxval.scalar<IntType>()();
243 OP_REQUIRES(
244 ctx, lo < hi,
245 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
246
247 // Build distribution
248 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
249 Distribution;
250 Distribution dist(lo, hi);
251
252 auto output_flat = output->flat<IntType>();
253 functor::FillPhiloxRandom<Device, Distribution>()(
254 ctx, ctx->eigen_device<Device>(),
255 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
256 // it just here.
257 generator_.ReserveRandomOutputs(output_flat.size(), 256),
258 output_flat.data(), output_flat.size(), dist);
259 }
260
261 private:
262 GuardedPhiloxRandom generator_;
263 };
264
265 // Samples from one or more gamma distributions. All internal computations are
266 // done with double precision for numerical stability.
267 template <typename T>
268 class RandomGammaOp : public OpKernel {
269 public:
RandomGammaOp(OpKernelConstruction * context)270 explicit RandomGammaOp(OpKernelConstruction* context) : OpKernel(context) {
271 OP_REQUIRES_OK(context, generator_.Init(context));
272 }
273
Compute(OpKernelContext * ctx)274 void Compute(OpKernelContext* ctx) override {
275 const Tensor& shape_t = ctx->input(0);
276 const Tensor& alpha_t = ctx->input(1);
277
278 OP_REQUIRES(ctx,
279 TensorShapeUtils::IsVector(shape_t.shape()) &&
280 (shape_t.dtype() == DataType::DT_INT32 ||
281 shape_t.dtype() == DataType::DT_INT64),
282 errors::InvalidArgument(
283 "shape must be a vector of {int32,int64}, got shape: ",
284 shape_t.DebugString()));
285 TensorShape samples_shape;
286 if (shape_t.dtype() == DataType::DT_INT32) {
287 auto vec = shape_t.flat<int32>();
288 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
289 &samples_shape));
290 } else if (shape_t.dtype() == DataType::DT_INT64) {
291 auto vec = shape_t.flat<int64>();
292 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
293 &samples_shape));
294 }
295 const int64 num_samples = samples_shape.num_elements();
296
297 samples_shape.AppendShape(alpha_t.shape());
298 // Allocate output samples.
299 Tensor* samples_t = nullptr;
300 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
301
302 if (num_samples == 0) return;
303
304 using random::PhiloxRandom;
305
306 typedef random::NormalDistribution<PhiloxRandom, double> Normal;
307 typedef random::UniformDistribution<PhiloxRandom, double> Uniform;
308 #define UNIFORM(X) \
309 if (uniform_remaining == 0) { \
310 uniform_remaining = Uniform::kResultElementCount; \
311 uniform_result = uniform(&gen); \
312 } \
313 uniform_remaining--; \
314 double X = uniform_result[uniform_remaining]
315
316 // Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
317 static constexpr int kReservedSamplesPerOutput = 256;
318
319 const auto alpha_flat = alpha_t.flat<T>().data();
320 const int64 num_alphas = alpha_t.NumElements();
321 OP_REQUIRES(ctx, num_alphas > 0,
322 errors::InvalidArgument(
323 "Input alpha should have non-zero element count, got: ",
324 num_alphas));
325 auto samples_flat = samples_t->flat<T>().data();
326 PhiloxRandom rng = generator_.ReserveRandomOutputs(
327 num_samples * num_alphas, kReservedSamplesPerOutput);
328
329 // We partition work first across alphas then across samples-per-alpha to
330 // avoid a couple flops which can be done on a per-alpha basis.
331
332 auto DoWork = [num_samples, num_alphas, &rng, samples_flat, alpha_flat](
333 int start_output, int limit_output) {
334 using Eigen::numext::exp;
335 using Eigen::numext::log;
336 using Eigen::numext::pow;
337
338 // Capturing "rng" by-value would only make a copy for the _shared_
339 // lambda. Since we want to let each worker have its own copy, we pass
340 // "rng" by reference and explicitly do a copy assignment.
341
342 Normal normal;
343 Uniform uniform;
344 typename Normal::ResultType norm_result;
345 typename Uniform::ResultType uniform_result;
346 for (int64 output_idx = start_output; output_idx < limit_output;
347 /* output_idx incremented within inner loop below */) {
348 int64 alpha_idx = output_idx / num_samples;
349
350 // Instead of +alpha_idx for each sample, we offset the pointer once.
351 T* const samples_alpha_offset = samples_flat + alpha_idx;
352
353 // Several calculations can be done on a per-alpha basis.
354 const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
355
356 DISABLE_FLOAT_EQUALITY_WARNING
357 if (alpha == double(1.0)) {
358 ENABLE_FLOAT_EQUALITY_WARNING
359 // Sample from an exponential distribution.
360 for (int64 sample_idx = output_idx % num_samples;
361 sample_idx < num_samples && output_idx < limit_output;
362 sample_idx++, output_idx++) {
363 // As we want data stable regardless of sharding
364 // (including eventually on GPU), we skip on a per-sample basis.
365 PhiloxRandom gen = rng;
366 gen.Skip(kReservedSamplesPerOutput * output_idx);
367 short uniform_remaining = 0;
368 UNIFORM(u);
369 const double res = -log(1.0 - u);
370 samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
371 } // for (sample_idx)
372 } else { // if alpha != 1.0
373 // Transformation-rejection from pairs of uniform and normal random
374 // variables. http://dl.acm.org/citation.cfm?id=358414
375 //
376 // The algorithm has an acceptance rate of ~95% for small alpha (~1),
377 // and higher accept rates for higher alpha, so runtime is
378 // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
379 //
380 // For alpha<1, we add one to d=alpha-1/3, and multiply the final
381 // result by uniform()^(1/alpha)
382 const bool alpha_less_than_one = alpha < 1;
383 const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
384 const double c = 1.0 / 3 / sqrt(d);
385
386 // Compute the rest of the samples for the current alpha value.
387 for (int64 sample_idx = output_idx % num_samples;
388 sample_idx < num_samples && output_idx < limit_output;
389 sample_idx++, output_idx++) {
390 // Since each sample may use a variable number of normal/uniform
391 // samples, and we want data stable regardless of sharding
392 // (including eventually on GPU), we skip on a per-sample basis.
393 PhiloxRandom gen = rng;
394 gen.Skip(kReservedSamplesPerOutput * output_idx);
395 short norm_remaining = 0;
396 short uniform_remaining = 0;
397
398 // Keep trying until we don't reject a sample. In practice, we will
399 // only reject ~5% at worst, for low alpha near 1.
400 while (true) {
401 if (norm_remaining == 0) {
402 norm_remaining = Normal::kResultElementCount;
403 norm_result = normal(&gen);
404 }
405 norm_remaining--;
406 const double x = norm_result[norm_remaining];
407 double v = 1 + c * x;
408 if (v <= 0) {
409 continue;
410 }
411 v = v * v * v;
412 UNIFORM(u);
413 // The first option in the if is a "squeeze" short-circuit to
414 // dodge the two logs. Magic constant sourced from the paper
415 // linked above. Upward of .91 of the area covered by the log
416 // inequality is covered by the squeeze as well (larger coverage
417 // for smaller values of alpha).
418 if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
419 (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
420 double res = d * v;
421 if (alpha_less_than_one) {
422 UNIFORM(b);
423 res *= pow(b, 1 / alpha);
424 }
425 samples_alpha_offset[sample_idx * num_alphas] =
426 static_cast<T>(res);
427 break;
428 }
429 } // while: true
430 } // for: sample_idx
431 } // if (alpha == 1.0)
432 } // for: output_idx
433 }; // DoWork
434 #undef UNIFORM
435 // Two calls to log only occur for ~10% of samples reaching the log line.
436 // 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
437 // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
438 // each = ~60.
439 // All of this /0.95 due to the rejection possibility = ~85.
440 static const int kElementCost = 85 + 2 * Normal::kElementCost +
441 Uniform::kElementCost +
442 3 * PhiloxRandom::kElementCost;
443 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
444 Shard(worker_threads.num_threads, worker_threads.workers,
445 num_alphas * num_samples, kElementCost, DoWork);
446 }
447
448 private:
449 GuardedPhiloxRandom generator_;
450
451 TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp);
452 };
453
454 } // namespace
455
456 #define REGISTER(TYPE) \
457 template struct functor::FillPhiloxRandom< \
458 CPUDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
459 template struct functor::FillPhiloxRandom< \
460 CPUDevice, random::NormalDistribution<random::PhiloxRandom, TYPE>>; \
461 template struct functor::FillPhiloxRandom< \
462 CPUDevice, \
463 random::TruncatedNormalDistribution< \
464 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>; \
465 REGISTER_KERNEL_BUILDER( \
466 Name("RandomUniform") \
467 .Device(DEVICE_CPU) \
468 .HostMemory("shape") \
469 .TypeConstraint<TYPE>("dtype"), \
470 PhiloxRandomOp<CPUDevice, random::UniformDistribution< \
471 random::PhiloxRandom, TYPE>>); \
472 REGISTER_KERNEL_BUILDER( \
473 Name("RandomStandardNormal") \
474 .Device(DEVICE_CPU) \
475 .HostMemory("shape") \
476 .TypeConstraint<TYPE>("dtype"), \
477 PhiloxRandomOp<CPUDevice, \
478 random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
479 REGISTER_KERNEL_BUILDER( \
480 Name("TruncatedNormal") \
481 .Device(DEVICE_CPU) \
482 .HostMemory("shape") \
483 .TypeConstraint<TYPE>("dtype"), \
484 PhiloxRandomOp< \
485 CPUDevice, \
486 random::TruncatedNormalDistribution< \
487 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); \
488 REGISTER_KERNEL_BUILDER( \
489 Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
490 RandomGammaOp<TYPE>)
491
492 #define REGISTER_INT(IntType) \
493 template struct functor::FillPhiloxRandom< \
494 CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
495 REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
496 .Device(DEVICE_CPU) \
497 .HostMemory("shape") \
498 .HostMemory("minval") \
499 .HostMemory("maxval") \
500 .TypeConstraint<IntType>("Tout"), \
501 RandomUniformIntOp<CPUDevice, IntType>);
502
503 TF_CALL_half(REGISTER);
504 TF_CALL_bfloat16(REGISTER);
505 TF_CALL_float(REGISTER);
506 TF_CALL_double(REGISTER);
507 TF_CALL_int32(REGISTER_INT);
508 TF_CALL_int64(REGISTER_INT);
509
510 #undef REGISTER
511 #undef REGISTER_INT
512
513 #if GOOGLE_CUDA
514
515 #define REGISTER(TYPE) \
516 REGISTER_KERNEL_BUILDER( \
517 Name("RandomUniform") \
518 .Device(DEVICE_GPU) \
519 .HostMemory("shape") \
520 .TypeConstraint<int32>("T") \
521 .TypeConstraint<TYPE>("dtype"), \
522 PhiloxRandomOp<GPUDevice, random::UniformDistribution< \
523 random::PhiloxRandom, TYPE>>); \
524 REGISTER_KERNEL_BUILDER( \
525 Name("RandomStandardNormal") \
526 .Device(DEVICE_GPU) \
527 .HostMemory("shape") \
528 .TypeConstraint<int32>("T") \
529 .TypeConstraint<TYPE>("dtype"), \
530 PhiloxRandomOp<GPUDevice, \
531 random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
532 REGISTER_KERNEL_BUILDER( \
533 Name("TruncatedNormal") \
534 .Device(DEVICE_GPU) \
535 .HostMemory("shape") \
536 .TypeConstraint<int32>("T") \
537 .TypeConstraint<TYPE>("dtype"), \
538 PhiloxRandomOp< \
539 GPUDevice, \
540 random::TruncatedNormalDistribution< \
541 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
542
543 #define REGISTER_INT(IntType) \
544 template struct functor::FillPhiloxRandom< \
545 GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
546 REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
547 .Device(DEVICE_GPU) \
548 .HostMemory("shape") \
549 .HostMemory("minval") \
550 .HostMemory("maxval") \
551 .TypeConstraint<int32>("T") \
552 .TypeConstraint<IntType>("Tout"), \
553 RandomUniformIntOp<GPUDevice, IntType>);
554
555 TF_CALL_half(REGISTER);
556 TF_CALL_float(REGISTER);
557 TF_CALL_double(REGISTER);
558 TF_CALL_int32(REGISTER_INT);
559 TF_CALL_int64(REGISTER_INT);
560
561 #undef REGISTER
562 #undef REGISTER_INT
563
564 #endif // GOOGLE_CUDA
565
566 #ifdef TENSORFLOW_USE_SYCL
567
568 namespace functor {
569
570 using namespace cl;
571
572 template <class Distribution, bool VariableSamplesPerOutput>
573 struct FillPhiloxRandomKernel;
574
575 template <class Distribution>
576 struct FillPhiloxRandomKernel<Distribution, false> {
577 typedef typename Distribution::ResultElementType T;
578 using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
579 sycl::access::target::global_buffer>;
580
FillPhiloxRandomKerneltensorflow::functor::FillPhiloxRandomKernel581 FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
582 Distribution& dist)
583 : data_(data), gen_(gen), dist_(dist) {}
584
operator ()tensorflow::functor::FillPhiloxRandomKernel585 void operator()(sycl::nd_item<1> item) {
586 const size_t kGroupSize = Distribution::kResultElementCount;
587
588 const size_t item_id = item.get_global(0);
589 const size_t total_item_count = item.get_global_range();
590 size_t offset = item_id * kGroupSize;
591 gen_.Skip(item_id);
592
593 const size_t size = data_.get_size() / sizeof(T);
594 T* data = ConvertToActualTypeSycl(T, data_);
595
596 while (offset + kGroupSize <= size) {
597 const typename Distribution::ResultType samples = dist_(&gen_);
598 for (size_t i = 0; i < kGroupSize; ++i) {
599 data[offset + i] = samples[i];
600 }
601
602 offset += (total_item_count - 1) * kGroupSize;
603 gen_.Skip(total_item_count - 1);
604 }
605
606 const typename Distribution::ResultType samples = dist_(&gen_);
607 for (size_t i = 0; i < kGroupSize; ++i) {
608 if (offset >= size) {
609 return;
610 }
611 data[offset] = samples[i];
612 ++offset;
613 }
614 }
615
616 private:
617 write_accessor data_;
618 random::PhiloxRandom gen_;
619 Distribution dist_;
620 };
621
622 template <class Distribution>
623 struct FillPhiloxRandomKernel<Distribution, true> {
624 typedef typename Distribution::ResultElementType T;
625 using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
626 sycl::access::target::global_buffer>;
627
FillPhiloxRandomKerneltensorflow::functor::FillPhiloxRandomKernel628 FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
629 Distribution& dist)
630 : data_(data), gen_(gen), dist_(dist) {}
631
operator ()tensorflow::functor::FillPhiloxRandomKernel632 void operator()(sycl::nd_item<1> item) {
633 using random::PhiloxRandom;
634 using random::SingleSampleAdapter;
635
636 const size_t kReservedSamplesPerOutput = 256;
637 const size_t kGroupSize = Distribution::kResultElementCount;
638 const size_t kGeneratorSkipPerOutputGroup =
639 kGroupSize * kReservedSamplesPerOutput /
640 PhiloxRandom::kResultElementCount;
641
642 const size_t item_id = item.get_global(0);
643 const size_t total_item_count = item.get_global_range();
644 size_t group_index = item_id;
645 size_t offset = group_index * kGroupSize;
646
647 T* data = ConvertToActualTypeSycl(T, data_);
648 const size_t size = data_.get_size() / sizeof(T);
649
650 while (offset < size) {
651 // Since each output takes a variable number of samples, we need to
652 // realign the generator to the beginning for the current output group
653 PhiloxRandom gen = gen_;
654 gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
655 SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
656
657 const typename Distribution::ResultType samples = dist_(&single_samples);
658
659 for (size_t i = 0; i < kGroupSize; ++i) {
660 if (offset >= size) {
661 return;
662 }
663 data[offset] = samples[i];
664 ++offset;
665 }
666
667 offset += (total_item_count - 1) * kGroupSize;
668 group_index += total_item_count;
669 }
670 }
671
672 private:
673 write_accessor data_;
674 random::PhiloxRandom gen_;
675 Distribution dist_;
676 };
677
678 template <typename T>
679 class FillRandomKernel;
680 // Partial specialization for SYCL to fill the entire region with randoms
681 // It splits the work into several tasks and run them in parallel
682 template <class Distribution>
operator ()(OpKernelContext * context,const SYCLDevice & device,random::PhiloxRandom gen,typename Distribution::ResultElementType * data,int64 size,Distribution dist)683 void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
684 OpKernelContext* context, const SYCLDevice& device,
685 random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
686 int64 size, Distribution dist) {
687 const size_t group_size = device.maxSyclThreadsPerBlock();
688 const size_t group_count = (size + group_size - 1) / group_size;
689
690 auto buffer = device.get_sycl_buffer(data);
691
692 device.sycl_queue().submit([&](sycl::handler& cgh) {
693 auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
694
695 FillPhiloxRandomKernel<Distribution,
696 Distribution::kVariableSamplesPerOutput>
697 task(access, gen, dist);
698 cgh.parallel_for<class FillRandomKernel<Distribution>>(
699 sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
700 sycl::range<1>(group_size)),
701 task);
702 });
703 }
704
705 } // namespace functor
706
707 #define REGISTER(TYPE) \
708 template struct functor::FillPhiloxRandom< \
709 SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
710 REGISTER_KERNEL_BUILDER( \
711 Name("RandomUniform") \
712 .Device(DEVICE_SYCL) \
713 .HostMemory("shape") \
714 .TypeConstraint<TYPE>("dtype"), \
715 PhiloxRandomOp<SYCLDevice, random::UniformDistribution< \
716 random::PhiloxRandom, TYPE>>); \
717 REGISTER_KERNEL_BUILDER( \
718 Name("RandomStandardNormal") \
719 .Device(DEVICE_SYCL) \
720 .HostMemory("shape") \
721 .TypeConstraint<TYPE>("dtype"), \
722 PhiloxRandomOp<SYCLDevice, \
723 random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
724 REGISTER_KERNEL_BUILDER( \
725 Name("TruncatedNormal") \
726 .Device(DEVICE_SYCL) \
727 .HostMemory("shape") \
728 .TypeConstraint<TYPE>("dtype"), \
729 PhiloxRandomOp< \
730 SYCLDevice, \
731 random::TruncatedNormalDistribution< \
732 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
733
734 #define REGISTER_INT(IntType) \
735 REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
736 .Device(DEVICE_SYCL) \
737 .HostMemory("shape") \
738 .HostMemory("minval") \
739 .HostMemory("maxval") \
740 .TypeConstraint<IntType>("Tout"), \
741 RandomUniformIntOp<SYCLDevice, IntType>);
742
743 TF_CALL_float(REGISTER);
744 TF_CALL_double(REGISTER);
745 TF_CALL_int32(REGISTER_INT);
746 TF_CALL_int64(REGISTER_INT);
747
748 #undef REGISTER
749 #undef REGISTER_INT
750
751 #endif // TENSORFLOW_USE_SYCL
752
753 } // end namespace tensorflow
754