• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UTIL_CUH_
18 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UTIL_CUH_
19 #include <cuda_fp16.h>
20 #include <algorithm>
21 #include <limits>
22 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
23 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
24 
25 #define kThreadsPerBlock (256)
26 #define kBlocksPerGrid(n) ((n + kThreadsPerBlock - 1) / kThreadsPerBlock)
27 
28 namespace atomic {
29 constexpr size_t OneByte = 1;
30 constexpr size_t TwoByte = 2;
31 constexpr size_t FourByte = 4;
32 constexpr size_t EightByte = 8;
33 
34 template <typename Func, typename T, size_t Bytes = sizeof(T)>
35 struct MsAtomicBinaryOpImpl;
36 
37 template <typename Func, typename T>
38 struct MsAtomicBinaryOpImpl<Func, T, OneByte> {
operator ()atomic::MsAtomicBinaryOpImpl39   __device__ __forceinline__ T operator()(T *address, T val) {
40     // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to
41     // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but
42     // unsigned int* must be 4 byte aligned. This variable contains the offset,
43     // in bytes, of the beginning of address, within the 4 byte aligned space that
44     // contains it.
45     size_t address_offset = reinterpret_cast<size_t>(address) & 3;
46 
47     // Address of the 4 byte aligned space that contains address.
48     unsigned int *aligned =
49       reinterpret_cast<unsigned int *>(reinterpret_cast<unsigned char *>(address) - address_offset);
50 
51     // Constants which will be used later with __byte_perm. __byte_perm is a cuda
52     // function which takes 3 unsigned int's (x, y, selector) as parameters and
53     // returns an int. __byte_perm returns an integer by selecting bytes from x
54     // and y based on the given selector. The selector 0x3210 in will select all
55     // four bytes from x, preserving their original order. The position of the
56     // "4" in the selector indicates the position in the output where the first
57     // byte of y will end up.
58     unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210};
59 
60     // Gets the selector that will select the bytes at address from aligned
61     unsigned int selector = selectors[address_offset];
62 
63     unsigned int old = *aligned;
64     unsigned int assumed = 0;
65 
66     do {
67       assumed = old;
68       // Selects the byte associated with address and put it as the first byte of
69       // this variable, so that we can add val to the value at address.
70       uint8_t old_byte = __byte_perm(old, 0, address_offset);
71       T old_value = *(reinterpret_cast<T *>(&old_byte));
72 
73       T new_value = Func()(old_value, val);
74 
75       unsigned int new_byte = *(reinterpret_cast<uint8_t *>(&new_value));
76 
77       // Takes old and replaces the byte corresponding to address with the sum.
78       unsigned int replacement = __byte_perm(old, new_byte, selector);
79 
80       // Try to replace the old value with the new value
81       old = atomicCAS(aligned, assumed, replacement);
82     } while (old != assumed);
83     // Select the single byte corredsponding to address and return it.
84     return __byte_perm(old, 0, address_offset);
85   }
86 };
87 
88 template <typename Func, typename T>
89 struct MsAtomicBinaryOpImpl<Func, T, TwoByte> {
operator ()atomic::MsAtomicBinaryOpImpl90   __device__ __forceinline__ T operator()(T *address, T val) {
91     bool is_4_byte_aligned = (reinterpret_cast<size_t>(address) & 2) == 0;
92     unsigned int *aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<size_t>(address) & ~2);
93     unsigned int old = *aligned;
94     unsigned int assumed;
95 
96     do {
97       assumed = old;
98       uint16_t old_byte = is_4_byte_aligned ? (old & 0xffff) : (old >> 16);
99       T old_value = *(reinterpret_cast<T *>(&old_byte));
100       // Do the binary operation.
101       T new_value = Func()(old_value, val);
102 
103       unsigned int new_byte = *(reinterpret_cast<uint16_t *>(&new_value));
104       if (is_4_byte_aligned) {
105         new_byte = (old & 0xffff0000) | new_byte;
106       } else {
107         new_byte = (old & 0xffff) | (new_byte << 16);
108       }
109       // Try to replace the old value with the new value.
110       // If failed, the current value of *address would be used to update the old value.
111       old = atomicCAS(aligned, assumed, new_byte);
112     } while (assumed != old);
113 
114     if (is_4_byte_aligned) {
115       return T(old & 0xffff);  // NOLINT
116     } else {
117       return T(old >> 16);  // NOLINT
118     }
119   }
120 };
121 
122 template <typename Func, typename T>
123 struct MsAtomicBinaryOpImpl<Func, T, FourByte> {
operator ()atomic::MsAtomicBinaryOpImpl124   __device__ __forceinline__ T operator()(T *address, T val) {
125     unsigned int *address_as_uint32 = reinterpret_cast<unsigned int *>(address);
126     unsigned int old = *address_as_uint32;
127     unsigned int assumed;
128     do {
129       assumed = old;
130       T old_value = *(reinterpret_cast<T *>(&old));
131       // Do the binary operation.
132       T new_value = Func()(old_value, val);
133       unsigned int new_byte = *(reinterpret_cast<unsigned int *>(&new_value));
134       // Try to replace the old value with the new value.
135       // If failed, the current value of *address would be used to update the old value.
136       old = atomicCAS(address_as_uint32, assumed, new_byte);
137     } while (assumed != old);
138     return T(old);
139   }
140 };
141 
142 template <typename Func, typename T>
143 struct MsAtomicBinaryOpImpl<Func, T, EightByte> {
operator ()atomic::MsAtomicBinaryOpImpl144   __device__ __forceinline__ T operator()(T *address, T val) {
145     unsigned long long int *address_as_uint64 = reinterpret_cast<unsigned long long int *>(address);  // NOLINT
146     unsigned long long int old = *address_as_uint64;                                                  // NOLINT
147     unsigned long long int assumed;                                                                   // NOLINT
148     do {
149       assumed = old;
150       T old_value = *(reinterpret_cast<T *>(&old));
151       // Do the binary operation.
152       T new_value = Func()(old_value, val);
153       unsigned long long int new_byte = *(reinterpret_cast<unsigned long long int *>(&new_value));  // NOLINT
154       // Try to replace the old value with the new value.
155       // If failed, the current value of *address would be used to update the old value.
156       old = atomicCAS(address_as_uint64, assumed, new_byte);
157     } while (assumed != old);
158     return T(old);
159   }
160 };
161 
162 struct Add {
163   template <typename T>
operator ()atomic::Add164   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
165     return lhs + rhs;
166   }
167 };
168 
169 struct Sub {
170   template <typename T>
operator ()atomic::Sub171   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
172     return lhs - rhs;
173   }
174 };
175 
176 struct Mul {
177   template <typename T>
operator ()atomic::Mul178   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
179     return lhs * rhs;
180   }
181 };
182 
183 struct Div {
184   template <typename T>
operator ()atomic::Div185   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
186     return lhs / rhs;
187   }
188 };
189 
190 struct Min {
191   template <typename T>
operator ()atomic::Min192   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
193     return lhs < rhs ? lhs : rhs;
194   }
195 };
196 
197 struct Max {
198   template <typename T>
operator ()atomic::Max199   __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
200     return lhs > rhs ? lhs : rhs;
201   }
202 };
203 }  // namespace atomic
204 
205 // atomic add
206 template <typename T>
MsAtomicAdd(T * address,T val)207 __device__ __forceinline__ T MsAtomicAdd(T *address, T val) {
208   return atomic::MsAtomicBinaryOpImpl<atomic::Add, T>()(address, val);
209 }
210 
211 // For following types, call CUDA API directly
212 template <>
MsAtomicAdd(int * address,int val)213 __device__ __forceinline__ int MsAtomicAdd(int *address, int val) {
214   return atomicAdd(address, val);
215 }
216 
217 template <>
MsAtomicAdd(unsigned int * address,unsigned int val)218 __device__ __forceinline__ unsigned int MsAtomicAdd(unsigned int *address, unsigned int val) {
219   return atomicAdd(address, val);
220 }
221 
222 template <>
MsAtomicAdd(unsigned long long int * address,unsigned long long int val)223 __device__ __forceinline__ unsigned long long int MsAtomicAdd(unsigned long long int *address,  // NOLINT
224                                                               unsigned long long int val) {     // NOLINT
225   return atomicAdd(address, val);
226 }
227 
228 template <>
MsAtomicAdd(float * address,float val)229 __device__ __forceinline__ float MsAtomicAdd(float *address, float val) {
230   return atomicAdd(address, val);
231 }
232 
233 template <>
MsAtomicAdd(bool * address,bool val)234 __device__ __forceinline__ bool MsAtomicAdd(bool *address, bool val) {
235   *address = address && val;
236   return address[0];
237 }
238 
239 template <>
MsAtomicAdd(Complex<float> * address,Complex<float> val)240 __device__ __forceinline__ Complex<float> MsAtomicAdd(Complex<float> *address, Complex<float> val) {
241   float *realAddr = reinterpret_cast<float *>(address);
242   return Complex<float>(MsAtomicAdd(realAddr, val.real()), MsAtomicAdd(realAddr + 1, val.imag()));
243 }
244 
245 template <>
MsAtomicAdd(Complex<double> * address,Complex<double> val)246 __device__ __forceinline__ Complex<double> MsAtomicAdd(Complex<double> *address, Complex<double> val) {
247   double *realAddr = reinterpret_cast<double *>(address);
248   return Complex<double>(MsAtomicAdd(realAddr, val.real()), MsAtomicAdd(realAddr + 1, val.imag()));
249 }
250 
251 // atomic sub
252 template <typename T>
MsAtomicSub(T * address,T val)253 __device__ __forceinline__ T MsAtomicSub(T *address, T val) {
254   return atomic::MsAtomicBinaryOpImpl<atomic::Sub, T>()(address, val);
255 }
256 
257 // For following types, call CUDA API directly
258 template <>
MsAtomicSub(unsigned int * address,unsigned int val)259 __device__ __forceinline__ unsigned int MsAtomicSub(unsigned int *address, unsigned int val) {
260   return atomicSub(address, val);
261 }
262 
263 // atomic min
264 template <typename T>
MsAtomicMin(T * address,T val)265 __device__ __forceinline__ T MsAtomicMin(T *address, T val) {
266   return atomic::MsAtomicBinaryOpImpl<atomic::Min, T>()(address, val);
267 }
268 
269 // For following types, call CUDA API directly
270 template <>
MsAtomicMin(int * address,int val)271 __device__ __forceinline__ int MsAtomicMin(int *address, int val) {
272   return atomicMin(address, val);
273 }
274 
275 template <>
MsAtomicMin(unsigned int * address,unsigned int val)276 __device__ __forceinline__ unsigned int MsAtomicMin(unsigned int *address, unsigned int val) {
277   return atomicMin(address, val);
278 }
279 
280 template <>
MsAtomicMin(unsigned long long int * address,unsigned long long int val)281 __device__ __forceinline__ unsigned long long int MsAtomicMin(unsigned long long int *address,  // NOLINT
282                                                               unsigned long long int val) {     // NOLINT
283   return atomicMin(address, val);
284 }
285 
286 template <>
MsAtomicMin(long long int * address,long long int val)287 __device__ __forceinline__ long long int MsAtomicMin(long long int *address, long long int val) {  // NOLINT
288   return atomicMin(address, val);
289 }
290 
291 // atomic max
292 template <typename T>
MsAtomicMax(T * address,T val)293 __device__ __forceinline__ T MsAtomicMax(T *address, T val) {
294   return atomic::MsAtomicBinaryOpImpl<atomic::Max, T>()(address, val);
295 }
296 
297 // For following types, call CUDA API directly
298 template <>
MsAtomicMax(int * address,int val)299 __device__ __forceinline__ int MsAtomicMax(int *address, int val) {
300   return atomicMax(address, val);
301 }
302 
303 template <>
MsAtomicMax(unsigned int * address,unsigned int val)304 __device__ __forceinline__ unsigned int MsAtomicMax(unsigned int *address, unsigned int val) {
305   return atomicMax(address, val);
306 }
307 
308 template <>
MsAtomicMax(unsigned long long int * address,unsigned long long int val)309 __device__ __forceinline__ unsigned long long int MsAtomicMax(unsigned long long int *address,  // NOLINT
310                                                               unsigned long long int val) {     // NOLINT
311   return atomicMax(address, val);
312 }
313 
314 template <>
MsAtomicMax(long long int * address,long long int val)315 __device__ __forceinline__ long long int MsAtomicMax(long long int *address, long long int val) {  // NOLINT
316   return atomicMax(address, val);
317 }
318 
319 // atomic mul
320 template <typename T>
MsAtomicMul(T * address,T val)321 __device__ __forceinline__ T MsAtomicMul(T *address, T val) {
322   return atomic::MsAtomicBinaryOpImpl<atomic::Mul, T>()(address, val);
323 }
324 
325 template <>
MsAtomicMul(bool * address,bool val)326 __device__ __forceinline__ bool MsAtomicMul(bool *address, bool val) {
327   *address = address && val;
328   return address[0];
329 }
330 
331 // atomic div
332 template <typename T>
MsAtomicDiv(T * address,T val)333 __device__ __forceinline__ T MsAtomicDiv(T *address, T val) {
334   return atomic::MsAtomicBinaryOpImpl<atomic::Div, T>()(address, val);
335 }
336 
BallotSync(int predicate,unsigned mask=0xffffffff)337 __device__ __forceinline__ unsigned BallotSync(int predicate, unsigned mask = 0xffffffff) {
338   return __ballot_sync(mask, predicate);
339 }
340 
341 struct MsAtomicAddFunctor {
342   template <typename T>
operator ()MsAtomicAddFunctor343   __device__ __forceinline__ T operator()(T *address, T val) {
344     return MsAtomicAdd(address, val);
345   }
346 };
347 
348 struct MsAtomicSubFunctor {
349   template <typename T>
operator ()MsAtomicSubFunctor350   __device__ __forceinline__ T operator()(T *address, T val) {
351     return MsAtomicSub(address, val);
352   }
353 };
354 
355 struct MsAtomicMulFunctor {
356   template <typename T>
operator ()MsAtomicMulFunctor357   __device__ __forceinline__ T operator()(T *address, T val) {
358     return MsAtomicMul(address, val);
359   }
360 };
361 
362 struct MsAtomicDivFunctor {
363   template <typename T>
operator ()MsAtomicDivFunctor364   __device__ __forceinline__ T operator()(T *address, T val) {
365     return MsAtomicDiv(address, val);
366   }
367 };
368 
369 struct MsAtomicMinFunctor {
370   template <typename T>
operator ()MsAtomicMinFunctor371   __device__ __forceinline__ T operator()(T *address, T val) {
372     return MsAtomicMin(address, val);
373   }
374 };
375 
376 struct MsAtomicMaxFunctor {
377   template <typename T>
operator ()MsAtomicMaxFunctor378   __device__ __forceinline__ T operator()(T *address, T val) {
379     return MsAtomicMax(address, val);
380   }
381 };
382 
383 enum : unsigned { warp_size = 32, log_wap_size = 5 };
LaneId()384 __device__ __forceinline__ unsigned LaneId() { return threadIdx.x & (warp_size - 1); }
WarpId(const unsigned & tid)385 __device__ __forceinline__ unsigned WarpId(const unsigned &tid) { return tid >> log_wap_size; }
386 
387 template <typename T>
388 struct Epsilon {
389   static constexpr float value = std::numeric_limits<T>::epsilon();
390 };
391 
392 template <>
393 struct Epsilon<half> {
394   static constexpr float value = 0.000977;
395 };
396 
397 // Some bit-related function
Log2Floor(uint32_t n)398 inline int Log2Floor(uint32_t n) {
399   if (n == 0) return -1;
400   int log = 0;
401   for (int i = 4; i >= 0; --i) {
402     int shift = (1 << i);
403     uint32_t x = n >> shift;
404     if (x) {
405       n = x;
406       log += shift;
407     }
408   }
409   return log;
410 }
411 
Log2Ceil(uint32_t n)412 inline int Log2Ceil(uint32_t n) {
413   int floor = Log2Floor(n);
414   if (n == (n & ~(n - 1)))
415     return floor;
416   else
417     return floor + 1;
418 }
419 
Log2Floor64(uint64_t n)420 inline int Log2Floor64(uint64_t n) {
421   // Scan n first high 32 then low 32 bits.
422   const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
423   if (high_32_bit == 0) {
424     return Log2Floor(static_cast<uint32_t>(n));
425   } else {
426     return 32 + Log2Floor(high_32_bit);
427   }
428 }
429 
Log2Ceil64(uint64_t n)430 inline int Log2Ceil64(uint64_t n) {
431   int floor = Log2Floor64(n);
432   if (n == (n & ~(n - 1)))
433     return floor;
434   else
435     return floor + 1;
436 }
437 
438 template <typename T>
ZeroImpl()439 __device__ __forceinline__ T ZeroImpl() {
440   return 0;
441 }
442 
443 template <>
ZeroImpl()444 __device__ __forceinline__ cuComplex ZeroImpl() {
445   return make_cuComplex(0., 0.);
446 }
447 
448 template <>
ZeroImpl()449 __device__ __forceinline__ cuDoubleComplex ZeroImpl() {
450   return make_cuDoubleComplex(0., 0.);
451 }
452 
453 template <typename T>
shfl_xor_sync(unsigned mask,T var,int lane_mask)454 __device__ __forceinline__ T shfl_xor_sync(unsigned mask, T var, int lane_mask) {
455   return __shfl_xor_sync(mask, var, lane_mask);
456 }
457 
458 template <>
shfl_xor_sync(unsigned mask,Complex<float> var,int lane_mask)459 __device__ __forceinline__ Complex<float> shfl_xor_sync(unsigned mask, Complex<float> var, int lane_mask) {
460   return Complex<float>(__shfl_xor_sync(mask, var.real(), lane_mask), __shfl_xor_sync(mask, var.imag(), lane_mask));
461 }
462 
463 template <>
shfl_xor_sync(unsigned mask,Complex<double> var,int lane_mask)464 __device__ __forceinline__ Complex<double> shfl_xor_sync(unsigned mask, Complex<double> var, int lane_mask) {
465   return Complex<double>(__shfl_xor_sync(mask, var.real(), lane_mask), __shfl_xor_sync(mask, var.imag(), lane_mask));
466 }
467 
468 template <typename T>
shfl_down_sync(unsigned mask,T var,int lane_mask)469 __device__ __forceinline__ T shfl_down_sync(unsigned mask, T var, int lane_mask) {
470   return __shfl_down_sync(mask, var, lane_mask);
471 }
472 
473 template <>
shfl_down_sync(unsigned mask,Complex<float> var,int lane_mask)474 __device__ __forceinline__ Complex<float> shfl_down_sync(unsigned mask, Complex<float> var, int lane_mask) {
475   return Complex<float>(__shfl_down_sync(mask, var.real(), lane_mask), __shfl_down_sync(mask, var.imag(), lane_mask));
476 }
477 
478 template <>
shfl_down_sync(unsigned mask,Complex<double> var,int lane_mask)479 __device__ __forceinline__ Complex<double> shfl_down_sync(unsigned mask, Complex<double> var, int lane_mask) {
480   return Complex<double>(__shfl_down_sync(mask, var.real(), lane_mask), __shfl_down_sync(mask, var.imag(), lane_mask));
481 }
482 
483 template <typename T>
FastAtomicAdd(T * base,size_t offset,const size_t length,T value)484 __device__ __forceinline__ void FastAtomicAdd(T *base, size_t offset, const size_t length, T value) {
485   MsAtomicAdd(base + offset, value);
486 }
487 
488 template <>
FastAtomicAdd(half * base,size_t offset,const size_t length,half value)489 __device__ __forceinline__ void FastAtomicAdd(half *base, size_t offset, const size_t length, half value) {
490 #if ((defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
491   MsAtomicAdd(reinterpret_cast<half *>(base + offset), static_cast<half>(value));
492 #else
493   // Accounts for the chance base falls on an odd 16 bit alignment (ie, not 32 bit aligned)
494   __half *target_addr = reinterpret_cast<__half *>(base + offset);
495   bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
496 
497   if (low_byte && offset < (length - 1)) {
498     __half2 value2;
499     value2.x = value;
500     value2.y = __float2half_rz(0);
501     atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2);
502   } else if (!low_byte && offset > 0) {
503     __half2 value2;
504     value2.x = __float2half_rz(0);
505     value2.y = value;
506     atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2);
507   } else {
508     MsAtomicAdd(reinterpret_cast<__half *>(base) + offset, static_cast<__half>(value));
509   }
510 #endif
511 }
512 
513 #endif  // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UTIL_CUH_
514