1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #if GOOGLE_CUDA
22 #include "third_party/gpus/cuda/include/cuda_fp16.h"
23 #endif
24 #include "tensorflow/core/util/gpu_cuda_alias.h"
25 #include "tensorflow/core/util/gpu_device_functions.h"
26 #include "tensorflow/core/util/gpu_launch_config.h"
27
28 #if GOOGLE_CUDA
29 #define TF_RED_WARPSIZE 32
30 #elif TENSORFLOW_USE_ROCM
31 #define TF_RED_WARPSIZE 64
32 #endif
33
34 // Deprecated, use 'for(int i : GpuGridRangeX(n))' instead.
35 #define GPU_1D_KERNEL_LOOP(i, n) \
36 for (int i : ::tensorflow::GpuGridRangeX<int>(n))
37 #define CUDA_1D_KERNEL_LOOP(i, n) \
38 for (int i : ::tensorflow::GpuGridRangeX<int>(n))
39
40 // Deprecated, use 'for(int i : GpuGridRange?(n))' instead.
41 #define GPU_AXIS_KERNEL_LOOP(i, n, axis) \
42 for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
43 #define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
44 for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
45
46 #if GOOGLE_CUDA
47 #define gpuSuccess cudaSuccess
48 using gpuStream_t = cudaStream_t;
49 using gpuError_t = cudaError_t;
50 #elif TENSORFLOW_USE_ROCM
51 #define gpuSuccess hipSuccess
52 using gpuStream_t = hipStream_t;
53 using gpuError_t = hipError_t;
54 #endif
55
56 // macro wrapper to declare dynamic shared memory
57 #if GOOGLE_CUDA
58
59 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
60 extern __shared__ __align__(ALIGN) TYPE NAME[]
61
62 #elif TENSORFLOW_USE_ROCM
63
64 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
65 HIP_DYNAMIC_SHARED(TYPE, NAME)
66
67 #endif
68
69 namespace tensorflow {
70
71 #if GOOGLE_CUDA
72 // cudaGetErrorString is available to both host and device
GpuGetErrorString(cudaError_t error)73 __host__ __device__ inline const char* GpuGetErrorString(cudaError_t error) {
74 return cudaGetErrorString(error);
75 }
76 #elif TENSORFLOW_USE_ROCM
77 // hipGetErrorString is available on host side only
78 inline const char* GpuGetErrorString(hipError_t error) {
79 return hipGetErrorString(error);
80 }
81 #endif
82
83 // Returns a raw reference to the current cuda stream. Required by a
84 // number of kernel calls (for which StreamInterface* does not work),
85 // i.e. CUB and certain cublas primitives.
GetGpuStream(OpKernelContext * context)86 inline const gpuStream_t& GetGpuStream(OpKernelContext* context) {
87 const gpuStream_t* ptr = CHECK_NOTNULL(
88 reinterpret_cast<const gpuStream_t*>(context->op_device_context()
89 ->stream()
90 ->implementation()
91 ->GpuStreamMemberHack()));
92 return *ptr;
93 }
94
95 // Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or
96 // hipLaunchKernel in ROCm environment with the given arguments.
97 //
98 // The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
99 template <typename... Ts, typename... Args>
GpuLaunchKernel(void (* function)(Ts...),dim3 grid_dim,dim3 block_dim,size_t shared_memory_size_bytes,gpuStream_t stream,Args...arguments)100 Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
101 size_t shared_memory_size_bytes, gpuStream_t stream,
102 Args... arguments) {
103 static_assert(detail::NoneIsReference<Ts...>(),
104 "Kernels with reference arguments have undefined behaviour.");
105 #if GOOGLE_CUDA
106 auto func_ptr = absl::bit_cast<const void*>(function);
107 // Cast arguments and forward them as an array of pointers.
108 auto args_tuple = std::tuple<Ts...>(arguments...);
109 auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
110 auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
111 shared_memory_size_bytes, stream);
112 if (result != cudaSuccess) {
113 return errors::Internal(cudaGetErrorString(result));
114 }
115 #elif TENSORFLOW_USE_ROCM
116 hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes,
117 stream, std::forward<Args>(arguments)...);
118 #endif
119 return Status::OK();
120 }
121
122 // Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA
123 // builds
124 template <typename... Args>
125 auto CudaLaunchKernel(Args&&... args)
126 -> decltype(GpuLaunchKernel(std::forward<Args>(args)...)) {
127 return GpuLaunchKernel(std::forward<Args>(args)...);
128 }
129
GpuLdg(const tensorflow::bfloat16 * address)130 __host__ __device__ inline tensorflow::bfloat16 GpuLdg(
131 const tensorflow::bfloat16* address) {
132 tensorflow::bfloat16 return_value;
133 return_value.value = GpuLdg(reinterpret_cast<const uint16_t*>(address));
134 return return_value;
135 }
136 // Already aliased in gpu_device_functions.h
137
138 template <typename T>
ldg(const T * ptr)139 __host__ __device__ inline T ldg(const T* ptr) {
140 return GpuLdg(ptr);
141 }
142
143 template <typename T>
tf_min(const T & x,const T & y)144 __host__ __device__ inline const T& tf_min(const T& x, const T& y) {
145 return x < y ? x : y;
146 }
147
148 template <typename T>
tf_max(const T & x,const T & y)149 __host__ __device__ inline const T& tf_max(const T& x, const T& y) {
150 return x < y ? y : x;
151 }
152
153 // Overloads of the above functions for float and double.
tf_min(float x,float y)154 __host__ __device__ inline float tf_min(float x, float y) {
155 return fminf(x, y);
156 }
tf_min(double x,double y)157 __host__ __device__ inline double tf_min(double x, double y) {
158 return fmin(x, y);
159 }
tf_max(float x,float y)160 __host__ __device__ inline float tf_max(float x, float y) {
161 return fmaxf(x, y);
162 }
tf_max(double x,double y)163 __host__ __device__ inline double tf_max(double x, double y) {
164 return fmax(x, y);
165 }
166
167 // ROCM TODO re-enable them after adding fp16 support logic
168 #if GOOGLE_CUDA
169 __device__ inline Eigen::half GpuShuffleSync(unsigned mask, Eigen::half value,
170 int src_lane,
171 int width = warpSize) {
172 return Eigen::half(
173 GpuShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
174 }
175 // Aliased in gpu_device_functions.h
176
177 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleUpSync(
178 unsigned mask, Eigen::half value, int delta, int width = warpSize) {
179 return Eigen::half(
180 GpuShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
181 }
182 // Aliased in gpu_device_functions.h
183
184 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleDownSync(
185 unsigned mask, Eigen::half value, int delta, int width = warpSize) {
186 return Eigen::half(
187 GpuShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
188 }
189 // Aliased in gpu_device_functions.h
190
191 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleXorSync(
192 unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
193 return Eigen::half(
194 GpuShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
195 }
196 // Aliased in gpu_device_functions.h
197 #endif
198
199 namespace gpu_helper {
200 template <typename T, typename OutType = int32>
upper_bound(const T * first,OutType count,T val)201 __device__ OutType upper_bound(const T* first, OutType count, T val) {
202 const T* orig = first;
203 const T* it = nullptr;
204 OutType step = 0;
205 while (count > 0) {
206 it = first;
207 step = count / 2;
208 it += step;
209 if (!(val < *it)) {
210 first = ++it;
211 count -= step + 1;
212 } else {
213 count = step;
214 }
215 }
216
217 return first - orig;
218 }
219
220 template <typename T, typename OutType = int32>
lower_bound(const T * first,OutType count,T val)221 __device__ OutType lower_bound(const T* first, OutType count, T val) {
222 const T* orig = first;
223 const T* it = nullptr;
224 OutType step = 0;
225 while (count > 0) {
226 it = first;
227 step = count / 2;
228 it += step;
229 if (*it < val) {
230 first = ++it;
231 count -= step + 1;
232 } else {
233 count = step;
234 }
235 }
236
237 return first - orig;
238 }
239
240 } // namespace gpu_helper
241
242 #ifndef TENSORFLOW_USE_ROCM
243 namespace cuda_helper = gpu_helper;
244 #endif
245
246 } // namespace tensorflow
247
248 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
249 #endif // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
250