1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ 17 #define TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ 18 19 #if defined(__CUDACC__) 20 21 #include "tensorflow/core/lib/random/philox_random.h" 22 #include "tensorflow/core/lib/random/random_distributions.h" 23 24 namespace tensorflow { 25 26 namespace functor { 27 28 template <class Distribution, bool VariableSamplesPerOutput> 29 struct FillPhiloxRandomKernel; 30 31 template <class Distribution> 32 struct FillPhiloxRandomKernel<Distribution, false> { 33 typedef typename Distribution::ResultElementType T; 34 PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size, 35 Distribution dist); 36 }; 37 38 template <class Distribution> 39 struct FillPhiloxRandomKernel<Distribution, true> { 40 typedef typename Distribution::ResultElementType T; 41 PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data, 42 int64 size, Distribution dist); 43 }; 44 45 template <typename T, int ElementCount> 46 class SampleCopier { 47 public: 48 inline __device__ void operator()( 49 T* buf, const tensorflow::random::Array<T, ElementCount>& array) const { 50 #pragma unroll 51 for (int i = 0; i < ElementCount; i++) { 52 buf[i] = array[i]; 53 } 54 } 55 }; 56 57 template <> 58 class SampleCopier<float, 4> { 59 public: 60 // Copies the elements from the array to buf. buf must be 128-bit aligned, 61 // which is true for tensor data, and all offsets that are a multiple of the 62 // vector size (because the vectors are 128 bits long). 63 inline __device__ void operator()( 64 float* buf, const tensorflow::random::Array<float, 4>& array) const { 65 // NOTE(ringwalt): It's not safe to cast &array[0] to a float4, because they 66 // have 32-bit alignment vs 128-bit alignment. There seems to be no 67 // performance loss when assigning each element to a vector. 68 float4 vec; 69 vec.x = array[0]; 70 vec.y = array[1]; 71 vec.z = array[2]; 72 vec.w = array[3]; 73 float4* buf_vector = reinterpret_cast<float4*>(buf); 74 *buf_vector = vec; 75 } 76 }; 77 78 template <> 79 class SampleCopier<int32, 4> { 80 public: 81 // Copies the elements from the array to buf. buf must be 128-bit aligned, 82 // which is true for tensor data, and all offsets that are a multiple of the 83 // vector size (because the vectors are 128 bits long). 84 inline __device__ void operator()( 85 int32* buf, const tensorflow::random::Array<int32, 4>& array) const { 86 int4 vec; 87 vec.x = array[0]; 88 vec.y = array[1]; 89 vec.z = array[2]; 90 vec.w = array[3]; 91 int4* buf_vector = reinterpret_cast<int4*>(buf); 92 *buf_vector = vec; 93 } 94 }; 95 96 template <> 97 class SampleCopier<double, 2> { 98 public: 99 // Copies the elements from the array to buf. buf must be 128-bit aligned, 100 // which is true for tensor data, and all offsets that are a multiple of the 101 // vector size (because the vectors are 128 bits long). 102 inline __device__ void operator()( 103 double* buf, const tensorflow::random::Array<double, 2>& array) const { 104 double2 vec; 105 vec.x = array[0]; 106 vec.y = array[1]; 107 double2* buf_vector = reinterpret_cast<double2*>(buf); 108 *buf_vector = vec; 109 } 110 }; 111 112 template <> 113 class SampleCopier<int64, 2> { 114 public: 115 // Copies the elements from the array to buf. buf must be 128-bit aligned, 116 // which is true for tensor data, and all offsets that are a multiple of the 117 // vector size (because the vectors are 128 bits long). 118 inline __device__ void operator()( 119 int64* buf, const tensorflow::random::Array<int64, 2>& array) const { 120 longlong2 vec; 121 vec.x = array[0]; 122 vec.y = array[1]; 123 longlong2* buf_vector = reinterpret_cast<longlong2*>(buf); 124 *buf_vector = vec; 125 } 126 }; 127 128 // A cuda kernel to fill the data with random numbers from the specified 129 // distribution. Each output takes a fixed number of samples. 130 template <class Distribution> 131 PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, false>::Run( 132 random::PhiloxRandom gen, T* data, int64 size, Distribution dist) { 133 const int kGroupSize = Distribution::kResultElementCount; 134 135 const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; 136 const int32 total_thread_count = gridDim.x * blockDim.x; 137 int32 offset = thread_id * kGroupSize; 138 gen.Skip(thread_id); 139 140 const SampleCopier<T, kGroupSize> copier; 141 while (offset + kGroupSize <= size) { 142 const typename Distribution::ResultType samples = dist(&gen); 143 copier(&data[offset], samples); 144 145 offset += total_thread_count * kGroupSize; 146 gen.Skip(total_thread_count - 1); 147 } 148 149 typename Distribution::ResultType samples = dist(&gen); 150 for (int i = 0; i < kGroupSize; ++i) { 151 if (offset >= size) { 152 return; 153 } 154 data[offset] = samples[i]; 155 ++offset; 156 } 157 } 158 159 // A cuda kernel to fill the data with random numbers from the specified 160 // distribution. Each output takes a variable number of samples. 161 template <class Distribution> 162 PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, true>::Run( 163 const random::PhiloxRandom& base_gen, T* data, int64 size, 164 Distribution dist) { 165 using random::PhiloxRandom; 166 using random::SingleSampleAdapter; 167 168 const int kReservedSamplesPerOutput = 256; 169 const int kGroupSize = Distribution::kResultElementCount; 170 const int kGeneratorSkipPerOutputGroup = kGroupSize * 171 kReservedSamplesPerOutput / 172 PhiloxRandom::kResultElementCount; 173 174 const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; 175 const int32 total_thread_count = gridDim.x * blockDim.x; 176 int64 group_index = thread_id; 177 int64 offset = group_index * kGroupSize; 178 179 while (offset < size) { 180 // Since each output takes a variable number of samples, we need to 181 // realign the generator to the beginning for the current output group 182 PhiloxRandom gen = base_gen; 183 gen.Skip(group_index * kGeneratorSkipPerOutputGroup); 184 SingleSampleAdapter<PhiloxRandom> single_samples(&gen); 185 186 typename Distribution::ResultType samples = dist(&single_samples); 187 188 for (int i = 0; i < kGroupSize; ++i) { 189 if (offset >= size) { 190 return; 191 } 192 data[offset] = samples[i]; 193 ++offset; 194 } 195 196 offset += (total_thread_count - 1) * kGroupSize; 197 group_index += total_thread_count; 198 } 199 } 200 201 } // namespace functor 202 } // namespace tensorflow 203 204 #endif // defined(__CUDACC__) 205 206 #endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ 207