• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
18 
19 /**
20  * Wrappers and helpers for CUDA device code.
21  *
22  * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
23  * backwards compatibility, see go/volta-porting for details.
24  * Provides atomic operations on types that aren't natively supported.
25  * Defines a number of macros and types providing a shared interface
26  * to either CUDA or ROCm APIs, depending on the build.
27  */
28 
29 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30 
31 #include <algorithm>
32 #include <complex>
33 
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #if GOOGLE_CUDA
36 #include "third_party/gpus/cuda/include/cuComplex.h"
37 #include "third_party/gpus/cuda/include/cuda.h"
38 #else
39 #include "rocm/include/hip/hip_complex.h"
40 #endif
41 
42 #include "tensorflow/core/platform/types.h"
43 #include "tensorflow/core/util/gpu_cuda_alias.h"
44 
45 #if GOOGLE_CUDA
46 using gpuFloatComplex = cuFloatComplex;
47 using gpuDoubleComplex = cuDoubleComplex;
48 using gpuStream_t = cudaStream_t;
49 using gpuEvent_t = cudaEvent_t;
50 #define gpuEventRecord cudaEventRecord
51 #define gpuEventSynchronize cudaEventSynchronize
52 #define gpuEventDestroy cudaEventDestroy
53 #define gpuEventCreate cudaEventCreate
54 #define gpuEventCreateWithFlags cudaEventCreateWithFlags
55 #define gpuEventDisableTiming cudaEventDisableTiming
56 #define gpuDeviceSynchronize cudaDeviceSynchronize
57 #define gpuFree cudaFree
58 #elif TENSORFLOW_USE_ROCM
59 using gpuFloatComplex = hipFloatComplex;
60 using gpuDoubleComplex = hipDoubleComplex;
61 using gpuStream_t = hipStream_t;
62 using gpuEvent_t = hipEvent_t;
63 using cudaError = int;
64 using cudaError_t = int;
65 #define cudaSuccess 0
66 #define cudaGetLastError hipGetLastError
67 #define gpuEventRecord hipEventRecord
68 #define gpuEventDestroy hipEventDestroy
69 #define gpuEventSynchronize hipEventSynchronize
70 #define gpuEventCreate hipEventCreate
71 #define gpuEventCreateWithFlags hipEventCreateWithFlags
72 #define gpuEventDisableTiming hipEventDisableTiming
73 #define gpuDeviceSynchronize hipDeviceSynchronize
74 #define gpuFree hipFree
cudaGetErrorString(int err)75 static std::string cudaGetErrorString(int err) { return std::to_string(err); }
76 #endif
77 
78 #define TF_RETURN_IF_CUDA_ERROR(result)                   \
79   do {                                                    \
80     cudaError_t error(result);                            \
81     if (!SE_PREDICT_TRUE(error == cudaSuccess)) {         \
82       return errors::Internal("Cuda call failed with ",   \
83                               cudaGetErrorString(error)); \
84     }                                                     \
85   } while (0)
86 
87 #define TF_OP_REQUIRES_CUDA_SUCCESS(context, result)                   \
88   do {                                                                 \
89     cudaError_t error(result);                                         \
90     if (!SE_PREDICT_TRUE(error == cudaSuccess)) {                      \
91       context->SetStatus(errors::Internal("Cuda call failed with",     \
92                                           cudaGetErrorString(error))); \
93       return;                                                          \
94     }                                                                  \
95   } while (0)
96 
97 namespace tensorflow {
98 // According to HIP developer guide at
99 // https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#assert
100 // assert is not supported by HIP. While we are waiting for assert support in
101 // hip kernels, the assert call should be macroed to NOP so that it does not
102 // block us from creating a debug build
103 #if TENSORFLOW_USE_ROCM
104 #undef assert
105 #define assert(x) \
106   {}
107 #endif
108 
109 namespace detail {
110 
111 // Helper for range-based for loop using 'delta' increments.
112 // Usage: see GpuGridRange?() functions below.
113 template <typename T>
114 class GpuGridRange {
115   struct Iterator {
IteratorIterator116     __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
117     __device__ T operator*() const { return index_; }
118     __device__ Iterator& operator++() {
119       index_ += delta_;
120       return *this;
121     }
122     __device__ bool operator!=(const Iterator& other) const {
123       bool greater = index_ > other.index_;
124       bool less = index_ < other.index_;
125       // Anything past an end iterator (delta_ == 0) is equal.
126       // In range-based for loops, this optimizes to 'return less'.
127       if (!other.delta_) {
128         return less;
129       }
130       if (!delta_) {
131         return greater;
132       }
133       return less || greater;
134     }
135 
136    private:
137     T index_;
138     const T delta_;
139   };
140 
141  public:
GpuGridRange(T begin,T delta,T end)142   __device__ GpuGridRange(T begin, T delta, T end)
143       : begin_(begin), delta_(delta), end_(end) {}
144 
begin()145   __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
end()146   __device__ Iterator end() const { return Iterator{end_, 0}; }
147 
148  private:
149   T begin_;
150   T delta_;
151   T end_;
152 };
153 
154 #ifndef TENSORFLOW_USE_ROCM
155 template <typename... T>
156 using CudaGridRange = GpuGridRange<T...>;
157 #endif
158 }  // namespace detail
159 
160 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate
161 // of the global thread index. That is, each index i is visited by all threads
162 // with the same x-coordinate.
163 // Usage: for(int i : GpuGridRangeX(count)) { visit(i); }
164 template <typename T>
GpuGridRangeX(T count)165 __device__ detail::GpuGridRange<T> GpuGridRangeX(T count) {
166   return detail::GpuGridRange<T>(
167       /*begin=*/blockIdx.x * blockDim.x + threadIdx.x,
168       /*delta=*/gridDim.x * blockDim.x, /*end=*/count);
169 }
170 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeX, CudaGridRangeX);
171 
172 // Helper to visit indices in the range 0 <= i < count using the y-coordinate.
173 // Usage: for(int i : GpuGridRangeY(count)) { visit(i); }
174 template <typename T>
GpuGridRangeY(T count)175 __device__ detail::GpuGridRange<T> GpuGridRangeY(T count) {
176   return detail::GpuGridRange<T>(
177       /*begin=*/blockIdx.y * blockDim.y + threadIdx.y,
178       /*delta=*/gridDim.y * blockDim.y, /*end=*/count);
179 }
180 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeY, CudaGridRangeY);
181 
182 // Helper to visit indices in the range 0 <= i < count using the z-coordinate.
183 // Usage: for(int i : GpuGridRangeZ(count)) { visit(i); }
184 template <typename T>
GpuGridRangeZ(T count)185 __device__ detail::GpuGridRange<T> GpuGridRangeZ(T count) {
186   return detail::GpuGridRange<T>(
187       /*begin=*/blockIdx.z * blockDim.z + threadIdx.z,
188       /*delta=*/gridDim.z * blockDim.z, /*end=*/count);
189 }
190 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeZ, CudaGridRangeZ);
191 
192 // Mask for all 32 threads in a warp.
193 __device__ const unsigned kCudaWarpAll = 0xffffffff;
194 // ROCM TODO add ROCM implementation
195 // Mask for all 64 threads in a wavefront.
196 __device__ const unsigned kGpuWarpAll = 0xffffffff;
197 
198 // Returns the warp lane ID of the calling thread
GpuLaneId()199 __device__ inline unsigned GpuLaneId() {
200   unsigned int lane_id;
201 #if GOOGLE_CUDA
202 #if __clang__
203   return __nvvm_read_ptx_sreg_laneid();
204 #else   // __clang__
205   asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
206 #endif  // __clang__
207 #elif TENSORFLOW_USE_ROCM
208   lane_id = __lane_id();
209 #endif
210   return lane_id;
211 }
212 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLaneId, CudaLaneId);
213 
214 namespace detail {
215 // Returns true if mask is a valid parameter for __shfl*sync to return a well
216 // defined value, assuming the calling lane will read from src_lane as part of
217 // the shuffle operation.
218 //
219 // Specifically, returns true iff mask has the calling lane bit and the src_lane
220 // bit set, and the src_lane calls this function with the same mask value
221 // (required for the two threads to wait for each other).
222 //
223 // On Volta, for some invalid masks, this function hangs or returns false
224 // positives, because the implementation shuffles with the same mask that
225 // we are validating. Run on Pascal if you suspect that the mask is incorrect.
GpuValidateShuffleSyncMask(unsigned mask,unsigned src_lane)226 __device__ inline bool GpuValidateShuffleSyncMask(unsigned mask,
227                                                   unsigned src_lane) {
228   unsigned src_dst_mask = 1u << GpuLaneId() | 1u << src_lane;
229 #if CUDA_VERSION >= 9000
230   unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
231 #else
232 #if GOOGLE_CUDA
233   unsigned src_lane_mask = __shfl(mask, src_lane);
234 #elif TENSORFLOW_USE_ROCM
235   unsigned src_lane_mask =
236       __shfl(static_cast<int>(mask), static_cast<int>(src_lane));
237 #endif
238 #endif
239   return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
240 }
241 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuValidateShuffleSyncMask,
242                                   CudaValidateShuffleSyncMask);
243 
244 // Returns the actual source lane for shuffle.
GpuShuffleGetSrcLane(int src_lane,int width)245 __device__ inline unsigned GpuShuffleGetSrcLane(int src_lane, int width) {
246   int lane_id = GpuLaneId();
247   int lane_base = lane_id & ~width + 1;
248   int lane_offset = src_lane & width - 1;
249   return lane_base + lane_offset;
250 }
251 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleGetSrcLane, CudaShuffleGetSrcLane);
252 
253 // Returns the source lane for shuffle up.
GpuShuffleUpGetSrcLane(unsigned delta,int width)254 __device__ inline unsigned GpuShuffleUpGetSrcLane(unsigned delta, int width) {
255   unsigned lane_id = GpuLaneId();
256   if ((lane_id & width - 1) < delta) {
257     return lane_id;
258   }
259   return lane_id - delta;
260 }
261 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpGetSrcLane,
262                                   CudaShuffleUpGetSrcLane);
263 
264 // Returns the source lane for shuffle down.
GpuShuffleDownGetSrcLane(unsigned delta,int width)265 __device__ inline unsigned GpuShuffleDownGetSrcLane(unsigned delta, int width) {
266   unsigned lane_id = GpuLaneId();
267   if ((lane_id & width - 1) + delta >= width) {
268     return lane_id;
269   }
270   return lane_id + delta;
271 }
272 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownGetSrcLane,
273                                   CudaShuffleDownGetSrcLane);
274 
275 // Returns the source lane for shuffle xor.
GpuShuffleXorGetSrcLane(int lane_mask,int width)276 __device__ inline unsigned GpuShuffleXorGetSrcLane(int lane_mask, int width) {
277   int lane_id = GpuLaneId();
278   int src_lane = lane_id ^ lane_mask;
279   if (src_lane > (lane_id | width - 1)) {
280     return lane_id;
281   }
282   return src_lane;
283 }
284 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorGetSrcLane,
285                                   CudaShuffleXorGetSrcLane);
286 }  // namespace detail
287 
288 // For all *_sync wrappers below, it is illegal to synchronize threads from
289 // different program locations, because that is not supported before sm_70.
290 // In other words, all threads in 'mask' must call the functions in convergence.
291 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
292 //
293 // It is also illegal to shuffle with a mask that produces an undefined result
294 // for any of the threads. Specifically, all source threads of the shuffle
295 // must have their corresponding bit in 'mask' set.
296 
297 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
298 __device__ inline void GpuSyncWarp(unsigned mask = kCudaWarpAll) {
299   assert(mask & 1u << GpuLaneId());
300 #if CUDA_VERSION >= 9000
301   __syncwarp(mask);
302 #endif
303 }
304 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuSyncWarp, CudaSyncWarp);
305 
306 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in
307 // convergence, see comment above for details.
GpuBallotSync(unsigned mask,int pred)308 __device__ inline unsigned GpuBallotSync(unsigned mask, int pred) {
309   assert(mask & 1u << GpuLaneId());
310 #if CUDA_VERSION >= 9000
311   return __ballot_sync(mask, pred);
312 #else
313   return __ballot(pred) & mask;  // Apply mask to match __ballot_sync's spec.
314 #endif
315 }
316 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuBallotSync, CudaBallotSync);
317 
318 // Wrapper for __any_sync. All threads in 'mask' must call this function in
319 // convergence, see comment above for details.
GpuAnySync(unsigned mask,int pred)320 __device__ inline int GpuAnySync(unsigned mask, int pred) {
321   assert(mask & 1u << GpuLaneId());
322 #if CUDA_VERSION >= 9000
323   return __any_sync(mask, pred);
324 #else
325   return __any(pred);
326 #endif
327 }
328 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAnySync, CudaAnySync);
329 
330 // Wrapper for __all_sync. All threads in 'mask' must call this function in
331 // convergence, see comment above for details.
GpuAllSync(unsigned mask,int pred)332 __device__ inline int GpuAllSync(unsigned mask, int pred) {
333   assert(mask & 1u << GpuLaneId());
334 #if CUDA_VERSION >= 9000
335   return __all_sync(mask, pred);
336 #else
337   return __all(pred);
338 #endif
339 }
340 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAllSync, CudaAllSync);
341 
342 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in
343 // convergence, see comment above for details.
344 template <typename T>
345 __device__ T GpuShuffleSync(unsigned mask, T value, int src_lane,
346                             int width = warpSize) {
347   assert(!(width & width - 1));
348   assert(detail::GpuValidateShuffleSyncMask(
349       mask, detail::GpuShuffleGetSrcLane(src_lane, width)));
350 #if CUDA_VERSION >= 9000
351   return __shfl_sync(mask, value, src_lane, width);
352 #else
353   return __shfl(value, src_lane, width);
354 #endif
355 }
356 
357 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
358 // instead of float for lo and hi (which is incorrect with ftz, for example).
359 // See b/69446944.
360 __device__ inline double GpuShuffleSync(unsigned mask, double value,
361                                         int src_lane, int width = warpSize) {
362 #if GOOGLE_CUDA
363   auto tmp = __double_as_longlong(value);
364   auto lo = static_cast<unsigned>(tmp);
365   auto hi = static_cast<unsigned>(tmp >> 32);
366   hi = GpuShuffleSync(mask, hi, src_lane, width);
367   lo = GpuShuffleSync(mask, lo, src_lane, width);
368   return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
369 #elif TENSORFLOW_USE_ROCM
370   auto tmp = static_cast<uint64_t>(value);
371   auto lo = static_cast<unsigned>(tmp);
372   auto hi = static_cast<unsigned>(tmp >> 32);
373   hi = __shfl(static_cast<int>(hi), src_lane, width);
374   lo = __shfl(static_cast<int>(lo), src_lane, width);
375   return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
376                              static_cast<uint64_t>(lo));
377 #endif
378 }
379 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleSync, CudaShuffleSync);
380 
381 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
382 // convergence, see comment above for details.
383 template <typename T>
384 __device__ inline T GpuShuffleUpSync(unsigned mask, T value, unsigned delta,
385                                      int width = warpSize) {
386   assert(!(width & width - 1));
387   assert(detail::GpuValidateShuffleSyncMask(
388       mask, detail::GpuShuffleUpGetSrcLane(delta, width)));
389 #if CUDA_VERSION >= 9000
390   return __shfl_up_sync(mask, value, delta, width);
391 #else
392   return __shfl_up(value, delta, width);
393 #endif
394 }
395 
396 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
397 // instead of float for lo and hi (which is incorrect with ftz, for example).
398 // See b/69446944.
399 __device__ inline double GpuShuffleUpSync(unsigned mask, double value,
400                                           unsigned delta,
401                                           int width = warpSize) {
402 #if GOOGLE_CUDA
403   auto tmp = __double_as_longlong(value);
404   auto lo = static_cast<unsigned>(tmp);
405   auto hi = static_cast<unsigned>(tmp >> 32);
406   hi = GpuShuffleUpSync(mask, hi, delta, width);
407   lo = GpuShuffleUpSync(mask, lo, delta, width);
408   return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
409 #elif TENSORFLOW_USE_ROCM
410   auto tmp = static_cast<uint64_t>(value);
411   auto lo = static_cast<unsigned>(tmp);
412   auto hi = static_cast<unsigned>(tmp >> 32);
413   hi = __shfl_up(static_cast<int>(hi), delta, width);
414   lo = __shfl_up(static_cast<int>(lo), delta, width);
415   return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
416                              static_cast<uint64_t>(lo));
417 #endif
418 }
419 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpSync, CudaShuffleUpSync);
420 
421 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
422 // in convergence, see comment above for details.
423 template <typename T>
424 __device__ inline T GpuShuffleDownSync(unsigned mask, T value, unsigned delta,
425                                        int width = warpSize) {
426   assert(!(width & width - 1));
427   assert(detail::GpuValidateShuffleSyncMask(
428       mask, detail::GpuShuffleDownGetSrcLane(delta, width)));
429 #if CUDA_VERSION >= 9000
430   return __shfl_down_sync(mask, value, delta, width);
431 #else
432   return __shfl_down(value, delta, width);
433 #endif
434 }
435 
436 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
437 // instead of float for lo and hi (which is incorrect with ftz, for example).
438 // See b/69446944.
439 __device__ inline double GpuShuffleDownSync(unsigned mask, double value,
440                                             unsigned delta,
441                                             int width = warpSize) {
442 #if GOOGLE_CUDA
443   auto tmp = __double_as_longlong(value);
444   auto lo = static_cast<unsigned>(tmp);
445   auto hi = static_cast<unsigned>(tmp >> 32);
446   hi = GpuShuffleDownSync(mask, hi, delta, width);
447   lo = GpuShuffleDownSync(mask, lo, delta, width);
448   return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
449 #elif TENSORFLOW_USE_ROCM
450   auto tmp = static_cast<uint64_t>(value);
451   auto lo = static_cast<unsigned>(tmp);
452   auto hi = static_cast<unsigned>(tmp >> 32);
453   hi = __shfl_down(static_cast<int>(hi), delta, width);
454   lo = __shfl_down(static_cast<int>(lo), delta, width);
455   return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
456                              static_cast<uint64_t>(lo));
457 #endif
458 }
459 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownSync, CudaShuffleDownSync);
460 
461 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
462 // convergence, see comment above for details.
463 template <typename T>
464 __device__ T GpuShuffleXorSync(unsigned mask, T value, int lane_mask,
465                                int width = warpSize) {
466   assert(!(width & width - 1));
467   assert(detail::GpuValidateShuffleSyncMask(
468       mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width)));
469 #if GOOGLE_CUDA
470 #if CUDA_VERSION >= 9000
471   return __shfl_xor_sync(mask, value, lane_mask, width);
472 #else
473   return __shfl_xor(value, lane_mask, width);
474 #endif
475 #elif TENSORFLOW_USE_ROCM
476   // ROCM TODO: check if HIP should be changed to cope with more types
477   return __shfl_xor(static_cast<int>(value), lane_mask, width);
478 #endif
479 }
480 
481 #if TENSORFLOW_USE_ROCM
482 __device__ inline Eigen::half GpuShuffleXorSync(unsigned mask,
483                                                 Eigen::half value,
484                                                 int lane_mask,
485                                                 int width = warpSize) {
486   assert(!(width & width - 1));
487   assert(detail::GpuValidateShuffleSyncMask(
488       mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width)));
489   // TODO(rocm): This doesn't preserve NaN payload and flushes denorms to zero,
490   // maybe this should be implemented differently?
491   return static_cast<Eigen::half>(
492       __shfl_xor(static_cast<float>(value), lane_mask, width));
493 }
494 #endif
495 
496 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
497 // instead of float for lo and hi (which is incorrect with ftz, for example).
498 // See b/69446944.
499 __device__ inline double GpuShuffleXorSync(unsigned mask, double value,
500                                            int lane_mask,
501                                            int width = warpSize) {
502 #if GOOGLE_CUDA
503   auto tmp = __double_as_longlong(value);
504   auto lo = static_cast<unsigned>(tmp);
505   auto hi = static_cast<unsigned>(tmp >> 32);
506   hi = GpuShuffleXorSync(mask, hi, lane_mask, width);
507   lo = GpuShuffleXorSync(mask, lo, lane_mask, width);
508   return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
509 #elif TENSORFLOW_USE_ROCM
510   auto tmp = static_cast<uint64_t>(value);
511   auto lo = static_cast<unsigned>(tmp);
512   auto hi = static_cast<unsigned>(tmp >> 32);
513   hi = __shfl_xor(static_cast<int>(hi), lane_mask, width);
514   lo = __shfl_xor(static_cast<int>(lo), lane_mask, width);
515   return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
516                              static_cast<uint64_t>(lo));
517 #endif
518 }
519 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorSync, CudaShuffleXorSync);
520 
521 // Wrapper for __ldg.
522 template <typename T>
GpuLdg(const T * address)523 __host__ __device__ T GpuLdg(const T* address) {
524 #if __CUDA_ARCH__ >= 350
525   return __ldg(address);
526 #else
527   return *address;
528 #endif
529 }
530 
GpuLdg(const bool * address)531 __host__ __device__ inline bool GpuLdg(const bool* address) {
532   return GpuLdg(reinterpret_cast<const char*>(address)) != 0;
533 }
534 
GpuLdg(const std::complex<float> * address)535 __host__ __device__ inline std::complex<float> GpuLdg(
536     const std::complex<float>* address) {
537 #if __CUDA_ARCH__ >= 350
538   float2 mem = __ldg(reinterpret_cast<const float2*>(address));
539   return std::complex<float>(mem.x, mem.y);
540 #else
541   return *address;
542 #endif
543 }
544 
GpuLdg(const std::complex<double> * address)545 __host__ __device__ inline std::complex<double> GpuLdg(
546     const std::complex<double>* address) {
547 #if __CUDA_ARCH__ >= 350
548   double2 mem = __ldg(reinterpret_cast<const double2*>(address));
549   return std::complex<double>(mem.x, mem.y);
550 #else
551   return *address;
552 #endif
553 }
554 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLdg, CudaLdg);
555 
556 // Zeroes count elements starting at ptr using all threads of a 1-D grid.
557 // Note: this function does not synchronize, and therefore the memory range is
558 // not guaranteed to be zero until the next kernel launch.
559 template <typename T>
SetZero(const int count,T * __restrict__ ptr)560 __global__ void SetZero(const int count, T* __restrict__ ptr) {
561   // Check that the grid is one dimensional and index doesn't overflow.
562   assert(blockDim.y == 1);
563   assert(blockDim.z == 1);
564   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
565   for (int i : GpuGridRangeX(count)) {
566     ptr[i] = T(0);
567   }
568 }
569 
570 // Helper to set all tensor entries to a specific value.
571 template <typename T>
SetToValue(const int count,T * __restrict__ ptr,T value)572 __global__ void SetToValue(const int count, T* __restrict__ ptr, T value) {
573   // Check that the grid is one dimensional and index doesn't overflow.
574   assert(blockDim.y == 1);
575   assert(blockDim.z == 1);
576   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
577   for (int i : GpuGridRangeX(count)) {
578     ptr[i] = value;
579   }
580 }
581 
582 namespace detail {
583 // Helper function for atomic accumulation implemented as CAS.
584 template <typename T, typename F>
GpuAtomicCasHelper(T * ptr,F accumulate)585 __device__ T GpuAtomicCasHelper(T* ptr, F accumulate) {
586   T old = *ptr;
587   T assumed;
588   do {
589     assumed = old;
590     old = atomicCAS(ptr, assumed, accumulate(assumed));
591   } while (assumed != old);
592   return old;
593 }
594 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicCasHelper, CudaAtomicCasHelper);
595 
596 // Overload for floating point (using integer comparison to handle NaN
597 // correctly).
598 template <typename F>
GpuAtomicCasHelper(float * ptr,F accumulate)599 __device__ float GpuAtomicCasHelper(float* ptr, F accumulate) {
600   return __int_as_float(
601       GpuAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
602         return __float_as_int(accumulate(__int_as_float(a)));
603       }));
604 }
605 template <typename F>
GpuAtomicCasHelper(double * ptr,F accumulate)606 __device__ double GpuAtomicCasHelper(double* ptr, F accumulate) {
607 #if TENSORFLOW_USE_ROCM
608   // FIXME: remove the workaround below once bug is fixed.
609   // HIP has a bug in the implementation of __longlong_as_double
610   // So workaround it by using reinterpret_cast<double*>.
611   uint64_t result =
612       GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
613                          [accumulate](tensorflow::uint64 a) {
614                            return __double_as_longlong(
615                                accumulate(*(reinterpret_cast<double*>(&a))));
616                          });
617   return *(reinterpret_cast<double*>(&result));
618 #else
619   return __longlong_as_double(GpuAtomicCasHelper(
620       reinterpret_cast<unsigned long long*>(ptr),
621       [accumulate](tensorflow::uint64 a) {
622         return __double_as_longlong(accumulate(__longlong_as_double(a)));
623       }));
624 #endif
625 }
626 
627 // Overload of above function for half. Note that we don't have
628 // atomicCAS() for anything less than 32 bits, so we need to include the
629 // other 16 bits in the operation.
630 //
631 // This version is going to be very slow
632 // under high concurrency, since most threads will be spinning on failing
633 // their compare-and-swap tests. (The fact that we get false sharing on the
634 // neighboring fp16 makes this even worse.) If you are doing a large reduction,
635 // you are much better off with doing the intermediate steps in fp32 and then
636 // switching to fp16 as late as you can in the calculations.
637 //
638 // Note: Assumes little endian.
639 template <typename F>
GpuAtomicCasHelper(Eigen::half * ptr,F accumulate)640 __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
641 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
642   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
643 #endif
644   namespace half_impl = Eigen::half_impl;
645   intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
646   assert(!(intptr & 0x1));  // should be 2-aligned.
647   if (intptr & 0x2) {
648     // The half is in the second part of the uint32 (upper 16 bits).
649     uint32* address = reinterpret_cast<uint32*>(intptr - 2);
650     uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
651       unsigned short high = static_cast<unsigned short>(arg >> 16);
652       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
653       return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
654     });
655     return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
656   } else {
657     // The half is in the first part of the uint32 (lower 16 bits).
658     uint32* address = reinterpret_cast<uint32*>(intptr);
659     uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
660       unsigned short low = static_cast<unsigned short>(arg & 0xffff);
661       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
662       return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
663     });
664     return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
665   }
666 }
667 
668 template <typename F>
GpuAtomicCasHelper(long long * ptr,F accumulate)669 __device__ long long GpuAtomicCasHelper(long long* ptr, F accumulate) {
670   return static_cast<long long>(
671       GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
672                          [accumulate](unsigned long long a) {
673                            return static_cast<unsigned long long>(
674                                accumulate(static_cast<long long>(a)));
675                          }));
676 }
677 
678 template <typename From, typename To>
679 using ToTypeIfConvertible =
680     typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
681 
682 template <typename T>
683 struct CudaSupportedTypeImpl {
684   using type = T;
685 };
686 
687 template <>
688 struct CudaSupportedTypeImpl<long long> {
689   using type = unsigned long long;
690 };
691 
692 template <>
693 struct CudaSupportedTypeImpl<unsigned long> {
694   using type =
695       typename std::conditional<sizeof(unsigned long) == sizeof(unsigned int),
696                                 unsigned int, unsigned long long>::type;
697 };
698 
699 template <>
700 struct CudaSupportedTypeImpl<long> {
701   // This cast should be safe since module-2 addition should work fine. However,
702   // signed overflow is not handled correctly since it's undefined behavior.
703   using type = typename CudaSupportedTypeImpl<unsigned long>::type;
704 };
705 
706 template <typename T>
707 using CudaSupportedType = typename CudaSupportedTypeImpl<T>::type;
708 
709 template <typename T>
710 __device__ CudaSupportedType<T>* ToCudaSupportedPtr(T* ptr) {
711   return reinterpret_cast<CudaSupportedType<T>*>(ptr);
712 }
713 
714 }  // namespace detail
715 
716 // CUDA provides atomic ops, but not for all types.  We provide wrappers
717 // for some ops and provide implementation for all reasonable types.
718 
719 template <typename T, typename U>
720 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicAdd(T* ptr, U value) {
721   return atomicAdd(detail::ToCudaSupportedPtr(ptr), value);
722 }
723 
724 __device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
725                                            Eigen::half value) {
726   return detail::GpuAtomicCasHelper(
727       ptr, [value](Eigen::half a) { return a + value; });
728 }
729 
730 #if (__CUDA_ARCH__ < 600) || TENSORFLOW_USE_ROCM
731 __device__ inline double GpuAtomicAdd(double* ptr, double value) {
732   return detail::GpuAtomicCasHelper(ptr,
733                                     [value](double a) { return a + value; });
734 }
735 #endif
736 
737 // GpuAtomicAdd
738 // Specializations of GpuAtomicAdd for complex types, which GpuAtomicAdd does
739 // not support. We treat a std::complex<T>* as a T* (the C++ standard section
740 // 26.4.4 allows this explicitly) and atomic add the real and imaginary
741 // components individually. The operation as a whole is not atomic, but we can
742 // safely treat the components independently for the purpose of accumulating.
743 
744 // ROCM TODO support GpuAtomicAdd for std::complex<>
745 #if GOOGLE_CUDA
746 __device__ inline std::complex<float> GpuAtomicAdd(std::complex<float>* ptr,
747                                                    std::complex<float> value) {
748   auto ptr_scalar = reinterpret_cast<float*>(ptr);
749   return std::complex<float>(GpuAtomicAdd(ptr_scalar, value.real()),
750                              GpuAtomicAdd(ptr_scalar + 1, value.imag()));
751 }
752 
753 __device__ inline std::complex<double> GpuAtomicAdd(
754     std::complex<double>* ptr, std::complex<double> value) {
755   auto ptr_scalar = reinterpret_cast<double*>(ptr);
756   return std::complex<double>(GpuAtomicAdd(ptr_scalar, value.real()),
757                               GpuAtomicAdd(ptr_scalar + 1, value.imag()));
758 }
759 #endif
760 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicAdd, CudaAtomicAdd);
761 
762 // GpuAtomicSub
763 template <typename T, typename U>
764 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicSub(T* ptr, U value) {
765   return atomicSub(ptr, value);
766 }
767 
768 // Specializations of substraction which add the negative value.
769 __device__ inline float GpuAtomicSub(float* ptr, float value) {
770   return GpuAtomicAdd(ptr, -value);
771 }
772 
773 __device__ inline double GpuAtomicSub(double* ptr, double value) {
774   return GpuAtomicAdd(ptr, -value);
775 }
776 
777 __device__ inline tensorflow::int64 GpuAtomicSub(tensorflow::int64* ptr,
778                                                  tensorflow::int64 value) {
779   return GpuAtomicAdd(ptr, -value);
780 }
781 
782 __device__ inline tensorflow::uint64 GpuAtomicSub(tensorflow::uint64* ptr,
783                                                   tensorflow::uint64 value) {
784   return GpuAtomicAdd(ptr, -static_cast<tensorflow::int64>(value));
785 }
786 
787 __device__ inline Eigen::half GpuAtomicSub(Eigen::half* ptr,
788                                            Eigen::half value) {
789   return detail::GpuAtomicCasHelper(
790       ptr, [value](Eigen::half a) { return a - value; });
791 }
792 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub);
793 
794 // GpuAtomicMax
795 template <typename T, typename U>
796 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMax(T* ptr, U value) {
797   return atomicMax(detail::ToCudaSupportedPtr(ptr), value);
798 }
799 
800 #if TENSORFLOW_USE_ROCM
801 
802 /*
803  * CUDA runtime headers have the following defined
804  *   __device__  int max(int, int)
805  *   __device__  float max(float, float)
806  *   __device__  double max(double, double)
807  *
808  * and many others, where as HIP runtime headers only have the "int" version
809  *
810  * Therefore need to special case ROCm version to call the correct underlying
811  * routines for float and double types.
812  *
813  */
814 
815 __device__ inline float GpuAtomicMax(float* ptr, float value) {
816   return detail::GpuAtomicCasHelper(
817       ptr, [value](float a) { return fmaxf(a, value); });
818 }
819 
820 __device__ inline double GpuAtomicMax(double* ptr, double value) {
821   return detail::GpuAtomicCasHelper(
822       ptr, [value](double a) { return fmax(a, value); });
823 }
824 
825 #else
826 
827 __device__ inline float GpuAtomicMax(float* ptr, float value) {
828   return detail::GpuAtomicCasHelper(ptr,
829                                     [value](float a) { return max(a, value); });
830 }
831 
832 __device__ inline double GpuAtomicMax(double* ptr, double value) {
833   return detail::GpuAtomicCasHelper(
834       ptr, [value](double a) { return max(a, value); });
835 }
836 
837 #endif
838 
839 __device__ inline Eigen::half GpuAtomicMax(Eigen::half* ptr,
840                                            Eigen::half value) {
841   return detail::GpuAtomicCasHelper(
842       ptr, [value](Eigen::half a) { return max(a, value); });
843 }
844 
845 #if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320)
846 __device__ inline tensorflow::uint64 GpuAtomicMax(tensorflow::uint64* ptr,
847                                                   tensorflow::uint64 value) {
848   return detail::GpuAtomicCasHelper(
849       detail::ToCudaSupportedPtr(ptr),
850       [value](tensorflow::uint64 a) { return max(a, value); });
851 }
852 
853 __device__ inline int64 GpuAtomicMax(int64* ptr, int64 value) {
854   return detail::GpuAtomicCasHelper(detail::ToCudaSupportedPtr(ptr),
855                                     [value](int64 a) { return max(a, value); });
856 }
857 #endif
858 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax);
859 
860 // GpuAtomicMin
861 template <typename T, typename U>
862 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMin(T* ptr, U value) {
863   return atomicMin(detail::ToCudaSupportedPtr(ptr), value);
864 }
865 
866 #if TENSORFLOW_USE_ROCM
867 
868 /*
869  * CUDA runtime headers have the following defined
870  *   __device__  int min(int, int)
871  *   __device__  float min(float, float)
872  *   __device__  double min(double, double)
873  *
874  * and many others, where as HIP runtime headers only have the "int" version
875  *
876  * Therefore need to special case ROCm version to call the correct underlying
877  * routines for float and double types.
878  *
879  */
880 
881 __device__ inline float GpuAtomicMin(float* ptr, float value) {
882   return detail::GpuAtomicCasHelper(
883       ptr, [value](float a) { return fminf(a, value); });
884 }
885 
886 __device__ inline double GpuAtomicMin(double* ptr, double value) {
887   return detail::GpuAtomicCasHelper(
888       ptr, [value](double a) { return fmin(a, value); });
889 }
890 
891 #else
892 
893 __device__ inline float GpuAtomicMin(float* ptr, float value) {
894   return detail::GpuAtomicCasHelper(ptr,
895                                     [value](float a) { return min(a, value); });
896 }
897 
898 __device__ inline double GpuAtomicMin(double* ptr, double value) {
899   return detail::GpuAtomicCasHelper(
900       ptr, [value](double a) { return min(a, value); });
901 }
902 
903 #endif
904 
905 __device__ inline Eigen::half GpuAtomicMin(Eigen::half* ptr,
906                                            Eigen::half value) {
907   return detail::GpuAtomicCasHelper(
908       ptr, [value](Eigen::half a) { return min(a, value); });
909 }
910 
911 #if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320)
912 __device__ inline tensorflow::uint64 GpuAtomicMin(tensorflow::uint64* ptr,
913                                                   tensorflow::uint64 value) {
914   return detail::GpuAtomicCasHelper(
915       detail::ToCudaSupportedPtr(ptr),
916       [value](tensorflow::uint64 a) { return min(a, value); });
917 }
918 
919 __device__ inline int64 GpuAtomicMin(int64* ptr, int64 value) {
920   return detail::GpuAtomicCasHelper(detail::ToCudaSupportedPtr(ptr),
921                                     [value](int64 a) { return min(a, value); });
922 }
923 #endif
924 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMin, CudaAtomicMin);
925 
926 // GpuAtomicMul
927 template <typename T, typename U>
928 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMul(T* ptr, U value) {
929   return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a * value; });
930 }
931 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMul, CudaAtomicMul);
932 
933 // GpuAtomicDiv
934 template <typename T, typename U>
935 __device__ detail::ToTypeIfConvertible<U, T> GpuAtomicDiv(T* ptr, U value) {
936   return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a / value; });
937 }
938 CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicDiv, CudaAtomicDiv);
939 
940 // Operator overloads for complex numbers.
941 #if GOOGLE_CUDA
942 __device__ inline std::complex<float> operator+(const std::complex<float>& a,
943                                                 const std::complex<float>& b) {
944   auto result = cuCaddf(make_cuComplex(a.real(), a.imag()),
945                         make_cuComplex(b.real(), b.imag()));
946   return std::complex<float>(result.x, result.y);
947 }
948 
949 __device__ inline std::complex<float> operator-(const std::complex<float>& a,
950                                                 const std::complex<float>& b) {
951   auto result = cuCsubf(make_cuComplex(a.real(), a.imag()),
952                         make_cuComplex(b.real(), b.imag()));
953   return std::complex<float>(result.x, result.y);
954 }
955 
956 __device__ inline std::complex<float> operator*(const std::complex<float>& a,
957                                                 const std::complex<float>& b) {
958   auto result = cuCmulf(make_cuComplex(a.real(), a.imag()),
959                         make_cuComplex(b.real(), b.imag()));
960   return std::complex<float>(result.x, result.y);
961 }
962 
963 __device__ inline std::complex<float> operator/(const std::complex<float>& a,
964                                                 const std::complex<float>& b) {
965   auto result = cuCdivf(make_cuComplex(a.real(), a.imag()),
966                         make_cuComplex(b.real(), b.imag()));
967   return std::complex<float>(result.x, result.y);
968 }
969 
970 __device__ inline std::complex<double> operator+(
971     const std::complex<double>& a, const std::complex<double>& b) {
972   auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()),
973                        make_cuDoubleComplex(b.real(), b.imag()));
974   return std::complex<double>(result.x, result.y);
975 }
976 
977 __device__ inline std::complex<double> operator-(
978     const std::complex<double>& a, const std::complex<double>& b) {
979   auto result = cuCsub(make_cuDoubleComplex(a.real(), a.imag()),
980                        make_cuDoubleComplex(b.real(), b.imag()));
981   return std::complex<double>(result.x, result.y);
982 }
983 
984 __device__ inline std::complex<double> operator*(
985     const std::complex<double>& a, const std::complex<double>& b) {
986   auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()),
987                        make_cuDoubleComplex(b.real(), b.imag()));
988   return std::complex<double>(result.x, result.y);
989 }
990 
991 __device__ inline std::complex<double> operator/(
992     const std::complex<double>& a, const std::complex<double>& b) {
993   auto result = cuCdiv(make_cuDoubleComplex(a.real(), a.imag()),
994                        make_cuDoubleComplex(b.real(), b.imag()));
995   return std::complex<double>(result.x, result.y);
996 }
997 #endif  // GOOGLE_CUDA
998 
999 namespace functor {
1000 // ROCm hcc(clang) has severe difficulties dealing with std::complex directly
1001 // due to a header issue. This template assists in casting std::complex into the
1002 // corresponding internal ROCm types.
1003 template <class T>
1004 struct MapComplexToHipComplex {
1005   typedef T TM;
1006 };
1007 
1008 #if TENSORFLOW_USE_ROCM
1009 template <>
1010 struct MapComplexToHipComplex<std::complex<float> > {
1011   typedef hipFloatComplex TM;
1012 };
1013 
1014 template <>
1015 struct MapComplexToHipComplex<std::complex<double> > {
1016   typedef hipDoubleComplex TM;
1017 };
1018 #endif
1019 };  // namespace functor
1020 
1021 }  // namespace tensorflow
1022 
1023 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1024 #endif  // TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_
1025