1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
18
19 /**
20 * Wrappers and helpers for CUDA device code.
21 *
22 * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
23 * backwards compatibility, see go/volta-porting for details.
24 * Provides atomic operations on types that aren't natively supported.
25 */
26
27 #if GOOGLE_CUDA
28
29 #include <algorithm>
30 #include <complex>
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "cuda/include/cuda.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 namespace detail {
38
39 // Helper for range-based for loop using 'delta' increments.
40 // Usage: see CudaGridRange?() functions below.
41 template <typename T>
42 class CudaGridRange {
43 struct Iterator {
IteratorIterator44 __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
45 __device__ T operator*() const { return index_; }
46 __device__ Iterator& operator++() {
47 index_ += delta_;
48 return *this;
49 }
50 __device__ bool operator!=(const Iterator& other) const {
51 bool greater = index_ > other.index_;
52 bool less = index_ < other.index_;
53 // Anything past an end iterator (delta_ == 0) is equal.
54 // In range-based for loops, this optimizes to 'return less'.
55 if (!other.delta_) {
56 return less;
57 }
58 if (!delta_) {
59 return greater;
60 }
61 return less || greater;
62 }
63
64 private:
65 T index_;
66 const T delta_;
67 };
68
69 public:
CudaGridRange(T begin,T delta,T end)70 __device__ CudaGridRange(T begin, T delta, T end)
71 : begin_(begin), delta_(delta), end_(end) {}
72
begin()73 __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
end()74 __device__ Iterator end() const { return Iterator{end_, 0}; }
75
76 private:
77 T begin_;
78 T delta_;
79 T end_;
80 };
81
82 } // namespace detail
83
84 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate
85 // of the global thread index. That is, each index i is visited by all threads
86 // with the same x-coordinate.
87 // Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
88 template <typename T>
CudaGridRangeX(T count)89 __device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
90 return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
91 gridDim.x * blockDim.x, count);
92 }
93
94 // Helper to visit indices in the range 0 <= i < count using the y-coordinate.
95 // Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
96 template <typename T>
CudaGridRangeY(T count)97 __device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
98 return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
99 gridDim.y * blockDim.y, count);
100 }
101
102 // Helper to visit indices in the range 0 <= i < count using the z-coordinate.
103 // Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
104 template <typename T>
CudaGridRangeZ(T count)105 __device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
106 return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
107 gridDim.z * blockDim.z, count);
108 }
109
110 // Mask for all 32 threads in a warp.
111 const unsigned kCudaWarpAll = 0xffffffff;
112
113 // Returns the warp lane ID of the calling thread
CudaLaneId()114 __device__ inline unsigned CudaLaneId() {
115 unsigned int lane_id;
116 asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
117 return lane_id;
118 }
119
120 namespace detail {
121 // Returns true if mask is a valid parameter for __shfl*sync to return a well
122 // defined value, assuming the calling lane will read from src_lane as part of
123 // the shuffle operation.
124 //
125 // Specifically, returns true iff mask has the calling lane bit and the src_lane
126 // bit set, and the src_lane calls this function with the same mask value
127 // (required for the two threads to wait for each other).
128 //
129 // On Volta, for some invalid masks, this function hangs or returns false
130 // positives, because the implementation shuffles with the same mask that
131 // we are validating. Run on Pascal if you suspect that the mask is incorrect.
CudaValidateShuffleSyncMask(unsigned mask,unsigned src_lane)132 __device__ inline bool CudaValidateShuffleSyncMask(unsigned mask,
133 unsigned src_lane) {
134 unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane;
135 #if CUDA_VERSION >= 9000
136 unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
137 #else
138 unsigned src_lane_mask = __shfl(mask, src_lane);
139 #endif
140 return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
141 }
142
143 // Returns the actual source lane for shuffle.
CudaShuffleGetSrcLane(int src_lane,int width)144 __device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) {
145 int lane_id = CudaLaneId();
146 int lane_base = lane_id & ~width + 1;
147 int lane_offset = src_lane & width - 1;
148 return lane_base + lane_offset;
149 }
150
151 // Returns the source lane for shuffle up.
CudaShuffleUpGetSrcLane(unsigned delta,int width)152 __device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) {
153 unsigned lane_id = CudaLaneId();
154 if ((lane_id & width - 1) < delta) {
155 return lane_id;
156 }
157 return lane_id - delta;
158 }
159
160 // Returns the source lane for shuffle down.
CudaShuffleDownGetSrcLane(unsigned delta,int width)161 __device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta,
162 int width) {
163 unsigned lane_id = CudaLaneId();
164 if ((lane_id & width - 1) + delta >= width) {
165 return lane_id;
166 }
167 return lane_id + delta;
168 }
169
170 // Returns the source lane for shuffle xor.
CudaShuffleXorGetSrcLane(int lane_mask,int width)171 __device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) {
172 int lane_id = CudaLaneId();
173 int src_lane = lane_id ^ lane_mask;
174 if (src_lane > (lane_id | width - 1)) {
175 return lane_id;
176 }
177 return src_lane;
178 }
179 } // namespace detail
180
181 // For all *_sync wrappers below, it is illegal to synchronize threads from
182 // different program locations, because that is not supported before sm_70.
183 // In other words, all threads in 'mask' must call the functions in convergence.
184 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
185 //
186 // It is also illegal to shuffle with a mask that produces an undefined result
187 // for any of the threads. Specifically, all source threads of the shuffle
188 // must have their corresponding bit in 'mask' set.
189
190 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
191 __device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
192 assert(mask & 1u << CudaLaneId());
193 #if CUDA_VERSION >= 9000
194 __syncwarp(mask);
195 #endif
196 }
197
198 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in
199 // convergence, see comment above for details.
CudaBallotSync(unsigned mask,int pred)200 __device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
201 assert(mask & 1u << CudaLaneId());
202 #if CUDA_VERSION >= 9000
203 return __ballot_sync(mask, pred);
204 #else
205 return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec.
206 #endif
207 }
208
209 // Wrapper for __any_sync. All threads in 'mask' must call this function in
210 // convergence, see comment above for details.
CudaAnySync(unsigned mask,int pred)211 __device__ inline int CudaAnySync(unsigned mask, int pred) {
212 assert(mask & 1u << CudaLaneId());
213 #if CUDA_VERSION >= 9000
214 return __any_sync(mask, pred);
215 #else
216 return __any(pred);
217 #endif
218 }
219
220 // Wrapper for __all_sync. All threads in 'mask' must call this function in
221 // convergence, see comment above for details.
CudaAllSync(unsigned mask,int pred)222 __device__ inline int CudaAllSync(unsigned mask, int pred) {
223 assert(mask & 1u << CudaLaneId());
224 #if CUDA_VERSION >= 9000
225 return __all_sync(mask, pred);
226 #else
227 return __all(pred);
228 #endif
229 }
230
231 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in
232 // convergence, see comment above for details.
233 template <typename T>
234 __device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
235 int width = warpSize) {
236 assert(!(width & width - 1));
237 assert(detail::CudaValidateShuffleSyncMask(
238 mask, detail::CudaShuffleGetSrcLane(src_lane, width)));
239 #if CUDA_VERSION >= 9000
240 return __shfl_sync(mask, value, src_lane, width);
241 #else
242 return __shfl(value, src_lane, width);
243 #endif
244 }
245
246 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
247 // instead of float for lo and hi (which is incorrect with ftz, for example).
248 // See b/69446944.
249 __device__ inline double CudaShuffleSync(unsigned mask, double value,
250 int src_lane, int width = warpSize) {
251 unsigned lo, hi;
252 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
253 hi = CudaShuffleSync(mask, hi, src_lane, width);
254 lo = CudaShuffleSync(mask, lo, src_lane, width);
255 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
256 return value;
257 }
258
259 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
260 // convergence, see comment above for details.
261 template <typename T>
262 __device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta,
263 int width = warpSize) {
264 assert(!(width & width - 1));
265 assert(detail::CudaValidateShuffleSyncMask(
266 mask, detail::CudaShuffleUpGetSrcLane(delta, width)));
267 #if CUDA_VERSION >= 9000
268 return __shfl_up_sync(mask, value, delta, width);
269 #else
270 return __shfl_up(value, delta, width);
271 #endif
272 }
273
274 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
275 // instead of float for lo and hi (which is incorrect with ftz, for example).
276 // See b/69446944.
277 __device__ inline double CudaShuffleUpSync(unsigned mask, double value,
278 unsigned delta,
279 int width = warpSize) {
280 unsigned lo, hi;
281 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
282 hi = CudaShuffleUpSync(mask, hi, delta, width);
283 lo = CudaShuffleUpSync(mask, lo, delta, width);
284 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
285 return value;
286 }
287
288 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
289 // in convergence, see comment above for details.
290 template <typename T>
291 __device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta,
292 int width = warpSize) {
293 assert(!(width & width - 1));
294 assert(detail::CudaValidateShuffleSyncMask(
295 mask, detail::CudaShuffleDownGetSrcLane(delta, width)));
296 #if CUDA_VERSION >= 9000
297 return __shfl_down_sync(mask, value, delta, width);
298 #else
299 return __shfl_down(value, delta, width);
300 #endif
301 }
302
303 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
304 // instead of float for lo and hi (which is incorrect with ftz, for example).
305 // See b/69446944.
306 __device__ inline double CudaShuffleDownSync(unsigned mask, double value,
307 unsigned delta,
308 int width = warpSize) {
309 unsigned lo, hi;
310 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
311 hi = CudaShuffleDownSync(mask, hi, delta, width);
312 lo = CudaShuffleDownSync(mask, lo, delta, width);
313 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
314 return value;
315 }
316
317 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
318 // convergence, see comment above for details.
319 template <typename T>
320 __device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
321 int width = warpSize) {
322 assert(!(width & width - 1));
323 assert(detail::CudaValidateShuffleSyncMask(
324 mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width)));
325 #if CUDA_VERSION >= 9000
326 return __shfl_xor_sync(mask, value, lane_mask, width);
327 #else
328 return __shfl_xor(value, lane_mask, width);
329 #endif
330 }
331
332 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
333 // instead of float for lo and hi (which is incorrect with ftz, for example).
334 // See b/69446944.
335 __device__ inline double CudaShuffleXorSync(unsigned mask, double value,
336 int lane_mask,
337 int width = warpSize) {
338 unsigned lo, hi;
339 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
340 hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
341 lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
342 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
343 return value;
344 }
345
346 // Wrapper for __ldg.
347 template <typename T>
CudaLdg(const T * address)348 __host__ __device__ T CudaLdg(const T* address) {
349 #if __CUDA_ARCH__ >= 350
350 return __ldg(address);
351 #else
352 return *address;
353 #endif
354 }
355
CudaLdg(const bool * address)356 __host__ __device__ inline bool CudaLdg(const bool* address) {
357 return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
358 }
359
CudaLdg(const std::complex<float> * address)360 __host__ __device__ inline std::complex<float> CudaLdg(
361 const std::complex<float>* address) {
362 #if __CUDA_ARCH__ >= 350
363 float2 mem = __ldg(reinterpret_cast<const float2*>(address));
364 return std::complex<float>(mem.x, mem.y);
365 #else
366 return *address;
367 #endif
368 }
369
CudaLdg(const std::complex<double> * address)370 __host__ __device__ inline std::complex<double> CudaLdg(
371 const std::complex<double>* address) {
372 #if __CUDA_ARCH__ >= 350
373 double2 mem = __ldg(reinterpret_cast<const double2*>(address));
374 return std::complex<double>(mem.x, mem.y);
375 #else
376 return *address;
377 #endif
378 }
379
380 // Zeroes count elements starting at ptr using all threads of a 1-D grid.
381 // Note: this function does not synchronize, and therefore the memory range is
382 // not guaranteed to be zero until the next kernel launch.
383 template <typename T>
SetZero(const int count,T * ptr)384 __global__ void SetZero(const int count, T* ptr) {
385 // Check that the grid is one dimensional and index doesn't overflow.
386 assert(blockDim.y == 1 && blockDim.z == 1);
387 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
388 for (int i : CudaGridRangeX(count)) {
389 ptr[i] = T(0);
390 }
391 }
392
393 // Helper to set all tensor entries to a specific value.
394 template <typename T>
SetToValue(const int count,T * ptr,T value)395 __global__ void SetToValue(const int count, T* ptr, T value) {
396 // Check that the grid is one dimensional and index doesn't overflow.
397 assert(blockDim.y == 1 && blockDim.z == 1);
398 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
399 for (int i : CudaGridRangeX(count)) {
400 ptr[i] = value;
401 }
402 }
403
404 namespace detail {
405 // Helper function for atomic accumulation implemented as CAS.
406 template <typename T, typename F>
CudaAtomicCasHelper(T * ptr,F accumulate)407 __device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
408 T old = *ptr;
409 T assumed;
410 do {
411 assumed = old;
412 old = atomicCAS(ptr, assumed, accumulate(assumed));
413 } while (assumed != old);
414 return old;
415 }
416
417 // Overload for floating point (using integer comparison to handle NaN
418 // correctly).
419 template <typename F>
CudaAtomicCasHelper(float * ptr,F accumulate)420 __device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
421 return __float_as_int(
422 CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
423 return __float_as_int(accumulate(__int_as_float(a)));
424 }));
425 }
426 template <typename F>
CudaAtomicCasHelper(double * ptr,F accumulate)427 __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
428 return __longlong_as_double(CudaAtomicCasHelper(
429 reinterpret_cast<tensorflow::uint64*>(ptr),
430 [accumulate](tensorflow::uint64 a) {
431 return __double_as_longlong(accumulate(__longlong_as_double(a)));
432 }));
433 }
434
435 // Overload of above function for half. Note that we don't have
436 // atomicCAS() for anything less than 32 bits, so we need to include the
437 // other 16 bits in the operation.
438 //
439 // This version is going to be very slow
440 // under high concurrency, since most threads will be spinning on failing
441 // their compare-and-swap tests. (The fact that we get false sharing on the
442 // neighboring fp16 makes this even worse.) If you are doing a large reduction,
443 // you are much better off with doing the intermediate steps in fp32 and then
444 // switching to fp16 as late as you can in the calculations.
445 //
446 // Note: Assumes little endian.
447 template <typename F>
CudaAtomicCasHelper(Eigen::half * ptr,F accumulate)448 __device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
449 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
450 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
451 #endif
452 namespace half_impl = Eigen::half_impl;
453 intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
454 assert(!(intptr & 0x1)); // should be 2-aligned.
455 if (intptr & 0x2) {
456 // The half is in the second part of the uint32 (upper 16 bits).
457 uint32* address = reinterpret_cast<uint32*>(intptr - 2);
458 uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
459 unsigned short high = static_cast<unsigned short>(arg >> 16);
460 Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
461 return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
462 });
463 return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
464 } else {
465 // The half is in the first part of the uint32 (lower 16 bits).
466 uint32* address = reinterpret_cast<uint32*>(intptr);
467 uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
468 unsigned short low = static_cast<unsigned short>(arg & 0xffff);
469 Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
470 return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
471 });
472 return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
473 }
474 }
475
476 template <typename From, typename To>
477 using ToTypeIfConvertible =
478 typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
479
480 } // namespace detail
481
482 // CUDA provides atomic ops, but not for all types. We provide wrappers
483 // for some ops and provide implementation for all reasonable types.
484
485 template <typename T, typename U>
CudaAtomicAdd(T * ptr,U value)486 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) {
487 return atomicAdd(ptr, value);
488 }
489
CudaAtomicAdd(Eigen::half * ptr,Eigen::half value)490 __device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
491 Eigen::half value) {
492 return detail::CudaAtomicCasHelper(
493 ptr, [value](Eigen::half a) { return a + value; });
494 }
495
496
497 #if __CUDA_ARCH__ < 600
CudaAtomicAdd(double * ptr,double value)498 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
499 return detail::CudaAtomicCasHelper(ptr,
500 [value](double a) { return a + value; });
501 }
502 #elif __clang__
503 // Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
504 // see https://reviews.llvm.org/D39638
CudaAtomicAdd(double * ptr,double value)505 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
506 double result;
507 asm volatile("atom.add.f64 %0, [%1], %2;"
508 : "=d"(result)
509 : "l"(ptr), "d"(value)
510 : "memory");
511 return result;
512 }
513 #endif
514 // CudaAtomicAdd
515 // Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does
516 // not support. We treat a std::complex<T>* as a T* (the C++ standard section
517 // 26.4.4 allows this explicitly) and atomic add the real and imaginary
518 // components individually. The operation as a whole is not atomic, but we can
519 // safely treat the components independently for the purpose of accumulating.
CudaAtomicAdd(std::complex<float> * ptr,std::complex<float> value)520 __device__ inline std::complex<float> CudaAtomicAdd(std::complex<float>* ptr,
521 std::complex<float> value) {
522 auto ptr_scalar = reinterpret_cast<float*>(ptr);
523 return std::complex<float>(CudaAtomicAdd(ptr_scalar, value.real()),
524 CudaAtomicAdd(ptr_scalar + 1, value.imag()));
525 }
526
CudaAtomicAdd(std::complex<double> * ptr,std::complex<double> value)527 __device__ inline std::complex<double> CudaAtomicAdd(
528 std::complex<double>* ptr, std::complex<double> value) {
529 auto ptr_scalar = reinterpret_cast<double*>(ptr);
530 return std::complex<double>(CudaAtomicAdd(ptr_scalar, value.real()),
531 CudaAtomicAdd(ptr_scalar + 1, value.imag()));
532 }
533
534 // CudaAtomicSub
535 template <typename T, typename U>
CudaAtomicSub(T * ptr,U value)536 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) {
537 return atomicSub(ptr, value);
538 }
539
540 // Specializations of subtraction which add the negative value.
CudaAtomicSub(float * ptr,float value)541 __device__ inline float CudaAtomicSub(float* ptr, float value) {
542 return CudaAtomicAdd(ptr, -value);
543 }
544
CudaAtomicSub(double * ptr,double value)545 __device__ inline double CudaAtomicSub(double* ptr, double value) {
546 return CudaAtomicAdd(ptr, -value);
547 }
548
CudaAtomicSub(tensorflow::uint64 * ptr,tensorflow::uint64 value)549 __device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
550 tensorflow::uint64 value) {
551 return CudaAtomicAdd(ptr, -value);
552 }
553
CudaAtomicSub(Eigen::half * ptr,Eigen::half value)554 __device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
555 Eigen::half value) {
556 return detail::CudaAtomicCasHelper(
557 ptr, [value](Eigen::half a) { return a - value; });
558 }
559
560 // CudaAtomicMax
561 template <typename T, typename U>
CudaAtomicMax(T * ptr,U value)562 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) {
563 return atomicMax(ptr, value);
564 }
565
CudaAtomicMax(float * ptr,float value)566 __device__ inline float CudaAtomicMax(float* ptr, float value) {
567 return detail::CudaAtomicCasHelper(
568 ptr, [value](float a) { return max(a, value); });
569 }
570
CudaAtomicMax(double * ptr,double value)571 __device__ inline double CudaAtomicMax(double* ptr, double value) {
572 return detail::CudaAtomicCasHelper(
573 ptr, [value](double a) { return max(a, value); });
574 }
575
CudaAtomicMax(Eigen::half * ptr,Eigen::half value)576 __device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr,
577 Eigen::half value) {
578 return detail::CudaAtomicCasHelper(
579 ptr, [value](Eigen::half a) { return max(a, value); });
580 }
581
582 #if __CUDA_ARCH__ < 320
CudaAtomicMax(tensorflow::uint64 * ptr,tensorflow::uint64 value)583 __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
584 tensorflow::uint64 value) {
585 return detail::CudaAtomicCasHelper(
586 ptr, [value](tensorflow::uint64 a) { return max(a, value); });
587 }
588 #endif
589
590 // CudaAtomicMin
591 template <typename T, typename U>
CudaAtomicMin(T * ptr,U value)592 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMin(T* ptr, U value) {
593 return atomicMin(ptr, value);
594 }
595
CudaAtomicMin(float * ptr,float value)596 __device__ inline float CudaAtomicMin(float* ptr, float value) {
597 return detail::CudaAtomicCasHelper(
598 ptr, [value](float a) { return min(a, value); });
599 }
600
CudaAtomicMin(double * ptr,double value)601 __device__ inline double CudaAtomicMin(double* ptr, double value) {
602 return detail::CudaAtomicCasHelper(
603 ptr, [value](double a) { return min(a, value); });
604 }
605
CudaAtomicMin(Eigen::half * ptr,Eigen::half value)606 __device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr,
607 Eigen::half value) {
608 return detail::CudaAtomicCasHelper(
609 ptr, [value](Eigen::half a) { return min(a, value); });
610 }
611
612 #if __CUDA_ARCH__ < 320
CudaAtomicMin(tensorflow::uint64 * ptr,tensorflow::uint64 value)613 __device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr,
614 tensorflow::uint64 value) {
615 return detail::CudaAtomicCasHelper(
616 ptr, [value](tensorflow::uint64 a) { return min(a, value); });
617 }
618 #endif
619
620 // CudaAtomicMul
621 template <typename T, typename U>
CudaAtomicMul(T * ptr,U value)622 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) {
623 return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
624 }
625
626 // CudaAtomicDiv
627 template <typename T, typename U>
CudaAtomicDiv(T * ptr,U value)628 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) {
629 return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
630 }
631
632 } // namespace tensorflow
633
634 #endif // GOOGLE_CUDA
635 #endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
636