1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/kernels/random_op.h"
21 #include "tensorflow/core/kernels/random_op_gpu.h"
22
23 #include <assert.h>
24 #include <stdio.h>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/lib/random/philox_random.h"
29 #include "tensorflow/core/lib/random/random_distributions.h"
30 #include "tensorflow/core/util/cuda_kernel_helper.h"
31
32 namespace tensorflow {
33
34 class OpKernelContext;
35
36 namespace functor {
37
38 typedef Eigen::GpuDevice GPUDevice;
39
40 // A simple launch pad to call the correct function templates to fill the data
41 template <class Distribution>
42 __global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,typename Distribution::ResultElementType * data,int64 size,Distribution dist)43 FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
44 typename Distribution::ResultElementType* data,
45 int64 size, Distribution dist) {
46 FillPhiloxRandomKernel<Distribution,
47 Distribution::kVariableSamplesPerOutput>()
48 .Run(base_gen, data, size, dist);
49 }
50
51 // Partial specialization for GPU
52 template <class Distribution>
operator ()(OpKernelContext *,const GPUDevice & d,random::PhiloxRandom gen,typename Distribution::ResultElementType * data,int64 size,Distribution dist)53 void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
54 OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
55 typename Distribution::ResultElementType* data, int64 size,
56 Distribution dist) {
57 const int32 block_size = d.maxGpuThreadsPerBlock();
58 const int32 num_blocks =
59 (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
60 block_size;
61
62 TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
63 num_blocks, block_size, 0, d.stream(), gen, data,
64 size, dist));
65 }
66
67 // Explicit instantiation of the GPU distributions functors
68 // clang-format off
69 // NVCC cannot handle ">>" properly
70 template struct FillPhiloxRandom<
71 GPUDevice, random::UniformDistribution<random::PhiloxRandom, Eigen::half> >;
72 template struct FillPhiloxRandom<
73 GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
74 template struct FillPhiloxRandom<
75 GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
76 template struct FillPhiloxRandom<
77 GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
78 template struct FillPhiloxRandom<
79 GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
80 template struct FillPhiloxRandom<
81 GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
82 template struct FillPhiloxRandom<
83 GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
84 template struct FillPhiloxRandom<
85 GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
86 template struct FillPhiloxRandom<
87 GPUDevice, random::TruncatedNormalDistribution<
88 random::SingleSampleAdapter<random::PhiloxRandom>, Eigen::half> >;
89 template struct FillPhiloxRandom<
90 GPUDevice, random::TruncatedNormalDistribution<
91 random::SingleSampleAdapter<random::PhiloxRandom>, float> >;
92 template struct FillPhiloxRandom<
93 GPUDevice, random::TruncatedNormalDistribution<
94 random::SingleSampleAdapter<random::PhiloxRandom>, double> >;
95 // clang-format on
96
97 } // namespace functor
98 } // namespace tensorflow
99
100 #endif // GOOGLE_CUDA
101