1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/kernels/random_op_gpu.h"
21 #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
22 #include "tensorflow/core/util/cuda_launch_config.h"
23
24 namespace tensorflow {
25
26 using random::PhiloxRandom;
27
28 __device__ int thread_counter;
29
30 template <typename Distribution>
FillKernel(Distribution dist,int64 state_size,int64 output_size,StateElementType * state_data,typename Distribution::ResultElementType * output_data)31 __global__ void FillKernel(
32 Distribution dist, int64 state_size, int64 output_size,
33 StateElementType* state_data,
34 typename Distribution::ResultElementType* output_data) {
35 // Threads in this block share `philox`. Thread 0 is responsible for
36 // initializing it.
37 __shared__ char philox_raw[sizeof(PhiloxRandom)];
38 auto philox = reinterpret_cast<PhiloxRandom*>(philox_raw);
39 if (threadIdx.x == 0) {
40 *philox = GetPhiloxRandomFromMem(state_data);
41 }
42 __syncthreads();
43 functor::FillPhiloxRandomKernel<Distribution,
44 Distribution::kVariableSamplesPerOutput>()
45 .Run(*philox, output_data, output_size, dist);
46 // The last thread updates the state.
47 auto total_thread_count = gridDim.x * blockDim.x;
48 auto old_counter_value = atomicAdd(&thread_counter, 1);
49 if (old_counter_value == total_thread_count - 1) {
50 UpdateMemWithPhiloxRandom(*philox, output_size, state_data);
51 }
52 }
53
54 template <typename Distribution>
operator ()(OpKernelContext * ctx,const GPUDevice & d,int64 output_size,int64 alg_tag_skip,ScopedUnlockUnrefVar * not_used,Tensor * state_tensor,typename Distribution::ResultElementType * output_data)55 void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
56 OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
57 int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
58 typename Distribution::ResultElementType* output_data) {
59 OP_REQUIRES(
60 ctx, alg_tag_skip == 0,
61 errors::InvalidArgument(
62 "GPU kernel doesn't support reading algorithm from state variable, "
63 "so alg_tag_skip must be 0; got",
64 alg_tag_skip));
65 auto state_tensor_flat = state_tensor->flat<StateElementType>();
66 auto state_size = state_tensor_flat.size();
67 auto state_data = state_tensor_flat.data();
68
69 // maximize occupancy
70 const int kGroupSize = Distribution::kResultElementCount;
71 int work_element_count = (output_size + kGroupSize - 1) / kGroupSize;
72 CudaLaunchConfig cfg = GetCudaLaunchConfig(work_element_count, d,
73 FillKernel<Distribution>, 0, 0);
74
75 int zero = 0;
76 cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
77 TF_CHECK_OK(CudaLaunchKernel(FillKernel<Distribution>, cfg.block_count,
78 cfg.thread_per_block, 0, d.stream(),
79 Distribution(), state_size, output_size,
80 state_data, output_data));
81 }
82
83 // Explicit instantiation of the GPU distributions functors.
84
85 // clang-format off
86 // NVCC cannot handle ">>" properly
87 template struct UpdateVariableAndFill_Philox<
88 GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
89 template struct UpdateVariableAndFill_Philox<
90 GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
91 template struct UpdateVariableAndFill_Philox<
92 GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
93 // clang-format on
94
95 } // end namespace tensorflow
96
97 #endif // GOOGLE_CUDA
98