• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
18 
19 /**
20  * Wrappers and helpers for CUDA device code.
21  *
22  * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
23  * backwards compatibility, see go/volta-porting for details.
24  * Provides atomic operations on types that aren't natively supported.
25  */
26 
27 #if GOOGLE_CUDA
28 
29 #include <algorithm>
30 #include <complex>
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "cuda/include/cuda.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
37 namespace detail {
38 
39 // Helper for range-based for loop using 'delta' increments.
40 // Usage: see CudaGridRange?() functions below.
41 template <typename T>
42 class CudaGridRange {
43   struct Iterator {
IteratorIterator44     __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
45     __device__ T operator*() const { return index_; }
46     __device__ Iterator& operator++() {
47       index_ += delta_;
48       return *this;
49     }
50     __device__ bool operator!=(const Iterator& other) const {
51       bool greater = index_ > other.index_;
52       bool less = index_ < other.index_;
53       // Anything past an end iterator (delta_ == 0) is equal.
54       // In range-based for loops, this optimizes to 'return less'.
55       if (!other.delta_) {
56         return less;
57       }
58       if (!delta_) {
59         return greater;
60       }
61       return less || greater;
62     }
63 
64    private:
65     T index_;
66     const T delta_;
67   };
68 
69  public:
CudaGridRange(T begin,T delta,T end)70   __device__ CudaGridRange(T begin, T delta, T end)
71       : begin_(begin), delta_(delta), end_(end) {}
72 
begin()73   __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
end()74   __device__ Iterator end() const { return Iterator{end_, 0}; }
75 
76  private:
77   T begin_;
78   T delta_;
79   T end_;
80 };
81 
82 }  // namespace detail
83 
84 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate
85 // of the global thread index. That is, each index i is visited by all threads
86 // with the same x-coordinate.
87 // Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
88 template <typename T>
CudaGridRangeX(T count)89 __device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
90   return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
91                                   gridDim.x * blockDim.x, count);
92 }
93 
94 // Helper to visit indices in the range 0 <= i < count using the y-coordinate.
95 // Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
96 template <typename T>
CudaGridRangeY(T count)97 __device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
98   return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
99                                   gridDim.y * blockDim.y, count);
100 }
101 
102 // Helper to visit indices in the range 0 <= i < count using the z-coordinate.
103 // Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
104 template <typename T>
CudaGridRangeZ(T count)105 __device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
106   return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
107                                   gridDim.z * blockDim.z, count);
108 }
109 
110 // Mask for all 32 threads in a warp.
111 const unsigned kCudaWarpAll = 0xffffffff;
112 
113 // Returns the warp lane ID of the calling thread
CudaLaneId()114 __device__ inline unsigned CudaLaneId() {
115   unsigned int lane_id;
116   asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
117   return lane_id;
118 }
119 
120 namespace detail {
121 // Returns true if mask is a valid parameter for __shfl*sync to return a well
122 // defined value, assuming the calling lane will read from src_lane as part of
123 // the shuffle operation.
124 //
125 // Specifically, returns true iff mask has the calling lane bit and the src_lane
126 // bit set, and the src_lane calls this function with the same mask value
127 // (required for the two threads to wait for each other).
128 //
129 // On Volta, for some invalid masks, this function hangs or returns false
130 // positives, because the implementation shuffles with the same mask that
131 // we are validating. Run on Pascal if you suspect that the mask is incorrect.
CudaValidateShuffleSyncMask(unsigned mask,unsigned src_lane)132 __device__ inline bool CudaValidateShuffleSyncMask(unsigned mask,
133                                                    unsigned src_lane) {
134   unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane;
135 #if CUDA_VERSION >= 9000
136   unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
137 #else
138   unsigned src_lane_mask = __shfl(mask, src_lane);
139 #endif
140   return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
141 }
142 
143 // Returns the actual source lane for shuffle.
CudaShuffleGetSrcLane(int src_lane,int width)144 __device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) {
145   int lane_id = CudaLaneId();
146   int lane_base = lane_id & ~width + 1;
147   int lane_offset = src_lane & width - 1;
148   return lane_base + lane_offset;
149 }
150 
151 // Returns the source lane for shuffle up.
CudaShuffleUpGetSrcLane(unsigned delta,int width)152 __device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) {
153   unsigned lane_id = CudaLaneId();
154   if ((lane_id & width - 1) < delta) {
155     return lane_id;
156   }
157   return lane_id - delta;
158 }
159 
160 // Returns the source lane for shuffle down.
CudaShuffleDownGetSrcLane(unsigned delta,int width)161 __device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta,
162                                                      int width) {
163   unsigned lane_id = CudaLaneId();
164   if ((lane_id & width - 1) + delta >= width) {
165     return lane_id;
166   }
167   return lane_id + delta;
168 }
169 
170 // Returns the source lane for shuffle xor.
CudaShuffleXorGetSrcLane(int lane_mask,int width)171 __device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) {
172   int lane_id = CudaLaneId();
173   int src_lane = lane_id ^ lane_mask;
174   if (src_lane > (lane_id | width - 1)) {
175     return lane_id;
176   }
177   return src_lane;
178 }
179 }  // namespace detail
180 
181 // For all *_sync wrappers below, it is illegal to synchronize threads from
182 // different program locations, because that is not supported before sm_70.
183 // In other words, all threads in 'mask' must call the functions in convergence.
184 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
185 //
186 // It is also illegal to shuffle with a mask that produces an undefined result
187 // for any of the threads. Specifically, all source threads of the shuffle
188 // must have their corresponding bit in 'mask' set.
189 
190 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
191 __device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
192   assert(mask & 1u << CudaLaneId());
193 #if CUDA_VERSION >= 9000
194   __syncwarp(mask);
195 #endif
196 }
197 
198 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in
199 // convergence, see comment above for details.
CudaBallotSync(unsigned mask,int pred)200 __device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
201   assert(mask & 1u << CudaLaneId());
202 #if CUDA_VERSION >= 9000
203   return __ballot_sync(mask, pred);
204 #else
205   return __ballot(pred) & mask;  // Apply mask to match __ballot_sync's spec.
206 #endif
207 }
208 
209 // Wrapper for __any_sync. All threads in 'mask' must call this function in
210 // convergence, see comment above for details.
CudaAnySync(unsigned mask,int pred)211 __device__ inline int CudaAnySync(unsigned mask, int pred) {
212   assert(mask & 1u << CudaLaneId());
213 #if CUDA_VERSION >= 9000
214   return __any_sync(mask, pred);
215 #else
216   return __any(pred);
217 #endif
218 }
219 
220 // Wrapper for __all_sync. All threads in 'mask' must call this function in
221 // convergence, see comment above for details.
CudaAllSync(unsigned mask,int pred)222 __device__ inline int CudaAllSync(unsigned mask, int pred) {
223   assert(mask & 1u << CudaLaneId());
224 #if CUDA_VERSION >= 9000
225   return __all_sync(mask, pred);
226 #else
227   return __all(pred);
228 #endif
229 }
230 
231 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in
232 // convergence, see comment above for details.
233 template <typename T>
234 __device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
235                              int width = warpSize) {
236   assert(!(width & width - 1));
237   assert(detail::CudaValidateShuffleSyncMask(
238       mask, detail::CudaShuffleGetSrcLane(src_lane, width)));
239 #if CUDA_VERSION >= 9000
240   return __shfl_sync(mask, value, src_lane, width);
241 #else
242   return __shfl(value, src_lane, width);
243 #endif
244 }
245 
246 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
247 // instead of float for lo and hi (which is incorrect with ftz, for example).
248 // See b/69446944.
249 __device__ inline double CudaShuffleSync(unsigned mask, double value,
250                                          int src_lane, int width = warpSize) {
251   unsigned lo, hi;
252   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
253   hi = CudaShuffleSync(mask, hi, src_lane, width);
254   lo = CudaShuffleSync(mask, lo, src_lane, width);
255   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
256   return value;
257 }
258 
259 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
260 // convergence, see comment above for details.
261 template <typename T>
262 __device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta,
263                                       int width = warpSize) {
264   assert(!(width & width - 1));
265   assert(detail::CudaValidateShuffleSyncMask(
266       mask, detail::CudaShuffleUpGetSrcLane(delta, width)));
267 #if CUDA_VERSION >= 9000
268   return __shfl_up_sync(mask, value, delta, width);
269 #else
270   return __shfl_up(value, delta, width);
271 #endif
272 }
273 
274 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
275 // instead of float for lo and hi (which is incorrect with ftz, for example).
276 // See b/69446944.
277 __device__ inline double CudaShuffleUpSync(unsigned mask, double value,
278                                            unsigned delta,
279                                            int width = warpSize) {
280   unsigned lo, hi;
281   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
282   hi = CudaShuffleUpSync(mask, hi, delta, width);
283   lo = CudaShuffleUpSync(mask, lo, delta, width);
284   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
285   return value;
286 }
287 
288 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
289 // in convergence, see comment above for details.
290 template <typename T>
291 __device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta,
292                                         int width = warpSize) {
293   assert(!(width & width - 1));
294   assert(detail::CudaValidateShuffleSyncMask(
295       mask, detail::CudaShuffleDownGetSrcLane(delta, width)));
296 #if CUDA_VERSION >= 9000
297   return __shfl_down_sync(mask, value, delta, width);
298 #else
299   return __shfl_down(value, delta, width);
300 #endif
301 }
302 
303 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
304 // instead of float for lo and hi (which is incorrect with ftz, for example).
305 // See b/69446944.
306 __device__ inline double CudaShuffleDownSync(unsigned mask, double value,
307                                              unsigned delta,
308                                              int width = warpSize) {
309   unsigned lo, hi;
310   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
311   hi = CudaShuffleDownSync(mask, hi, delta, width);
312   lo = CudaShuffleDownSync(mask, lo, delta, width);
313   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
314   return value;
315 }
316 
317 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
318 // convergence, see comment above for details.
319 template <typename T>
320 __device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
321                                 int width = warpSize) {
322   assert(!(width & width - 1));
323   assert(detail::CudaValidateShuffleSyncMask(
324       mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width)));
325 #if CUDA_VERSION >= 9000
326   return __shfl_xor_sync(mask, value, lane_mask, width);
327 #else
328   return __shfl_xor(value, lane_mask, width);
329 #endif
330 }
331 
332 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
333 // instead of float for lo and hi (which is incorrect with ftz, for example).
334 // See b/69446944.
335 __device__ inline double CudaShuffleXorSync(unsigned mask, double value,
336                                             int lane_mask,
337                                             int width = warpSize) {
338   unsigned lo, hi;
339   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
340   hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
341   lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
342   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
343   return value;
344 }
345 
346 // Wrapper for __ldg.
347 template <typename T>
CudaLdg(const T * address)348 __host__ __device__ T CudaLdg(const T* address) {
349 #if __CUDA_ARCH__ >= 350
350   return __ldg(address);
351 #else
352   return *address;
353 #endif
354 }
355 
CudaLdg(const bool * address)356 __host__ __device__ inline bool CudaLdg(const bool* address) {
357   return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
358 }
359 
CudaLdg(const std::complex<float> * address)360 __host__ __device__ inline std::complex<float> CudaLdg(
361     const std::complex<float>* address) {
362 #if __CUDA_ARCH__ >= 350
363   float2 mem = __ldg(reinterpret_cast<const float2*>(address));
364   return std::complex<float>(mem.x, mem.y);
365 #else
366   return *address;
367 #endif
368 }
369 
CudaLdg(const std::complex<double> * address)370 __host__ __device__ inline std::complex<double> CudaLdg(
371     const std::complex<double>* address) {
372 #if __CUDA_ARCH__ >= 350
373   double2 mem = __ldg(reinterpret_cast<const double2*>(address));
374   return std::complex<double>(mem.x, mem.y);
375 #else
376   return *address;
377 #endif
378 }
379 
380 // Zeroes count elements starting at ptr using all threads of a 1-D grid.
381 // Note: this function does not synchronize, and therefore the memory range is
382 // not guaranteed to be zero until the next kernel launch.
383 template <typename T>
SetZero(const int count,T * ptr)384 __global__ void SetZero(const int count, T* ptr) {
385   // Check that the grid is one dimensional and index doesn't overflow.
386   assert(blockDim.y == 1 && blockDim.z == 1);
387   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
388   for (int i : CudaGridRangeX(count)) {
389     ptr[i] = T(0);
390   }
391 }
392 
393 // Helper to set all tensor entries to a specific value.
394 template <typename T>
SetToValue(const int count,T * ptr,T value)395 __global__ void SetToValue(const int count, T* ptr, T value) {
396   // Check that the grid is one dimensional and index doesn't overflow.
397   assert(blockDim.y == 1 && blockDim.z == 1);
398   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
399   for (int i : CudaGridRangeX(count)) {
400     ptr[i] = value;
401   }
402 }
403 
404 namespace detail {
405 // Helper function for atomic accumulation implemented as CAS.
406 template <typename T, typename F>
CudaAtomicCasHelper(T * ptr,F accumulate)407 __device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
408   T old = *ptr;
409   T assumed;
410   do {
411     assumed = old;
412     old = atomicCAS(ptr, assumed, accumulate(assumed));
413   } while (assumed != old);
414   return old;
415 }
416 
417 // Overload for floating point (using integer comparison to handle NaN
418 // correctly).
419 template <typename F>
CudaAtomicCasHelper(float * ptr,F accumulate)420 __device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
421   return __float_as_int(
422       CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
423         return __float_as_int(accumulate(__int_as_float(a)));
424       }));
425 }
426 template <typename F>
CudaAtomicCasHelper(double * ptr,F accumulate)427 __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
428   return __longlong_as_double(CudaAtomicCasHelper(
429       reinterpret_cast<tensorflow::uint64*>(ptr),
430       [accumulate](tensorflow::uint64 a) {
431         return __double_as_longlong(accumulate(__longlong_as_double(a)));
432       }));
433 }
434 
435 // Overload of above function for half. Note that we don't have
436 // atomicCAS() for anything less than 32 bits, so we need to include the
437 // other 16 bits in the operation.
438 //
439 // This version is going to be very slow
440 // under high concurrency, since most threads will be spinning on failing
441 // their compare-and-swap tests. (The fact that we get false sharing on the
442 // neighboring fp16 makes this even worse.) If you are doing a large reduction,
443 // you are much better off with doing the intermediate steps in fp32 and then
444 // switching to fp16 as late as you can in the calculations.
445 //
446 // Note: Assumes little endian.
447 template <typename F>
CudaAtomicCasHelper(Eigen::half * ptr,F accumulate)448 __device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
449 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
450   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
451 #endif
452   namespace half_impl = Eigen::half_impl;
453   intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
454   assert(!(intptr & 0x1));  // should be 2-aligned.
455   if (intptr & 0x2) {
456     // The half is in the second part of the uint32 (upper 16 bits).
457     uint32* address = reinterpret_cast<uint32*>(intptr - 2);
458     uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
459       unsigned short high = static_cast<unsigned short>(arg >> 16);
460       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
461       return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
462     });
463     return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
464   } else {
465     // The half is in the first part of the uint32 (lower 16 bits).
466     uint32* address = reinterpret_cast<uint32*>(intptr);
467     uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
468       unsigned short low = static_cast<unsigned short>(arg & 0xffff);
469       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
470       return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
471     });
472     return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
473   }
474 }
475 
476 template <typename From, typename To>
477 using ToTypeIfConvertible =
478     typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
479 
480 }  // namespace detail
481 
482 // CUDA provides atomic ops, but not for all types.  We provide wrappers
483 // for some ops and provide implementation for all reasonable types.
484 
485 template <typename T, typename U>
CudaAtomicAdd(T * ptr,U value)486 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) {
487   return atomicAdd(ptr, value);
488 }
489 
CudaAtomicAdd(Eigen::half * ptr,Eigen::half value)490 __device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
491                                             Eigen::half value) {
492   return detail::CudaAtomicCasHelper(
493       ptr, [value](Eigen::half a) { return a + value; });
494 }
495 
496 
497 #if __CUDA_ARCH__ < 600
CudaAtomicAdd(double * ptr,double value)498 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
499   return detail::CudaAtomicCasHelper(ptr,
500                                      [value](double a) { return a + value; });
501 }
502 #elif __clang__
503 // Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
504 // see https://reviews.llvm.org/D39638
CudaAtomicAdd(double * ptr,double value)505 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
506   double result;
507   asm volatile("atom.add.f64 %0, [%1], %2;"
508                : "=d"(result)
509                : "l"(ptr), "d"(value)
510                : "memory");
511   return result;
512 }
513 #endif
514 // CudaAtomicAdd
515 // Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does
516 // not support. We treat a std::complex<T>* as a T* (the C++ standard section
517 // 26.4.4 allows this explicitly) and atomic add the real and imaginary
518 // components individually. The operation as a whole is not atomic, but we can
519 // safely treat the components independently for the purpose of accumulating.
CudaAtomicAdd(std::complex<float> * ptr,std::complex<float> value)520 __device__ inline std::complex<float> CudaAtomicAdd(std::complex<float>* ptr,
521                                                     std::complex<float> value) {
522   auto ptr_scalar = reinterpret_cast<float*>(ptr);
523   return std::complex<float>(CudaAtomicAdd(ptr_scalar, value.real()),
524                              CudaAtomicAdd(ptr_scalar + 1, value.imag()));
525 }
526 
CudaAtomicAdd(std::complex<double> * ptr,std::complex<double> value)527 __device__ inline std::complex<double> CudaAtomicAdd(
528     std::complex<double>* ptr, std::complex<double> value) {
529   auto ptr_scalar = reinterpret_cast<double*>(ptr);
530   return std::complex<double>(CudaAtomicAdd(ptr_scalar, value.real()),
531                               CudaAtomicAdd(ptr_scalar + 1, value.imag()));
532 }
533 
534 // CudaAtomicSub
535 template <typename T, typename U>
CudaAtomicSub(T * ptr,U value)536 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) {
537   return atomicSub(ptr, value);
538 }
539 
540 // Specializations of subtraction which add the negative value.
CudaAtomicSub(float * ptr,float value)541 __device__ inline float CudaAtomicSub(float* ptr, float value) {
542   return CudaAtomicAdd(ptr, -value);
543 }
544 
CudaAtomicSub(double * ptr,double value)545 __device__ inline double CudaAtomicSub(double* ptr, double value) {
546   return CudaAtomicAdd(ptr, -value);
547 }
548 
CudaAtomicSub(tensorflow::uint64 * ptr,tensorflow::uint64 value)549 __device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
550                                                    tensorflow::uint64 value) {
551   return CudaAtomicAdd(ptr, -value);
552 }
553 
CudaAtomicSub(Eigen::half * ptr,Eigen::half value)554 __device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
555                                             Eigen::half value) {
556   return detail::CudaAtomicCasHelper(
557       ptr, [value](Eigen::half a) { return a - value; });
558 }
559 
560 // CudaAtomicMax
561 template <typename T, typename U>
CudaAtomicMax(T * ptr,U value)562 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) {
563   return atomicMax(ptr, value);
564 }
565 
CudaAtomicMax(float * ptr,float value)566 __device__ inline float CudaAtomicMax(float* ptr, float value) {
567   return detail::CudaAtomicCasHelper(
568       ptr, [value](float a) { return max(a, value); });
569 }
570 
CudaAtomicMax(double * ptr,double value)571 __device__ inline double CudaAtomicMax(double* ptr, double value) {
572   return detail::CudaAtomicCasHelper(
573       ptr, [value](double a) { return max(a, value); });
574 }
575 
CudaAtomicMax(Eigen::half * ptr,Eigen::half value)576 __device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr,
577                                             Eigen::half value) {
578   return detail::CudaAtomicCasHelper(
579       ptr, [value](Eigen::half a) { return max(a, value); });
580 }
581 
582 #if __CUDA_ARCH__ < 320
CudaAtomicMax(tensorflow::uint64 * ptr,tensorflow::uint64 value)583 __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
584                                                    tensorflow::uint64 value) {
585   return detail::CudaAtomicCasHelper(
586       ptr, [value](tensorflow::uint64 a) { return max(a, value); });
587 }
588 #endif
589 
590 // CudaAtomicMin
591 template <typename T, typename U>
CudaAtomicMin(T * ptr,U value)592 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMin(T* ptr, U value) {
593   return atomicMin(ptr, value);
594 }
595 
CudaAtomicMin(float * ptr,float value)596 __device__ inline float CudaAtomicMin(float* ptr, float value) {
597   return detail::CudaAtomicCasHelper(
598       ptr, [value](float a) { return min(a, value); });
599 }
600 
CudaAtomicMin(double * ptr,double value)601 __device__ inline double CudaAtomicMin(double* ptr, double value) {
602   return detail::CudaAtomicCasHelper(
603       ptr, [value](double a) { return min(a, value); });
604 }
605 
CudaAtomicMin(Eigen::half * ptr,Eigen::half value)606 __device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr,
607                                             Eigen::half value) {
608   return detail::CudaAtomicCasHelper(
609       ptr, [value](Eigen::half a) { return min(a, value); });
610 }
611 
612 #if __CUDA_ARCH__ < 320
CudaAtomicMin(tensorflow::uint64 * ptr,tensorflow::uint64 value)613 __device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr,
614                                                    tensorflow::uint64 value) {
615   return detail::CudaAtomicCasHelper(
616       ptr, [value](tensorflow::uint64 a) { return min(a, value); });
617 }
618 #endif
619 
620 // CudaAtomicMul
621 template <typename T, typename U>
CudaAtomicMul(T * ptr,U value)622 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) {
623   return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
624 }
625 
626 // CudaAtomicDiv
627 template <typename T, typename U>
CudaAtomicDiv(T * ptr,U value)628 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) {
629   return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
630 }
631 
632 }  // namespace tensorflow
633 
634 #endif  // GOOGLE_CUDA
635 #endif  // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
636