1 // 2 // Copyright 2016 Google Inc. 3 // 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 // 7 8 #ifndef HS_CUDA_MACROS_ONCE 9 #define HS_CUDA_MACROS_ONCE 10 11 // 12 // 13 // 14 15 #ifdef __cplusplus 16 extern "C" { 17 #endif 18 19 #include <stdint.h> 20 21 #ifdef __cplusplus 22 } 23 #endif 24 25 // 26 // Define the type based on key and val sizes 27 // 28 29 #if HS_KEY_WORDS == 1 30 #if HS_VAL_WORDS == 0 31 #define HS_KEY_TYPE uint32_t 32 #endif 33 #elif HS_KEY_WORDS == 2 34 #define HS_KEY_TYPE uint64_t 35 #endif 36 37 // 38 // FYI, restrict shouldn't have any impact on these kernels and 39 // benchmarks appear to prove that true 40 // 41 42 #define HS_RESTRICT __restrict__ 43 44 // 45 // 46 // 47 48 #define HS_SCOPE() \ 49 static 50 51 #define HS_KERNEL_QUALIFIER() \ 52 __global__ void 53 54 // 55 // The sm_35 arch has a maximum of 16 blocks per multiprocessor. Just 56 // clamp it to 16 when targeting this arch. 57 // 58 // This only arises when compiling the 32-bit sorting kernels. 59 // 60 // You can also generate a narrower 16-warp wide 32-bit sorting kernel 61 // which is sometimes faster and sometimes slower than the 32-block 62 // configuration. 63 // 64 65 #if ( __CUDA_ARCH__ == 350 ) 66 #define HS_CUDA_MAX_BPM 16 67 #else 68 #define HS_CUDA_MAX_BPM UINT32_MAX // 32 69 #endif 70 71 #define HS_CLAMPED_BPM(min_bpm) \ 72 ((min_bpm) < HS_CUDA_MAX_BPM ? (min_bpm) : HS_CUDA_MAX_BPM) 73 74 // 75 // 76 // 77 78 #define HS_LAUNCH_BOUNDS(max_tpb,min_bpm) \ 79 __launch_bounds__(max_tpb,HS_CLAMPED_BPM(min_bpm)) 80 81 // 82 // KERNEL PROTOS 83 // 84 85 #define HS_BS_KERNEL_NAME(slab_count_ru_log2) \ 86 hs_kernel_bs_##slab_count_ru_log2 87 88 #define HS_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2) \ 89 HS_SCOPE() \ 90 HS_KERNEL_QUALIFIER() \ 91 HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,1) \ 92 HS_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE * const HS_RESTRICT vout, \ 93 HS_KEY_TYPE const * const HS_RESTRICT vin) 94 95 // 96 97 #define HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2) \ 98 hs_kernel_bs_##slab_count_ru_log2 99 100 #define HS_OFFSET_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2) \ 101 HS_SCOPE() \ 102 HS_KERNEL_QUALIFIER() \ 103 HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_ru_log2)) \ 104 HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE * const HS_RESTRICT vout, \ 105 HS_KEY_TYPE const * const HS_RESTRICT vin, \ 106 uint32_t const slab_offset) 107 108 // 109 110 #define HS_BC_KERNEL_NAME(slab_count_log2) \ 111 hs_kernel_bc_##slab_count_log2 112 113 #define HS_BC_KERNEL_PROTO(slab_count,slab_count_log2) \ 114 HS_SCOPE() \ 115 HS_KERNEL_QUALIFIER() \ 116 HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_log2)) \ 117 HS_BC_KERNEL_NAME(slab_count_log2)(HS_KEY_TYPE * const HS_RESTRICT vout) 118 119 // 120 121 #define HS_HM_KERNEL_NAME(s) \ 122 hs_kernel_hm_##s 123 124 #define HS_HM_KERNEL_PROTO(s) \ 125 HS_SCOPE() \ 126 HS_KERNEL_QUALIFIER() \ 127 HS_HM_KERNEL_NAME(s)(HS_KEY_TYPE * const HS_RESTRICT vout) 128 129 // 130 131 #define HS_FM_KERNEL_NAME(s,r) \ 132 hs_kernel_fm_##s##_##r 133 134 #define HS_FM_KERNEL_PROTO(s,r) \ 135 HS_SCOPE() \ 136 HS_KERNEL_QUALIFIER() \ 137 HS_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout) 138 139 // 140 141 #define HS_OFFSET_FM_KERNEL_NAME(s,r) \ 142 hs_kernel_fm_##s##_##r 143 144 #define HS_OFFSET_FM_KERNEL_PROTO(s,r) \ 145 HS_SCOPE() \ 146 HS_KERNEL_QUALIFIER() \ 147 HS_OFFSET_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout, \ 148 uint32_t const span_offset) 149 150 // 151 152 #define HS_TRANSPOSE_KERNEL_NAME() \ 153 hs_kernel_transpose 154 155 #define HS_TRANSPOSE_KERNEL_PROTO() \ 156 HS_SCOPE() \ 157 HS_KERNEL_QUALIFIER() \ 158 HS_LAUNCH_BOUNDS(HS_SLAB_THREADS,1) \ 159 HS_TRANSPOSE_KERNEL_NAME()(HS_KEY_TYPE * const HS_RESTRICT vout) 160 161 // 162 // BLOCK LOCAL MEMORY DECLARATION 163 // 164 165 #define HS_BLOCK_LOCAL_MEM_DECL(width,height) \ 166 __shared__ struct { \ 167 HS_KEY_TYPE m[width * height]; \ 168 } shared 169 170 // 171 // BLOCK BARRIER 172 // 173 174 #define HS_BLOCK_BARRIER() \ 175 __syncthreads() 176 177 // 178 // GRID VARIABLES 179 // 180 181 #define HS_GLOBAL_SIZE_X() (gridDim.x * blockDim.x) 182 #define HS_GLOBAL_ID_X() (blockDim.x * blockIdx.x + threadIdx.x) 183 #define HS_LOCAL_ID_X() threadIdx.x 184 #define HS_WARP_ID_X() (threadIdx.x / 32) 185 #define HS_LANE_ID() (threadIdx.x & 31) 186 187 // 188 // SLAB GLOBAL 189 // 190 191 #define HS_SLAB_GLOBAL_PREAMBLE() \ 192 uint32_t const gmem_idx = \ 193 (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS-1)) * \ 194 HS_SLAB_HEIGHT + HS_LANE_ID() 195 196 #define HS_OFFSET_SLAB_GLOBAL_PREAMBLE() \ 197 uint32_t const gmem_idx = \ 198 ((slab_offset + HS_GLOBAL_ID_X()) & ~(HS_SLAB_THREADS-1)) * \ 199 HS_SLAB_HEIGHT + HS_LANE_ID() 200 201 #define HS_SLAB_GLOBAL_LOAD(extent,row_idx) \ 202 extent[gmem_idx + HS_SLAB_THREADS * row_idx] 203 204 #define HS_SLAB_GLOBAL_STORE(row_idx,reg) \ 205 vout[gmem_idx + HS_SLAB_THREADS * row_idx] = reg 206 207 // 208 // SLAB LOCAL 209 // 210 211 #define HS_SLAB_LOCAL_L(offset) \ 212 shared.m[smem_l_idx + (offset)] 213 214 #define HS_SLAB_LOCAL_R(offset) \ 215 shared.m[smem_r_idx + (offset)] 216 217 // 218 // SLAB LOCAL VERTICAL LOADS 219 // 220 221 #define HS_BX_LOCAL_V(offset) \ 222 shared.m[HS_LOCAL_ID_X() + (offset)] 223 224 // 225 // BLOCK SORT MERGE HORIZONTAL 226 // 227 228 #define HS_BS_MERGE_H_PREAMBLE(slab_count) \ 229 uint32_t const smem_l_idx = \ 230 HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) + \ 231 HS_LANE_ID(); \ 232 uint32_t const smem_r_idx = \ 233 (HS_WARP_ID_X() ^ 1) * (HS_SLAB_THREADS * slab_count) + \ 234 (HS_LANE_ID() ^ (HS_SLAB_THREADS - 1)) 235 236 // 237 // BLOCK CLEAN MERGE HORIZONTAL 238 // 239 240 #define HS_BC_MERGE_H_PREAMBLE(slab_count) \ 241 uint32_t const gmem_l_idx = \ 242 (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS*slab_count-1)) * \ 243 HS_SLAB_HEIGHT + HS_LOCAL_ID_X(); \ 244 uint32_t const smem_l_idx = \ 245 HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) + \ 246 HS_LANE_ID() 247 248 #define HS_BC_GLOBAL_LOAD_L(slab_idx) \ 249 vout[gmem_l_idx + (HS_SLAB_THREADS * slab_idx)] 250 251 // 252 // SLAB FLIP AND HALF PREAMBLES 253 // 254 255 #define HS_SLAB_FLIP_PREAMBLE(mask) \ 256 uint32_t const flip_lane_idx = HS_LANE_ID() ^ mask; \ 257 int32_t const t_lt = HS_LANE_ID() < flip_lane_idx; 258 259 // if we want to shlf_xor: uint32_t const flip_lane_mask = mask; 260 261 #define HS_SLAB_HALF_PREAMBLE(mask) \ 262 uint32_t const half_lane_idx = HS_LANE_ID() ^ mask; \ 263 int32_t const t_lt = HS_LANE_ID() < half_lane_idx; 264 265 // if we want to shfl_xor: uint32_t const half_lane_mask = mask; 266 267 // 268 // Inter-lane compare exchange 269 // 270 271 // good 272 #define HS_CMP_XCHG_V0(a,b) \ 273 { \ 274 HS_KEY_TYPE const t = min(a,b); \ 275 b = max(a,b); \ 276 a = t; \ 277 } 278 279 // surprisingly fast -- #1 on 64-bit keys 280 #define HS_CMP_XCHG_V1(a,b) \ 281 { \ 282 HS_KEY_TYPE const tmp = a; \ 283 a = (a < b) ? a : b; \ 284 b ^= a ^ tmp; \ 285 } 286 287 // good 288 #define HS_CMP_XCHG_V2(a,b) \ 289 if (a >= b) { \ 290 HS_KEY_TYPE const t = a; \ 291 a = b; \ 292 b = t; \ 293 } 294 295 // good 296 #define HS_CMP_XCHG_V3(a,b) \ 297 { \ 298 int32_t const ge = a >= b; \ 299 HS_KEY_TYPE const t = a; \ 300 a = ge ? b : a; \ 301 b = ge ? t : b; \ 302 } 303 304 // 305 // 306 // 307 308 #if (HS_KEY_WORDS == 1) 309 #define HS_CMP_XCHG(a,b) HS_CMP_XCHG_V0(a,b) 310 #elif (HS_KEY_WORDS == 2) 311 #define HS_CMP_XCHG(a,b) HS_CMP_XCHG_V0(a,b) 312 #endif 313 314 // 315 // The flip/half comparisons rely on a "conditional min/max": 316 // 317 // - if the flag is false, return min(a,b) 318 // - otherwise, return max(a,b) 319 // 320 // What's a little surprising is that sequence (1) is faster than (2) 321 // for 32-bit keys. 322 // 323 // I suspect either a code generation problem or that the sequence 324 // maps well to the GEN instruction set. 325 // 326 // We mostly care about 64-bit keys and unsurprisingly sequence (2) is 327 // fastest for this wider type. 328 // 329 330 // this is what you would normally use 331 #define HS_COND_MIN_MAX_V0(lt,a,b) ((a <= b) ^ lt) ? b : a 332 333 // this seems to be faster for 32-bit keys 334 #define HS_COND_MIN_MAX_V1(lt,a,b) (lt ? b : a) ^ ((a ^ b) & HS_LTE_TO_MASK(a,b)) 335 336 // 337 // 338 // 339 340 #if (HS_KEY_WORDS == 1) 341 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b) 342 #elif (HS_KEY_WORDS == 2) 343 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b) 344 #endif 345 346 // 347 // HotSort shuffles are always warp-wide 348 // 349 350 #define HS_SHFL_ALL 0xFFFFFFFF 351 352 // 353 // Conditional inter-subgroup flip/half compare exchange 354 // 355 356 #define HS_CMP_FLIP(i,a,b) \ 357 { \ 358 HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,flip_lane_idx); \ 359 HS_KEY_TYPE const tb = __shfl_sync(HS_SHFL_ALL,b,flip_lane_idx); \ 360 a = HS_COND_MIN_MAX(t_lt,a,tb); \ 361 b = HS_COND_MIN_MAX(t_lt,b,ta); \ 362 } 363 364 #define HS_CMP_HALF(i,a) \ 365 { \ 366 HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,half_lane_idx); \ 367 a = HS_COND_MIN_MAX(t_lt,a,ta); \ 368 } 369 370 // 371 // The device's comparison operator might return what we actually 372 // want. For example, it appears GEN 'cmp' returns {true:-1,false:0}. 373 // 374 375 #define HS_CMP_IS_ZERO_ONE 376 377 #ifdef HS_CMP_IS_ZERO_ONE 378 // OpenCL requires a {true: +1, false: 0} scalar result 379 // (a < b) -> { +1, 0 } -> NEGATE -> { 0, 0xFFFFFFFF } 380 #define HS_LTE_TO_MASK(a,b) (HS_KEY_TYPE)(-(a <= b)) 381 #define HS_CMP_TO_MASK(a) (HS_KEY_TYPE)(-a) 382 #else 383 // However, OpenCL requires { -1, 0 } for vectors 384 // (a < b) -> { 0xFFFFFFFF, 0 } 385 #define HS_LTE_TO_MASK(a,b) (a <= b) // FIXME for uint64 386 #define HS_CMP_TO_MASK(a) (a) 387 #endif 388 389 // 390 // The "flip-merge" and "half-merge" preambles are very similar 391 // 392 // For now, we're only using the .y dimension for the span idx 393 // 394 395 #define HS_OFFSET_HM_PREAMBLE(half_span,span_offset) \ 396 uint32_t const span_idx = span_offset + blockIdx.y; \ 397 uint32_t const span_stride = HS_GLOBAL_SIZE_X(); \ 398 uint32_t const span_size = span_stride * half_span * 2; \ 399 uint32_t const span_base = span_idx * span_size; \ 400 uint32_t const span_off = HS_GLOBAL_ID_X(); \ 401 uint32_t const span_l = span_base + span_off 402 403 #define HS_HM_PREAMBLE(half_span) \ 404 HS_OFFSET_HM_PREAMBLE(half_span,0) \ 405 406 #define HS_FM_PREAMBLE(half_span) \ 407 HS_HM_PREAMBLE(half_span); \ 408 uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1 409 410 #define HS_OFFSET_FM_PREAMBLE(half_span) \ 411 HS_OFFSET_HM_PREAMBLE(half_span,span_offset); \ 412 uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1 413 414 // 415 // 416 // 417 418 #define HS_XM_GLOBAL_L(stride_idx) \ 419 vout[span_l + span_stride * stride_idx] 420 421 #define HS_XM_GLOBAL_LOAD_L(stride_idx) \ 422 HS_XM_GLOBAL_L(stride_idx) 423 424 #define HS_XM_GLOBAL_STORE_L(stride_idx,reg) \ 425 HS_XM_GLOBAL_L(stride_idx) = reg 426 427 #define HS_FM_GLOBAL_R(stride_idx) \ 428 vout[span_r + span_stride * stride_idx] 429 430 #define HS_FM_GLOBAL_LOAD_R(stride_idx) \ 431 HS_FM_GLOBAL_R(stride_idx) 432 433 #define HS_FM_GLOBAL_STORE_R(stride_idx,reg) \ 434 HS_FM_GLOBAL_R(stride_idx) = reg 435 436 // 437 // This snarl of macros is for transposing a "slab" of sorted elements 438 // into linear order. 439 // 440 // This can occur as the last step in hs_sort() or via a custom kernel 441 // that inspects the slab and then transposes and stores it to memory. 442 // 443 // The slab format can be inspected more efficiently than a linear 444 // arrangement. 445 // 446 // The prime example is detecting when adjacent keys (in sort order) 447 // have differing high order bits ("key changes"). The index of each 448 // change is recorded to an auxilary array. 449 // 450 // A post-processing step like this needs to be able to navigate the 451 // slab and eventually transpose and store the slab in linear order. 452 // 453 454 #define HS_SUBGROUP_SHUFFLE_XOR(v,m) __shfl_xor_sync(HS_SHFL_ALL,v,m) 455 456 #define HS_TRANSPOSE_REG(prefix,row) prefix##row 457 #define HS_TRANSPOSE_DECL(prefix,row) HS_KEY_TYPE const HS_TRANSPOSE_REG(prefix,row) 458 #define HS_TRANSPOSE_PRED(level) is_lo_##level 459 460 #define HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) \ 461 prefix_curr##row_ll##_##row_ur 462 463 #define HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) \ 464 HS_KEY_TYPE const HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) 465 466 #define HS_TRANSPOSE_STAGE(level) \ 467 bool const HS_TRANSPOSE_PRED(level) = \ 468 (HS_LANE_ID() & (1 << (level-1))) == 0; 469 470 #define HS_TRANSPOSE_BLEND(prefix_prev,prefix_curr,level,row_ll,row_ur) \ 471 HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) = \ 472 HS_SUBGROUP_SHUFFLE_XOR(HS_TRANSPOSE_PRED(level) ? \ 473 HS_TRANSPOSE_REG(prefix_prev,row_ll) : \ 474 HS_TRANSPOSE_REG(prefix_prev,row_ur), \ 475 1<<(level-1)); \ 476 \ 477 HS_TRANSPOSE_DECL(prefix_curr,row_ll) = \ 478 HS_TRANSPOSE_PRED(level) ? \ 479 HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) : \ 480 HS_TRANSPOSE_REG(prefix_prev,row_ll); \ 481 \ 482 HS_TRANSPOSE_DECL(prefix_curr,row_ur) = \ 483 HS_TRANSPOSE_PRED(level) ? \ 484 HS_TRANSPOSE_REG(prefix_prev,row_ur) : \ 485 HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur); 486 487 #define HS_TRANSPOSE_REMAP(prefix,row_from,row_to) \ 488 vout[gmem_idx + ((row_to-1) << HS_SLAB_WIDTH_LOG2)] = \ 489 HS_TRANSPOSE_REG(prefix,row_from); 490 491 // 492 // 493 // 494 495 #endif 496 497 // 498 // 499 // 500