1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
18
19 /**
20 * Wrappers and helpers for CUDA device code.
21 *
22 * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
23 * backwards compatibility, see go/volta-porting for details.
24 * Provides atomic operations on types that aren't natively supported.
25 * Defines a number of macros and types providing a shared interface
26 * to either CUDA or ROCm APIs, depending on the build.
27 */
28
29 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30
31 #include <algorithm>
32 #include <complex>
33
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #if GOOGLE_CUDA
36 #include "third_party/gpus/cuda/include/cuda.h"
37 #else
38 #include "rocm/include/hip/hip_complex.h"
39 #endif
40
41 #include "tensorflow/core/platform/types.h"
42 #include "tensorflow/core/util/gpu_cuda_alias.h"
43
44 #if GOOGLE_CUDA
45 using gpuStream_t = cudaStream_t;
46 using gpuEvent_t = cudaEvent_t;
47 #define gpuEventRecord cudaEventRecord
48 #define gpuEventSynchronize cudaEventSynchronize
49 #define gpuEventDestroy cudaEventDestroy
50 #define gpuEventCreate cudaEventCreate
51 #define gpuEventCreateWithFlags cudaEventCreateWithFlags
52 #define gpuEventDisableTiming cudaEventDisableTiming
53 #define gpuDeviceSynchronize cudaDeviceSynchronize
54 #define gpuFree cudaFree
55 #elif TENSORFLOW_USE_ROCM
56 using gpuStream_t = hipStream_t;
57 using gpuEvent_t = hipEvent_t;
58 using cudaError = int;
59 using cudaError_t = int;
60 #define cudaSuccess 0
61 #define cudaGetLastError hipGetLastError
62 #define gpuEventRecord hipEventRecord
63 #define gpuEventDestroy hipEventDestroy
64 #define gpuEventSynchronize hipEventSynchronize
65 #define gpuEventCreate hipEventCreate
66 #define gpuEventCreateWithFlags hipEventCreateWithFlags
67 #define gpuEventDisableTiming hipEventDisableTiming
68 #define gpuDeviceSynchronize hipDeviceSynchronize
69 #define gpuFree hipFree
cudaGetErrorString(int err)70 static std::string cudaGetErrorString(int err) { return std::to_string(err); }
71 #endif
72
73 #define TF_RETURN_IF_CUDA_ERROR(result) \
74 do { \
75 cudaError_t error(result); \
76 if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \
77 return errors::Internal("Cuda call failed with ", \
78 cudaGetErrorString(error)); \
79 } \
80 } while (0)
81
82 #define TF_OP_REQUIRES_CUDA_SUCCESS(context, result) \
83 do { \
84 cudaError_t error(result); \
85 if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \
86 context->SetStatus(errors::Internal("Cuda call failed with", \
87 cudaGetErrorString(error))); \
88 return; \
89 } \
90 } while (0)
91
92 namespace tensorflow {
93 // According to HIP developer guide at
94 // https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#assert
95 // assert is not supported by HIP. While we are waiting for assert support in
96 // hip kernels, the assert call should be macroed to NOP so that it does not
97 // block us from creating a debug build
98 #if TENSORFLOW_USE_ROCM
99 #undef assert
100 #define assert(x) \
101 {}
102 #endif
103
104 namespace detail {
105
106 // Helper for range-based for loop using 'delta' increments.
107 // Usage: see GpuGridRange?() functions below.
108 template <typename T>
109 class GpuGridRange {
110 struct Iterator {
IteratorIterator111 __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
112 __device__ T operator*() const { return index_; }
113 __device__ Iterator& operator++() {
114 index_ += delta_;
115 return *this;
116 }
117 __device__ bool operator!=(const Iterator& other) const {
118 bool greater = index_ > other.index_;
119 bool less = index_ < other.index_;
120 // Anything past an end iterator (delta_ == 0) is equal.
121 // In range-based for loops, this optimizes to 'return less'.
122 if (!other.delta_) {
123 return less;
124 }
125 if (!delta_) {
126 return greater;
127 }
128 return less || greater;
129 }
130
131 private:
132 T index_;
133 const T delta_;
134 };
135
136 public:
GpuGridRange(T begin,T delta,T end)137 __device__ GpuGridRange(T begin, T delta, T end)
138 : begin_(begin), delta_(delta), end_(end) {}
139
begin()140 __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
end()141 __device__ Iterator end() const { return Iterator{end_, 0}; }
142
143 private:
144 T begin_;
145 T delta_;
146 T end_;
147 };
148
149 #ifndef TENSORFLOW_USE_ROCM
150 template <typename... T>
151 using CudaGridRange = GpuGridRange<T...>;
152 #endif
153 } // namespace detail
154
155 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate
156 // of the global thread index. That is, each index i is visited by all threads
157 // with the same x-coordinate.
158 // Usage: for(int i : GpuGridRangeX(count)) { visit(i); }
159 template <typename T>
GpuGridRangeX(T count)160 __device__ detail::GpuGridRange<T> GpuGridRangeX(T count) {
161 return detail::GpuGridRange<T>(
162 /*begin=*/blockIdx.x * blockDim.x + threadIdx.x,
163 /*delta=*/gridDim.x * blockDim.x, /*end=*/count);
164 }
165 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeX, CudaGridRangeX);
166
167 // Helper to visit indices in the range 0 <= i < count using the y-coordinate.
168 // Usage: for(int i : GpuGridRangeY(count)) { visit(i); }
169 template <typename T>
GpuGridRangeY(T count)170 __device__ detail::GpuGridRange<T> GpuGridRangeY(T count) {
171 return detail::GpuGridRange<T>(
172 /*begin=*/blockIdx.y * blockDim.y + threadIdx.y,
173 /*delta=*/gridDim.y * blockDim.y, /*end=*/count);
174 }
175 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeY, CudaGridRangeY);
176
177 // Helper to visit indices in the range 0 <= i < count using the z-coordinate.
178 // Usage: for(int i : GpuGridRangeZ(count)) { visit(i); }
179 template <typename T>
GpuGridRangeZ(T count)180 __device__ detail::GpuGridRange<T> GpuGridRangeZ(T count) {
181 return detail::GpuGridRange<T>(
182 /*begin=*/blockIdx.z * blockDim.z + threadIdx.z,
183 /*delta=*/gridDim.z * blockDim.z, /*end=*/count);
184 }
185 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeZ, CudaGridRangeZ);
186
187 // Mask for all 32 threads in a warp.
188 __device__ const unsigned kCudaWarpAll = 0xffffffff;
189 // ROCM TODO add ROCM implementation
190 // Mask for all 64 threads in a wavefront.
191 __device__ const unsigned kGpuWarpAll = 0xffffffff;
192
193 // Returns the warp lane ID of the calling thread
GpuLaneId()194 __device__ inline unsigned GpuLaneId() {
195 unsigned int lane_id;
196 #if GOOGLE_CUDA
197 #if __clang__
198 return __nvvm_read_ptx_sreg_laneid();
199 #else // __clang__
200 asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
201 #endif // __clang__
202 #elif TENSORFLOW_USE_ROCM
203 lane_id = __lane_id();
204 #endif
205 return lane_id;
206 }
207 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLaneId, CudaLaneId);
208
209 namespace detail {
210 // Returns true if mask is a valid parameter for __shfl*sync to return a well
211 // defined value, assuming the calling lane will read from src_lane as part of
212 // the shuffle operation.
213 //
214 // Specifically, returns true iff mask has the calling lane bit and the src_lane
215 // bit set, and the src_lane calls this function with the same mask value
216 // (required for the two threads to wait for each other).
217 //
218 // On Volta, for some invalid masks, this function hangs or returns false
219 // positives, because the implementation shuffles with the same mask that
220 // we are validating. Run on Pascal if you suspect that the mask is incorrect.
GpuValidateShuffleSyncMask(unsigned mask,unsigned src_lane)221 __device__ inline bool GpuValidateShuffleSyncMask(unsigned mask,
222 unsigned src_lane) {
223 unsigned src_dst_mask = 1u << GpuLaneId() | 1u << src_lane;
224 #if CUDA_VERSION >= 9000
225 unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
226 #else
227 #if GOOGLE_CUDA
228 unsigned src_lane_mask = __shfl(mask, src_lane);
229 #elif TENSORFLOW_USE_ROCM
230 unsigned src_lane_mask =
231 __shfl(static_cast<int>(mask), static_cast<int>(src_lane));
232 #endif
233 #endif
234 return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
235 }
236 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuValidateShuffleSyncMask,
237 CudaValidateShuffleSyncMask);
238
239 // Returns the actual source lane for shuffle.
GpuShuffleGetSrcLane(int src_lane,int width)240 __device__ inline unsigned GpuShuffleGetSrcLane(int src_lane, int width) {
241 int lane_id = GpuLaneId();
242 int lane_base = lane_id & ~width + 1;
243 int lane_offset = src_lane & width - 1;
244 return lane_base + lane_offset;
245 }
246 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleGetSrcLane, CudaShuffleGetSrcLane);
247
248 // Returns the source lane for shuffle up.
GpuShuffleUpGetSrcLane(unsigned delta,int width)249 __device__ inline unsigned GpuShuffleUpGetSrcLane(unsigned delta, int width) {
250 unsigned lane_id = GpuLaneId();
251 if ((lane_id & width - 1) < delta) {
252 return lane_id;
253 }
254 return lane_id - delta;
255 }
256 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpGetSrcLane,
257 CudaShuffleUpGetSrcLane);
258
259 // Returns the source lane for shuffle down.
GpuShuffleDownGetSrcLane(unsigned delta,int width)260 __device__ inline unsigned GpuShuffleDownGetSrcLane(unsigned delta, int width) {
261 unsigned lane_id = GpuLaneId();
262 if ((lane_id & width - 1) + delta >= width) {
263 return lane_id;
264 }
265 return lane_id + delta;
266 }
267 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownGetSrcLane,
268 CudaShuffleDownGetSrcLane);
269
270 // Returns the source lane for shuffle xor.
GpuShuffleXorGetSrcLane(int lane_mask,int width)271 __device__ inline unsigned GpuShuffleXorGetSrcLane(int lane_mask, int width) {
272 int lane_id = GpuLaneId();
273 int src_lane = lane_id ^ lane_mask;
274 if (src_lane > (lane_id | width - 1)) {
275 return lane_id;
276 }
277 return src_lane;
278 }
279 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorGetSrcLane,
280 CudaShuffleXorGetSrcLane);
281 } // namespace detail
282
283 // For all *_sync wrappers below, it is illegal to synchronize threads from
284 // different program locations, because that is not supported before sm_70.
285 // In other words, all threads in 'mask' must call the functions in convergence.
286 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
287 //
288 // It is also illegal to shuffle with a mask that produces an undefined result
289 // for any of the threads. Specifically, all source threads of the shuffle
290 // must have their corresponding bit in 'mask' set.
291
292 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
293 __device__ inline void GpuSyncWarp(unsigned mask = kCudaWarpAll) {
294 assert(mask & 1u << GpuLaneId());
295 #if CUDA_VERSION >= 9000
296 __syncwarp(mask);
297 #endif
298 }
299 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuSyncWarp, CudaSyncWarp);
300
301 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in
302 // convergence, see comment above for details.
GpuBallotSync(unsigned mask,int pred)303 __device__ inline unsigned GpuBallotSync(unsigned mask, int pred) {
304 assert(mask & 1u << GpuLaneId());
305 #if CUDA_VERSION >= 9000
306 return __ballot_sync(mask, pred);
307 #else
308 return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec.
309 #endif
310 }
311 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuBallotSync, CudaBallotSync);
312
313 // Wrapper for __any_sync. All threads in 'mask' must call this function in
314 // convergence, see comment above for details.
GpuAnySync(unsigned mask,int pred)315 __device__ inline int GpuAnySync(unsigned mask, int pred) {
316 assert(mask & 1u << GpuLaneId());
317 #if CUDA_VERSION >= 9000
318 return __any_sync(mask, pred);
319 #else
320 return __any(pred);
321 #endif
322 }
323 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAnySync, CudaAnySync);
324
325 // Wrapper for __all_sync. All threads in 'mask' must call this function in
326 // convergence, see comment above for details.
GpuAllSync(unsigned mask,int pred)327 __device__ inline int GpuAllSync(unsigned mask, int pred) {
328 assert(mask & 1u << GpuLaneId());
329 #if CUDA_VERSION >= 9000
330 return __all_sync(mask, pred);
331 #else
332 return __all(pred);
333 #endif
334 }
335 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAllSync, CudaAllSync);
336
337 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in
338 // convergence, see comment above for details.
339 template <typename T>
340 __device__ T GpuShuffleSync(unsigned mask, T value, int src_lane,
341 int width = warpSize) {
342 assert(!(width & width - 1));
343 assert(detail::GpuValidateShuffleSyncMask(
344 mask, detail::GpuShuffleGetSrcLane(src_lane, width)));
345 #if CUDA_VERSION >= 9000
346 return __shfl_sync(mask, value, src_lane, width);
347 #else
348 return __shfl(value, src_lane, width);
349 #endif
350 }
351
352 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
353 // instead of float for lo and hi (which is incorrect with ftz, for example).
354 // See b/69446944.
355 __device__ inline double GpuShuffleSync(unsigned mask, double value,
356 int src_lane, int width = warpSize) {
357 #if GOOGLE_CUDA
358 auto tmp = __double_as_longlong(value);
359 auto lo = static_cast<unsigned>(tmp);
360 auto hi = static_cast<unsigned>(tmp >> 32);
361 hi = GpuShuffleSync(mask, hi, src_lane, width);
362 lo = GpuShuffleSync(mask, lo, src_lane, width);
363 return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
364 #elif TENSORFLOW_USE_ROCM
365 auto tmp = static_cast<uint64_t>(value);
366 auto lo = static_cast<unsigned>(tmp);
367 auto hi = static_cast<unsigned>(tmp >> 32);
368 hi = __shfl(static_cast<int>(hi), src_lane, width);
369 lo = __shfl(static_cast<int>(lo), src_lane, width);
370 return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
371 static_cast<uint64_t>(lo));
372 #endif
373 }
374 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleSync, CudaShuffleSync);
375
376 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
377 // convergence, see comment above for details.
378 template <typename T>
379 __device__ inline T GpuShuffleUpSync(unsigned mask, T value, unsigned delta,
380 int width = warpSize) {
381 assert(!(width & width - 1));
382 assert(detail::GpuValidateShuffleSyncMask(
383 mask, detail::GpuShuffleUpGetSrcLane(delta, width)));
384 #if CUDA_VERSION >= 9000
385 return __shfl_up_sync(mask, value, delta, width);
386 #else
387 return __shfl_up(value, delta, width);
388 #endif
389 }
390
391 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
392 // instead of float for lo and hi (which is incorrect with ftz, for example).
393 // See b/69446944.
394 __device__ inline double GpuShuffleUpSync(unsigned mask, double value,
395 unsigned delta,
396 int width = warpSize) {
397 #if GOOGLE_CUDA
398 auto tmp = __double_as_longlong(value);
399 auto lo = static_cast<unsigned>(tmp);
400 auto hi = static_cast<unsigned>(tmp >> 32);
401 hi = GpuShuffleUpSync(mask, hi, delta, width);
402 lo = GpuShuffleUpSync(mask, lo, delta, width);
403 return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
404 #elif TENSORFLOW_USE_ROCM
405 auto tmp = static_cast<uint64_t>(value);
406 auto lo = static_cast<unsigned>(tmp);
407 auto hi = static_cast<unsigned>(tmp >> 32);
408 hi = __shfl_up(static_cast<int>(hi), delta, width);
409 lo = __shfl_up(static_cast<int>(lo), delta, width);
410 return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
411 static_cast<uint64_t>(lo));
412 #endif
413 }
414 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpSync, CudaShuffleUpSync);
415
416 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
417 // in convergence, see comment above for details.
418 template <typename T>
419 __device__ inline T GpuShuffleDownSync(unsigned mask, T value, unsigned delta,
420 int width = warpSize) {
421 assert(!(width & width - 1));
422 assert(detail::GpuValidateShuffleSyncMask(
423 mask, detail::GpuShuffleDownGetSrcLane(delta, width)));
424 #if CUDA_VERSION >= 9000
425 return __shfl_down_sync(mask, value, delta, width);
426 #else
427 return __shfl_down(value, delta, width);
428 #endif
429 }
430
431 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
432 // instead of float for lo and hi (which is incorrect with ftz, for example).
433 // See b/69446944.
434 __device__ inline double GpuShuffleDownSync(unsigned mask, double value,
435 unsigned delta,
436 int width = warpSize) {
437 #if GOOGLE_CUDA
438 auto tmp = __double_as_longlong(value);
439 auto lo = static_cast<unsigned>(tmp);
440 auto hi = static_cast<unsigned>(tmp >> 32);
441 hi = GpuShuffleDownSync(mask, hi, delta, width);
442 lo = GpuShuffleDownSync(mask, lo, delta, width);
443 return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
444 #elif TENSORFLOW_USE_ROCM
445 auto tmp = static_cast<uint64_t>(value);
446 auto lo = static_cast<unsigned>(tmp);
447 auto hi = static_cast<unsigned>(tmp >> 32);
448 hi = __shfl_down(static_cast<int>(hi), delta, width);
449 lo = __shfl_down(static_cast<int>(lo), delta, width);
450 return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
451 static_cast<uint64_t>(lo));
452 #endif
453 }
454 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownSync, CudaShuffleDownSync);
455
456 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
457 // convergence, see comment above for details.
458 template <typename T>
459 __device__ T GpuShuffleXorSync(unsigned mask, T value, int lane_mask,
460 int width = warpSize) {
461 assert(!(width & width - 1));
462 assert(detail::GpuValidateShuffleSyncMask(
463 mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width)));
464 #if GOOGLE_CUDA
465 #if CUDA_VERSION >= 9000
466 return __shfl_xor_sync(mask, value, lane_mask, width);
467 #else
468 return __shfl_xor(value, lane_mask, width);
469 #endif
470 #elif TENSORFLOW_USE_ROCM
471 // ROCM TODO: check if HIP should be changed to cope with more types
472 return __shfl_xor(static_cast<int>(value), lane_mask, width);
473 #endif
474 }
475
476 #if TENSORFLOW_USE_ROCM
477 __device__ inline Eigen::half GpuShuffleXorSync(unsigned mask,
478 Eigen::half value,
479 int lane_mask,
480 int width = warpSize) {
481 assert(!(width & width - 1));
482 assert(detail::GpuValidateShuffleSyncMask(
483 mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width)));
484 // TODO(rocm): This doesn't preserve NaN payload and flushes denorms to zero,
485 // maybe this should be implemented differently?
486 return static_cast<Eigen::half>(
487 __shfl_xor(static_cast<float>(value), lane_mask, width));
488 }
489 #endif
490
491 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
492 // instead of float for lo and hi (which is incorrect with ftz, for example).
493 // See b/69446944.
494 __device__ inline double GpuShuffleXorSync(unsigned mask, double value,
495 int lane_mask,
496 int width = warpSize) {
497 #if GOOGLE_CUDA
498 auto tmp = __double_as_longlong(value);
499 auto lo = static_cast<unsigned>(tmp);
500 auto hi = static_cast<unsigned>(tmp >> 32);
501 hi = GpuShuffleXorSync(mask, hi, lane_mask, width);
502 lo = GpuShuffleXorSync(mask, lo, lane_mask, width);
503 return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
504 #elif TENSORFLOW_USE_ROCM
505 auto tmp = static_cast<uint64_t>(value);
506 auto lo = static_cast<unsigned>(tmp);
507 auto hi = static_cast<unsigned>(tmp >> 32);
508 hi = __shfl_xor(static_cast<int>(hi), lane_mask, width);
509 lo = __shfl_xor(static_cast<int>(lo), lane_mask, width);
510 return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
511 static_cast<uint64_t>(lo));
512 #endif
513 }
514 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorSync, CudaShuffleXorSync);
515
516 // Wrapper for __ldg.
517 template <typename T>
GpuLdg(const T * address)518 __host__ __device__ T GpuLdg(const T* address) {
519 #if __CUDA_ARCH__ >= 350
520 return __ldg(address);
521 #else
522 return *address;
523 #endif
524 }
525
GpuLdg(const bool * address)526 __host__ __device__ inline bool GpuLdg(const bool* address) {
527 return GpuLdg(reinterpret_cast<const char*>(address)) != 0;
528 }
529
GpuLdg(const std::complex<float> * address)530 __host__ __device__ inline std::complex<float> GpuLdg(
531 const std::complex<float>* address) {
532 #if __CUDA_ARCH__ >= 350
533 float2 mem = __ldg(reinterpret_cast<const float2*>(address));
534 return std::complex<float>(mem.x, mem.y);
535 #else
536 return *address;
537 #endif
538 }
539
GpuLdg(const std::complex<double> * address)540 __host__ __device__ inline std::complex<double> GpuLdg(
541 const std::complex<double>* address) {
542 #if __CUDA_ARCH__ >= 350
543 double2 mem = __ldg(reinterpret_cast<const double2*>(address));
544 return std::complex<double>(mem.x, mem.y);
545 #else
546 return *address;
547 #endif
548 }
549 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLdg, CudaLdg);
550
551 // Zeroes count elements starting at ptr using all threads of a 1-D grid.
552 // Note: this function does not synchronize, and therefore the memory range is
553 // not guaranteed to be zero until the next kernel launch.
554 template <typename T>
SetZero(const int count,T * __restrict__ ptr)555 __global__ void SetZero(const int count, T* __restrict__ ptr) {
556 // Check that the grid is one dimensional and index doesn't overflow.
557 assert(blockDim.y == 1);
558 assert(blockDim.z == 1);
559 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
560 for (int i : GpuGridRangeX(count)) {
561 ptr[i] = T(0);
562 }
563 }
564
565 // Helper to set all tensor entries to a specific value.
566 template <typename T, typename Tvalue = T>
SetToValue(const int count,T * __restrict__ ptr,Tvalue value)567 __global__ void SetToValue(const int count, T* __restrict__ ptr, Tvalue value) {
568 // Check that the grid is one dimensional and index doesn't overflow.
569 assert(blockDim.y == 1);
570 assert(blockDim.z == 1);
571 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
572 for (int i : GpuGridRangeX(count)) {
573 ptr[i] = static_cast<T>(value);
574 }
575 }
576
577 namespace detail {
578 // Helper function for atomic accumulation implemented as CAS.
579 template <typename T, typename F>
GpuAtomicCasHelper(T * ptr,F accumulate)580 __device__ T GpuAtomicCasHelper(T* ptr, F accumulate) {
581 T old = *ptr;
582 T assumed;
583 do {
584 assumed = old;
585 old = atomicCAS(ptr, assumed, accumulate(assumed));
586 } while (assumed != old);
587 return old;
588 }
589 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicCasHelper, CudaAtomicCasHelper);
590
591 // Overload for floating point (using integer comparison to handle NaN
592 // correctly).
593 template <typename F>
GpuAtomicCasHelper(float * ptr,F accumulate)594 __device__ float GpuAtomicCasHelper(float* ptr, F accumulate) {
595 return __int_as_float(
596 GpuAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
597 return __float_as_int(accumulate(__int_as_float(a)));
598 }));
599 }
600 template <typename F>
GpuAtomicCasHelper(double * ptr,F accumulate)601 __device__ double GpuAtomicCasHelper(double* ptr, F accumulate) {
602 #if TENSORFLOW_USE_ROCM
603 // FIXME: remove the workaround below once bug is fixed.
604 // HIP has a bug in the implementation of __longlong_as_double
605 // So workaround it by using reinterpret_cast<double*>.
606 uint64_t result =
607 GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
608 [accumulate](tensorflow::uint64 a) {
609 return __double_as_longlong(
610 accumulate(*(reinterpret_cast<double*>(&a))));
611 });
612 return *(reinterpret_cast<double*>(&result));
613 #else
614 return __longlong_as_double(GpuAtomicCasHelper(
615 reinterpret_cast<unsigned long long*>(ptr),
616 [accumulate](tensorflow::uint64 a) {
617 return __double_as_longlong(accumulate(__longlong_as_double(a)));
618 }));
619 #endif
620 }
621
622 // Overload of above function for half. Note that we don't have
623 // atomicCAS() for anything less than 32 bits, so we need to include the
624 // other 16 bits in the operation.
625 //
626 // This version is going to be very slow
627 // under high concurrency, since most threads will be spinning on failing
628 // their compare-and-swap tests. (The fact that we get false sharing on the
629 // neighboring fp16 makes this even worse.) If you are doing a large reduction,
630 // you are much better off with doing the intermediate steps in fp32 and then
631 // switching to fp16 as late as you can in the calculations.
632 //
633 // Note: Assumes little endian.
634 template <typename F>
GpuAtomicCasHelper(Eigen::half * ptr,F accumulate)635 __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
636 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
637 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
638 #endif
639 intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
640 assert(!(intptr & 0x1)); // should be 2-aligned.
641 if (intptr & 0x2) {
642 // The half is in the second part of the uint32 (upper 16 bits).
643 uint32* address = reinterpret_cast<uint32*>(intptr - 2);
644 uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
645 unsigned short high = static_cast<unsigned short>(arg >> 16);
646 Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(high));
647 return (static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc)) << 16) |
648 (arg & 0xffff);
649 });
650 return Eigen::numext::bit_cast<Eigen::half>(
651 static_cast<uint16>(result >> 16));
652 } else {
653 // The half is in the first part of the uint32 (lower 16 bits).
654 uint32* address = reinterpret_cast<uint32*>(intptr);
655 uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
656 unsigned short low = static_cast<unsigned short>(arg & 0xffff);
657 Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(low));
658 return (arg & 0xffff0000) |
659 static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc));
660 });
661 return Eigen::numext::bit_cast<Eigen::half>(
662 static_cast<uint16>(result & 0xffff));
663 }
664 }
665
666 template <typename F>
GpuAtomicCasHelper(long long * ptr,F accumulate)667 __device__ long long GpuAtomicCasHelper(long long* ptr, F accumulate) {
668 return static_cast<long long>(
669 GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
670 [accumulate](unsigned long long a) {
671 return static_cast<unsigned long long>(
672 accumulate(static_cast<long long>(a)));
673 }));
674 }
675
676 template <typename From, typename To>
677 using ToTypeIfConvertible =
678 typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
679
680 template <typename T>
681 struct CudaSupportedTypeImpl {
682 using type = T;
683 };
684
685 template <>
686 struct CudaSupportedTypeImpl<long long> {
687 using type = unsigned long long;
688 };
689
690 template <>
691 struct CudaSupportedTypeImpl<unsigned long> {
692 using type =
693 typename std::conditional<sizeof(unsigned long) == sizeof(unsigned int),
694 unsigned int, unsigned long long>::type;
695 };
696
697 template <>
698 struct CudaSupportedTypeImpl<long> {
699 // This cast should be safe since module-2 addition should work fine. However,
700 // signed overflow is not handled correctly since it's undefined behavior.
701 using type = typename CudaSupportedTypeImpl<unsigned long>::type;
702 };
703
704 template <typename T>
705 using CudaSupportedType = typename CudaSupportedTypeImpl<T>::type;
706
707 template <typename T>
708 __device__ CudaSupportedType<T>* ToCudaSupportedPtr(T* ptr) {
709 return reinterpret_cast<CudaSupportedType<T>*>(ptr);
710 }
711
712 } // namespace detail
713
714 // CUDA provides atomic ops, but not for all types. We provide wrappers
715 // for some ops and provide implementation for all reasonable types.
716
717 template <typename T, typename U>
718 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicAdd(T* ptr, U value) {
719 return atomicAdd(detail::ToCudaSupportedPtr(ptr), value);
720 }
721
722 __device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
723 Eigen::half value) {
724 return detail::GpuAtomicCasHelper(
725 ptr, [value](Eigen::half a) { return a + value; });
726 }
727
728 #if (__CUDA_ARCH__ < 600) || TENSORFLOW_USE_ROCM
729 __device__ inline double GpuAtomicAdd(double* ptr, double value) {
730 return detail::GpuAtomicCasHelper(ptr,
731 [value](double a) { return a + value; });
732 }
733 #endif
734
735 // GpuAtomicAdd
736 // Specializations of GpuAtomicAdd for complex types, which GpuAtomicAdd does
737 // not support. We treat a std::complex<T>* as a T* (the C++ standard section
738 // 26.4.4 allows this explicitly) and atomic add the real and imaginary
739 // components individually. The operation as a whole is not atomic, but we can
740 // safely treat the components independently for the purpose of accumulating.
741
742 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
743 __device__ inline std::complex<float> GpuAtomicAdd(std::complex<float>* ptr,
744 std::complex<float> value) {
745 auto ptr_scalar = reinterpret_cast<float*>(ptr);
746 return std::complex<float>(GpuAtomicAdd(ptr_scalar, value.real()),
747 GpuAtomicAdd(ptr_scalar + 1, value.imag()));
748 }
749
750 __device__ inline std::complex<double> GpuAtomicAdd(
751 std::complex<double>* ptr, std::complex<double> value) {
752 auto ptr_scalar = reinterpret_cast<double*>(ptr);
753 return std::complex<double>(GpuAtomicAdd(ptr_scalar, value.real()),
754 GpuAtomicAdd(ptr_scalar + 1, value.imag()));
755 }
756 #endif
757 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicAdd, CudaAtomicAdd);
758
759 // GpuAtomicSub
760 template <typename T, typename U>
761 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicSub(T* ptr, U value) {
762 return atomicSub(ptr, value);
763 }
764
765 // Specializations of substraction which add the negative value.
766 __device__ inline float GpuAtomicSub(float* ptr, float value) {
767 return GpuAtomicAdd(ptr, -value);
768 }
769
770 __device__ inline double GpuAtomicSub(double* ptr, double value) {
771 return GpuAtomicAdd(ptr, -value);
772 }
773
774 __device__ inline int64_t GpuAtomicSub(int64_t* ptr, int64_t value) {
775 return GpuAtomicAdd(ptr, -value);
776 }
777
778 __device__ inline tensorflow::uint64 GpuAtomicSub(tensorflow::uint64* ptr,
779 tensorflow::uint64 value) {
780 return GpuAtomicAdd(ptr, -static_cast<int64_t>(value));
781 }
782
783 __device__ inline Eigen::half GpuAtomicSub(Eigen::half* ptr,
784 Eigen::half value) {
785 return detail::GpuAtomicCasHelper(
786 ptr, [value](Eigen::half a) { return a - value; });
787 }
788 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub);
789
790 // GpuAtomicMax
791 template <typename T, typename U>
792 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMax(T* ptr, U value) {
793 return atomicMax(detail::ToCudaSupportedPtr(ptr), value);
794 }
795
796 #if TENSORFLOW_USE_ROCM
797
798 /*
799 * CUDA runtime headers have the following defined
800 * __device__ int max(int, int)
801 * __device__ float max(float, float)
802 * __device__ double max(double, double)
803 *
804 * and many others, where as HIP runtime headers only have the "int" version
805 *
806 * Therefore need to special case ROCm version to call the correct underlying
807 * routines for float and double types.
808 *
809 */
810
811 __device__ inline float GpuAtomicMax(float* ptr, float value) {
812 return detail::GpuAtomicCasHelper(
813 ptr, [value](float a) { return fmaxf(a, value); });
814 }
815
816 __device__ inline double GpuAtomicMax(double* ptr, double value) {
817 return detail::GpuAtomicCasHelper(
818 ptr, [value](double a) { return fmax(a, value); });
819 }
820
821 #else
822
823 __device__ inline float GpuAtomicMax(float* ptr, float value) {
824 return detail::GpuAtomicCasHelper(ptr,
825 [value](float a) { return max(a, value); });
826 }
827
828 __device__ inline double GpuAtomicMax(double* ptr, double value) {
829 return detail::GpuAtomicCasHelper(
830 ptr, [value](double a) { return max(a, value); });
831 }
832
833 #endif
834
835 __device__ inline Eigen::half GpuAtomicMax(Eigen::half* ptr,
836 Eigen::half value) {
837 return detail::GpuAtomicCasHelper(
838 ptr, [value](Eigen::half a) { return max(a, value); });
839 }
840
841 #if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320)
842 __device__ inline tensorflow::uint64 GpuAtomicMax(tensorflow::uint64* ptr,
843 tensorflow::uint64 value) {
844 return detail::GpuAtomicCasHelper(
845 detail::ToCudaSupportedPtr(ptr),
846 [value](tensorflow::uint64 a) { return max(a, value); });
847 }
848
849 __device__ inline int64_t GpuAtomicMax(int64_t* ptr, int64_t value) {
850 return detail::GpuAtomicCasHelper(
851 detail::ToCudaSupportedPtr(ptr),
852 [value](int64_t a) { return max(a, value); });
853 }
854 #endif
855 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax);
856
857 // GpuAtomicMin
858 template <typename T, typename U>
859 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMin(T* ptr, U value) {
860 return atomicMin(detail::ToCudaSupportedPtr(ptr), value);
861 }
862
863 #if TENSORFLOW_USE_ROCM
864
865 /*
866 * CUDA runtime headers have the following defined
867 * __device__ int min(int, int)
868 * __device__ float min(float, float)
869 * __device__ double min(double, double)
870 *
871 * and many others, where as HIP runtime headers only have the "int" version
872 *
873 * Therefore need to special case ROCm version to call the correct underlying
874 * routines for float and double types.
875 *
876 */
877
878 __device__ inline float GpuAtomicMin(float* ptr, float value) {
879 return detail::GpuAtomicCasHelper(
880 ptr, [value](float a) { return fminf(a, value); });
881 }
882
883 __device__ inline double GpuAtomicMin(double* ptr, double value) {
884 return detail::GpuAtomicCasHelper(
885 ptr, [value](double a) { return fmin(a, value); });
886 }
887
888 #else
889
890 __device__ inline float GpuAtomicMin(float* ptr, float value) {
891 return detail::GpuAtomicCasHelper(ptr,
892 [value](float a) { return min(a, value); });
893 }
894
895 __device__ inline double GpuAtomicMin(double* ptr, double value) {
896 return detail::GpuAtomicCasHelper(
897 ptr, [value](double a) { return min(a, value); });
898 }
899
900 #endif
901
902 __device__ inline Eigen::half GpuAtomicMin(Eigen::half* ptr,
903 Eigen::half value) {
904 return detail::GpuAtomicCasHelper(
905 ptr, [value](Eigen::half a) { return min(a, value); });
906 }
907
908 #if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320)
909 __device__ inline tensorflow::uint64 GpuAtomicMin(tensorflow::uint64* ptr,
910 tensorflow::uint64 value) {
911 return detail::GpuAtomicCasHelper(
912 detail::ToCudaSupportedPtr(ptr),
913 [value](tensorflow::uint64 a) { return min(a, value); });
914 }
915
916 __device__ inline int64_t GpuAtomicMin(int64_t* ptr, int64_t value) {
917 return detail::GpuAtomicCasHelper(
918 detail::ToCudaSupportedPtr(ptr),
919 [value](int64_t a) { return min(a, value); });
920 }
921 #endif
922 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMin, CudaAtomicMin);
923
924 // GpuAtomicMul
925 template <typename T, typename U>
926 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMul(T* ptr, U value) {
927 return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a * value; });
928 }
929 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMul, CudaAtomicMul);
930
931 // GpuAtomicDiv
932 template <typename T, typename U>
933 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicDiv(T* ptr, U value) {
934 return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a / value; });
935 }
936 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicDiv, CudaAtomicDiv);
937
938 // Import all specialized std::complex device operators in namespace tensorflow.
939 #if GOOGLE_CUDA && defined(EIGEN_USING_STD_COMPLEX_OPERATORS)
940 EIGEN_USING_STD_COMPLEX_OPERATORS
941 #endif // GOOGLE_CUDA
942
943 namespace functor {
944 // Import all specialized std::complex device operators in namespace functor.
945 #if GOOGLE_CUDA && defined(EIGEN_USING_STD_COMPLEX_OPERATORS)
946 EIGEN_USING_STD_COMPLEX_OPERATORS
947 #endif // GOOGLE_CUDA
948
949 // ROCm hcc(clang) has severe difficulties dealing with std::complex directly
950 // due to a header issue. This template assists in casting std::complex into the
951 // corresponding internal ROCm types.
952 template <class T>
953 struct MapComplexToHipComplex {
954 typedef T TM;
955 };
956
957 #if TENSORFLOW_USE_ROCM
958 template <>
959 struct MapComplexToHipComplex<std::complex<float> > {
960 typedef hipFloatComplex TM;
961 };
962
963 template <>
964 struct MapComplexToHipComplex<std::complex<double> > {
965 typedef hipDoubleComplex TM;
966 };
967 #endif
968 }; // namespace functor
969
970 } // namespace tensorflow
971
972 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
973 #endif // TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
974