• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2016 Google Inc.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file.
6 //
7 
8 #ifndef HS_CL_MACROS_ONCE
9 #define HS_CL_MACROS_ONCE
10 
11 //
12 // Define the type based on key and val sizes
13 //
14 
15 #if   HS_KEY_WORDS == 1
16 #if   HS_VAL_WORDS == 0
17 #define HS_KEY_TYPE  uint
18 #endif
19 #elif HS_KEY_WORDS == 2
20 #define HS_KEY_TYPE  ulong
21 #endif
22 
23 //
24 // FYI, restrict shouldn't have any impact on these kernels and
25 // benchmarks appear to prove that true
26 //
27 
28 #define HS_RESTRICT restrict
29 
30 //
31 //
32 //
33 
34 #define HS_REQD_SUBGROUP_SIZE()                 \
35   __attribute__((intel_reqd_sub_group_size(HS_SLAB_THREADS)))
36 
37 //
38 // KERNEL PROTOS
39 //
40 
41 #define HS_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)               \
42   __kernel                                                              \
43   HS_REQD_SUBGROUP_SIZE()                                               \
44   __attribute__((reqd_work_group_size(HS_SLAB_THREADS*slab_count,1,1))) \
45   void                                                                  \
46   hs_kernel_bs_##slab_count_ru_log2(__global HS_KEY_TYPE const * const HS_RESTRICT vin, \
47                                     __global HS_KEY_TYPE       * const HS_RESTRICT vout)
48 
49 #define HS_BC_KERNEL_PROTO(slab_count,slab_count_log2)                  \
50   __kernel                                                              \
51   HS_REQD_SUBGROUP_SIZE()                                               \
52   __attribute__((reqd_work_group_size(HS_SLAB_THREADS*slab_count,1,1))) \
53   void                                                                  \
54   hs_kernel_bc_##slab_count_log2(__global HS_KEY_TYPE * const HS_RESTRICT vout)
55 
56 #define HS_FM_KERNEL_PROTO(s,r)                                         \
57   __kernel                                                              \
58   HS_REQD_SUBGROUP_SIZE()                                               \
59   void                                                                  \
60   hs_kernel_fm_##s##_##r(__global HS_KEY_TYPE * const HS_RESTRICT vout)
61 
62 #define HS_HM_KERNEL_PROTO(s)                                           \
63   __kernel                                                              \
64   HS_REQD_SUBGROUP_SIZE()                                               \
65   void                                                                  \
66   hs_kernel_hm_##s(__global HS_KEY_TYPE * const HS_RESTRICT vout)
67 
68 #define HS_TRANSPOSE_KERNEL_PROTO()                                     \
69   __kernel                                                              \
70   HS_REQD_SUBGROUP_SIZE()                                               \
71   void                                                                  \
72   hs_kernel_transpose(__global HS_KEY_TYPE * const HS_RESTRICT vout)
73 
74 //
75 // BLOCK LOCAL MEMORY DECLARATION
76 //
77 
78 #define HS_BLOCK_LOCAL_MEM_DECL(width,height)   \
79   __local struct {                              \
80     HS_KEY_TYPE m[width * height];              \
81   } shared
82 
83 //
84 //
85 //
86 
87 #define HS_SUBGROUP_ID()                        \
88   get_sub_group_id()
89 
90 //
91 // BLOCK BARRIER
92 //
93 
94 #define HS_BLOCK_BARRIER()                      \
95   barrier(CLK_LOCAL_MEM_FENCE)
96 
97 //
98 // SLAB GLOBAL
99 //
100 
101 #define HS_SLAB_GLOBAL_PREAMBLE()                                       \
102   uint const gmem_idx =                                                 \
103     (get_global_id(0) & ~(HS_SLAB_THREADS-1)) * HS_SLAB_HEIGHT +        \
104     (get_local_id(0) & (HS_SLAB_THREADS-1))
105 
106 #define HS_SLAB_GLOBAL_LOAD(extent,row_idx)     \
107   extent[gmem_idx + HS_SLAB_THREADS * row_idx]
108 
109 #define HS_SLAB_GLOBAL_STORE(row_idx,reg)               \
110   vout[gmem_idx + HS_SLAB_THREADS * row_idx] = reg
111 
112 //
113 // SLAB LOCAL
114 //
115 
116 #define HS_SLAB_LOCAL_L(offset)                 \
117   shared.m[smem_l_idx + (offset)]
118 
119 #define HS_SLAB_LOCAL_R(offset)                 \
120   shared.m[smem_r_idx + (offset)]
121 
122 //
123 // SLAB LOCAL VERTICAL LOADS
124 //
125 
126 #define HS_BX_LOCAL_V(offset)                   \
127   shared.m[get_local_id(0) + (offset)]
128 
129 //
130 // BLOCK SORT MERGE HORIZONTAL
131 //
132 
133 #define HS_BS_MERGE_H_PREAMBLE(slab_count)                      \
134   uint const smem_l_idx =                                       \
135     get_sub_group_id() * (HS_SLAB_THREADS * slab_count) +       \
136     get_sub_group_local_id();                                   \
137   uint const smem_r_idx =                                       \
138     (get_sub_group_id() ^ 1) * (HS_SLAB_THREADS * slab_count) + \
139     (get_sub_group_local_id() ^ (HS_SLAB_THREADS - 1))
140 
141 //
142 // BLOCK CLEAN MERGE HORIZONTAL
143 //
144 
145 #define HS_BC_MERGE_H_PREAMBLE(slab_count)                              \
146   uint const gmem_l_idx =                                               \
147     (get_global_id(0) & ~(HS_SLAB_THREADS*slab_count-1)) *              \
148     HS_SLAB_HEIGHT + get_local_id(0);                                   \
149   uint const smem_l_idx =                                               \
150     get_sub_group_id() * (HS_SLAB_THREADS * slab_count) +               \
151     get_sub_group_local_id()
152 
153 #define HS_BC_GLOBAL_LOAD_L(slab_idx)                   \
154   vout[gmem_l_idx + (HS_SLAB_THREADS * slab_idx)]
155 
156 //
157 // SLAB FLIP AND HALF PREAMBLES
158 //
159 
160 #define HS_SLAB_FLIP_PREAMBLE(mask)                                     \
161   uint const flip_lane_idx = get_sub_group_local_id() ^ mask;           \
162   int  const t_lt          = get_sub_group_local_id() < flip_lane_idx;
163 
164 #define HS_SLAB_HALF_PREAMBLE(mask)                                     \
165   uint const half_lane_idx = get_sub_group_local_id() ^ mask;           \
166   int  const t_lt          = get_sub_group_local_id() < half_lane_idx;
167 
168 //
169 // Inter-lane compare exchange
170 //
171 
172 // default
173 #define HS_CMP_XCHG_V0(a,b)                     \
174   {                                             \
175     HS_KEY_TYPE const t = min(a,b);             \
176     b = max(a,b);                               \
177     a = t;                                      \
178   }
179 
180 // super slow
181 #define HS_CMP_XCHG_V1(a,b)                     \
182   {                                             \
183     HS_KEY_TYPE const tmp = a;                  \
184     a  = (a < b) ? a : b;                       \
185     b ^= a ^ tmp;                               \
186   }
187 
188 // best
189 #define HS_CMP_XCHG_V2(a,b)                     \
190   if (a >= b) {                                 \
191     HS_KEY_TYPE const t = a;                    \
192     a = b;                                      \
193     b = t;                                      \
194   }
195 
196 // good
197 #define HS_CMP_XCHG_V3(a,b)                     \
198   {                                             \
199     int         const ge = a >= b;              \
200     HS_KEY_TYPE const t  = a;                   \
201     a = ge ? b : a;                             \
202     b = ge ? t : b;                             \
203   }
204 
205 //
206 //
207 //
208 
209 #if   (HS_KEY_WORDS == 1)
210 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V0(a,b)
211 #elif (HS_KEY_WORDS == 2)
212 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V2(a,b)
213 #endif
214 
215 //
216 // The flip/half comparisons rely on a "conditional min/max":
217 //
218 //  - if the flag is false, return min(a,b)
219 //  - otherwise, return max(a,b)
220 //
221 // What's a little surprising is that sequence (1) is faster than (2)
222 // for 32-bit keys.
223 //
224 // I suspect either a code generation problem or that the sequence
225 // maps well to the GEN instruction set.
226 //
227 // We mostly care about 64-bit keys and unsurprisingly sequence (2) is
228 // fastest for this wider type.
229 //
230 
231 // this is what you would normally use
232 #define HS_COND_MIN_MAX_V0(lt,a,b) ((a <= b) ^ lt) ? b : a
233 
234 // this may be faster for 32-bit keys on Intel GEN
235 #define HS_COND_MIN_MAX_V1(lt,a,b) (lt ? b : a) ^ ((a ^ b) & HS_LTE_TO_MASK(a,b))
236 
237 //
238 // FIXME -- EVENTUALLY HANDLE KEY+VAL
239 //
240 
241 #if   (HS_KEY_WORDS == 1)
242 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
243 #elif (HS_KEY_WORDS == 2)
244 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
245 #endif
246 
247 //
248 // Conditional inter-subgroup flip/half compare exchange
249 //
250 
251 #define HS_CMP_FLIP(i,a,b)                                              \
252   {                                                                     \
253     HS_KEY_TYPE const ta = intel_sub_group_shuffle(a,flip_lane_idx);    \
254     HS_KEY_TYPE const tb = intel_sub_group_shuffle(b,flip_lane_idx);    \
255     a = HS_COND_MIN_MAX(t_lt,a,tb);                                     \
256     b = HS_COND_MIN_MAX(t_lt,b,ta);                                     \
257   }
258 
259 #define HS_CMP_HALF(i,a)                                                \
260   {                                                                     \
261     HS_KEY_TYPE const ta = intel_sub_group_shuffle(a,half_lane_idx);    \
262     a = HS_COND_MIN_MAX(t_lt,a,ta);                                     \
263   }
264 
265 //
266 // The device's comparison operator might return what we actually
267 // want.  For example, it appears GEN 'cmp' returns {true:-1,false:0}.
268 //
269 
270 #define HS_CMP_IS_ZERO_ONE
271 
272 #ifdef HS_CMP_IS_ZERO_ONE
273 // OpenCL requires a {true: +1, false: 0} scalar result
274 // (a < b) -> { +1, 0 } -> NEGATE -> { 0, 0xFFFFFFFF }
275 #define HS_LTE_TO_MASK(a,b) (HS_KEY_TYPE)(-(a <= b))
276 #define HS_CMP_TO_MASK(a)   (HS_KEY_TYPE)(-a)
277 #else
278 // However, OpenCL requires { -1, 0 } for vectors
279 // (a < b) -> { 0xFFFFFFFF, 0 }
280 #define HS_LTE_TO_MASK(a,b) (a <= b) // FIXME for uint64
281 #define HS_CMP_TO_MASK(a)   (a)
282 #endif
283 
284 //
285 // The "flip-merge" and "half-merge" preambles are very similar
286 //
287 // For now, we're only using the .y dimension for the span idx
288 //
289 
290 #define HS_HM_PREAMBLE(half_span)                       \
291   uint const span_idx    = get_global_id(1);            \
292   uint const span_stride = get_global_size(0);          \
293   uint const span_size   = span_stride * half_span * 2; \
294   uint const span_base   = span_idx * span_size;        \
295   uint const span_off    = get_global_id(0);            \
296   uint const span_l      = span_base + span_off
297 
298 #define HS_FM_PREAMBLE(half_span)                                       \
299   HS_HM_PREAMBLE(half_span);                                            \
300   uint const span_r      = span_base + span_stride * (half_span + 1) - span_off - 1
301 
302 //
303 //
304 //
305 
306 #define HS_XM_GLOBAL_L(stride_idx)              \
307   vout[span_l + span_stride * stride_idx]
308 
309 #define HS_XM_GLOBAL_LOAD_L(stride_idx)         \
310   HS_XM_GLOBAL_L(stride_idx)
311 
312 #define HS_XM_GLOBAL_STORE_L(stride_idx,reg)    \
313   HS_XM_GLOBAL_L(stride_idx) = reg
314 
315 #define HS_FM_GLOBAL_R(stride_idx)              \
316   vout[span_r + span_stride * stride_idx]
317 
318 #define HS_FM_GLOBAL_LOAD_R(stride_idx)         \
319   HS_FM_GLOBAL_R(stride_idx)
320 
321 #define HS_FM_GLOBAL_STORE_R(stride_idx,reg)    \
322   HS_FM_GLOBAL_R(stride_idx) = reg
323 
324 //
325 // This snarl of macros is for transposing a "slab" of sorted elements
326 // into linear order.
327 //
328 // This can occur as the last step in hs_sort() or via a custom kernel
329 // that inspects the slab and then transposes and stores it to memory.
330 //
331 // The slab format can be inspected more efficiently than a linear
332 // arrangement.
333 //
334 // The prime example is detecting when adjacent keys (in sort order)
335 // have differing high order bits ("key changes").  The index of each
336 // change is recorded to an auxilary array.
337 //
338 // A post-processing step like this needs to be able to navigate the
339 // slab and eventually transpose and store the slab in linear order.
340 //
341 
342 #define HS_SUBGROUP_SHUFFLE_XOR(v,m)   intel_sub_group_shuffle_xor(v,m)
343 
344 #define HS_TRANSPOSE_REG(prefix,row)   prefix##row
345 #define HS_TRANSPOSE_DECL(prefix,row)  HS_KEY_TYPE const HS_TRANSPOSE_REG(prefix,row)
346 #define HS_TRANSPOSE_PRED(level)       is_lo_##level
347 
348 #define HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)       \
349   prefix_curr##row_ll##_##row_ur
350 
351 #define HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur)      \
352   HS_KEY_TYPE const HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)
353 
354 #define HS_TRANSPOSE_STAGE(level)                       \
355   bool const HS_TRANSPOSE_PRED(level) =                 \
356     (get_sub_group_local_id() & (1 << (level-1))) == 0;
357 
358 #define HS_TRANSPOSE_BLEND(prefix_prev,prefix_curr,level,row_ll,row_ur) \
359   HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) =                    \
360     HS_SUBGROUP_SHUFFLE_XOR(HS_TRANSPOSE_PRED(level) ?                  \
361                             HS_TRANSPOSE_REG(prefix_prev,row_ll) :      \
362                             HS_TRANSPOSE_REG(prefix_prev,row_ur),       \
363                             1<<(level-1));                              \
364                                                                         \
365   HS_TRANSPOSE_DECL(prefix_curr,row_ll) =                               \
366     HS_TRANSPOSE_PRED(level)                  ?                         \
367     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) :                   \
368     HS_TRANSPOSE_REG(prefix_prev,row_ll);                               \
369                                                                         \
370   HS_TRANSPOSE_DECL(prefix_curr,row_ur) =                               \
371     HS_TRANSPOSE_PRED(level)                  ?                         \
372     HS_TRANSPOSE_REG(prefix_prev,row_ur)      :                         \
373     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur);
374 
375 #define HS_TRANSPOSE_REMAP(prefix,row_from,row_to)      \
376   vout[gmem_idx + ((row_to-1) << HS_SLAB_WIDTH_LOG2)] = \
377     HS_TRANSPOSE_REG(prefix,row_from);
378 
379 //
380 //
381 //
382 
383 #endif
384 
385 //
386 //
387 //
388