• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
13 
14 #include <algorithm>
15 
16 #include <boost/compute/kernel.hpp>
17 #include <boost/compute/program.hpp>
18 #include <boost/compute/command_queue.hpp>
19 #include <boost/compute/container/vector.hpp>
20 #include <boost/compute/memory/local_buffer.hpp>
21 #include <boost/compute/detail/meta_kernel.hpp>
22 #include <boost/compute/detail/iterator_range_size.hpp>
23 
24 namespace boost {
25 namespace compute {
26 namespace detail {
27 
28 template<class KeyType, class ValueType>
pick_bitonic_block_sort_block_size(size_t proposed_wg,size_t lmem_size,bool sort_by_key)29 inline size_t pick_bitonic_block_sort_block_size(size_t proposed_wg,
30                                                  size_t lmem_size,
31                                                  bool sort_by_key)
32 {
33     size_t n = proposed_wg;
34 
35     size_t lmem_required = n * sizeof(KeyType);
36     if(sort_by_key) {
37         lmem_required += n * sizeof(ValueType);
38     }
39 
40     // try to force at least 4 work-groups of >64 elements
41     // for better occupancy
42     while(lmem_size < (lmem_required * 4) && (n > 64)) {
43         n /= 2;
44         lmem_required = n * sizeof(KeyType);
45     }
46     while(lmem_size < lmem_required && (n != 1)) {
47         n /= 2;
48         if(n < 1) n = 1;
49         lmem_required = n * sizeof(KeyType);
50     }
51 
52     if(n < 2)   { return 1; }
53     else if(n < 4)   { return 2; }
54     else if(n < 8)   { return 4; }
55     else if(n < 16)  { return 8; }
56     else if(n < 32)  { return 16; }
57     else if(n < 64)  { return 32; }
58     else if(n < 128) { return 64; }
59     else if(n < 256) { return 128; }
60     else             { return 256; }
61 }
62 
63 
64 /// Performs bitonic block sort according to \p compare.
65 ///
66 /// Since bitonic sort can be only performed when input size is equal to 2^n,
67 /// in this case input size is block size (\p work_group_size), we would have
68 /// to require \p count be a exact multiple of block size. That would not be
69 /// great.
70 /// Instead, bitonic sort kernel is merged with odd-even merge sort so if the
71 /// last block is not equal to 2^n (where n is some natural number) the odd-even
72 /// sort is performed for that block. That way bitonic_block_sort() works for
73 /// input of any size. Block size (\p work_group_size) still have to be equal
74 /// to 2^n.
75 ///
76 /// This is NOT stable.
77 ///
78 /// \param keys_first first key element in the range to sort
79 /// \param values_first first value element in the range to sort
80 /// \param compare comparison function for keys
81 /// \param count number of elements in the range; count > 0
82 /// \param work_group_size size of the work group, also the block size; must be
83 ///        equal to n^2 where n is natural number
84 /// \param queue command queue to perform the operation
85 template<class KeyIterator, class ValueIterator, class Compare>
bitonic_block_sort(KeyIterator keys_first,ValueIterator values_first,Compare compare,const size_t count,const bool sort_by_key,command_queue & queue)86 inline size_t bitonic_block_sort(KeyIterator keys_first,
87                                  ValueIterator values_first,
88                                  Compare compare,
89                                  const size_t count,
90                                  const bool sort_by_key,
91                                  command_queue &queue)
92 {
93     typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
94     typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
95 
96     meta_kernel k("bitonic_block_sort");
97     size_t count_arg = k.add_arg<const uint_>("count");
98 
99     size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "lkeys");
100     size_t local_vals_arg = 0;
101     if(sort_by_key) {
102         local_vals_arg = k.add_arg<uchar_ *>(memory_object::local_memory, "lidx");
103     }
104 
105     k <<
106         // Work item global and local ids
107         k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
108         k.decl<const uint_>("lid") << " = get_local_id(0);\n";
109 
110     // declare my_key and my_value
111     k <<
112         k.decl<key_type>("my_key") << ";\n";
113     // Instead of copying values (my_value) in local memory with keys
114     // we save local index (uchar) and copy my_value at the end at
115     // final index. This saves local memory.
116     if(sort_by_key)
117     {
118         k <<
119             k.decl<uchar_>("my_index") << " = (uchar)(lid);\n";
120     }
121 
122     // load key
123     k <<
124         "if(gid < count) {\n" <<
125             k.var<key_type>("my_key") <<  " = " <<
126                 keys_first[k.var<const uint_>("gid")] << ";\n" <<
127         "}\n";
128 
129     // load key and index to local memory
130     k <<
131         "lkeys[lid] = my_key;\n";
132     if(sort_by_key)
133     {
134         k <<
135             "lidx[lid] = my_index;\n";
136     }
137     k <<
138         k.decl<const uint_>("offset") << " = get_group_id(0) * get_local_size(0);\n" <<
139         k.decl<const uint_>("n") << " = min((uint)(get_local_size(0)),(count - offset));\n";
140 
141     // When work group size is a power of 2 bitonic sorter can be used;
142     // otherwise, slower odd-even sort is used.
143 
144     k <<
145         // check if n is power of 2
146         "if(((n != 0) && ((n & (~n + 1)) == n))) {\n";
147 
148     // bitonic sort, not stable
149     k <<
150         // wait for keys and vals to be stored in local memory
151         "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
152 
153         "#pragma unroll\n" <<
154         "for(" <<
155             k.decl<uint_>("length") << " = 1; " <<
156             "length < n; " <<
157             "length <<= 1" <<
158         ") {\n" <<
159             // direction of sort: false -> asc, true -> desc
160             k.decl<bool>("direction") << "= ((lid & (length<<1)) != 0);\n" <<
161             "for(" <<
162                 k.decl<uint_>("k") << " = length; " <<
163                 "k > 0; " <<
164                 "k >>= 1" <<
165             ") {\n" <<
166 
167             // sibling to compare with my key
168             k.decl<uint_>("sibling_idx") << " = lid ^ k;\n" <<
169             k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
170             k.decl<bool>("compare") << " = " <<
171                 compare(k.var<key_type>("sibling_key"),
172                         k.var<key_type>("my_key")) << ";\n" <<
173             k.decl<bool>("equal") << " = !(compare || " <<
174                 compare(k.var<key_type>("my_key"),
175                         k.var<key_type>("sibling_key")) << ");\n" <<
176             k.decl<bool>("swap") <<
177                 " = compare ^ (sibling_idx < lid) ^ direction;\n" <<
178             "swap = equal ? false : swap;\n" <<
179             "my_key = swap ? sibling_key : my_key;\n";
180     if(sort_by_key)
181     {
182         k <<
183             "my_index = swap ? lidx[sibling_idx] : my_index;\n";
184     }
185     k <<
186             "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
187             "lkeys[lid] = my_key;\n";
188     if(sort_by_key)
189     {
190         k <<
191             "lidx[lid] = my_index;\n";
192     }
193     k <<
194             "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
195             "}\n" <<
196          "}\n";
197 
198     // end of bitonic sort
199 
200     // odd-even sort, not stable
201     k <<
202         "}\n" <<
203         "else { \n";
204 
205     k <<
206         k.decl<bool>("lid_is_even") << " = (lid%2) == 0;\n" <<
207         k.decl<uint_>("oddsibling_idx") << " = " <<
208             "(lid_is_even) ? max(lid,(uint)(1)) - 1 : min(lid+1,n-1);\n" <<
209         k.decl<uint_>("evensibling_idx") << " = " <<
210             "(lid_is_even) ? min(lid+1,n-1) : max(lid,(uint)(1)) - 1;\n" <<
211 
212         // wait for keys and vals to be stored in local memory
213         "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
214 
215         "#pragma unroll\n" <<
216         "for(" <<
217             k.decl<uint_>("i") << " = 0; " <<
218             "i < n; " <<
219             "i++" <<
220         ") {\n" <<
221             k.decl<uint_>("sibling_idx") <<
222                 " = i%2 == 0 ? evensibling_idx : oddsibling_idx;\n" <<
223             k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
224             k.decl<bool>("compare") << " = " <<
225                 compare(k.var<key_type>("sibling_key"),
226                         k.var<key_type>("my_key")) << ";\n" <<
227             k.decl<bool>("equal") << " = !(compare || " <<
228                 compare(k.var<key_type>("my_key"),
229                         k.var<key_type>("sibling_key")) << ");\n" <<
230             k.decl<bool>("swap") <<
231                 " = compare ^ (sibling_idx < lid);\n" <<
232             "swap = equal ? false : swap;\n" <<
233             "my_key = swap ? sibling_key : my_key;\n";
234     if(sort_by_key)
235     {
236         k <<
237             "my_index = swap ? lidx[sibling_idx] : my_index;\n";
238     }
239     k <<
240             "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
241             "lkeys[lid] = my_key;\n";
242     if(sort_by_key)
243     {
244         k <<
245             "lidx[lid] = my_index;\n";
246     }
247     k <<
248             "barrier(CLK_LOCAL_MEM_FENCE);\n"
249         "}\n" <<  // for
250 
251     "}\n"; // else
252     // end of odd-even sort
253 
254     // save key and value
255     k <<
256         "if(gid < count) {\n" <<
257         keys_first[k.var<const uint_>("gid")] << " = " <<
258             k.var<key_type>("my_key") << ";\n";
259     if(sort_by_key)
260     {
261         k <<
262             k.decl<value_type>("my_value") << " = " <<
263                 values_first[k.var<const uint_>("offset + my_index")] << ";\n" <<
264             "barrier(CLK_GLOBAL_MEM_FENCE);\n" <<
265             values_first[k.var<const uint_>("gid")] << " = my_value;\n";
266     }
267     k <<
268         // end if
269         "}\n";
270 
271     const context &context = queue.get_context();
272     const device &device = queue.get_device();
273     ::boost::compute::kernel kernel = k.compile(context);
274 
275     const size_t work_group_size =
276         pick_bitonic_block_sort_block_size<key_type, uchar_>(
277             kernel.get_work_group_info<size_t>(
278                 device, CL_KERNEL_WORK_GROUP_SIZE
279             ),
280             device.get_info<size_t>(CL_DEVICE_LOCAL_MEM_SIZE),
281             sort_by_key
282         );
283 
284     const size_t global_size =
285         work_group_size * static_cast<size_t>(
286             std::ceil(float(count) / work_group_size)
287         );
288 
289     kernel.set_arg(count_arg, static_cast<uint_>(count));
290     kernel.set_arg(local_keys_arg, local_buffer<key_type>(work_group_size));
291     if(sort_by_key) {
292         kernel.set_arg(local_vals_arg, local_buffer<uchar_>(work_group_size));
293     }
294 
295     queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
296     // return size of the block
297     return work_group_size;
298 }
299 
300 template<class KeyIterator, class ValueIterator, class Compare>
block_sort(KeyIterator keys_first,ValueIterator values_first,Compare compare,const size_t count,const bool sort_by_key,const bool stable,command_queue & queue)301 inline size_t block_sort(KeyIterator keys_first,
302                          ValueIterator values_first,
303                          Compare compare,
304                          const size_t count,
305                          const bool sort_by_key,
306                          const bool stable,
307                          command_queue &queue)
308 {
309     if(stable) {
310         // TODO: Implement stable block sort (stable odd-even merge sort)
311         return size_t(1);
312     }
313     return bitonic_block_sort(
314         keys_first, values_first,
315         compare, count,
316         sort_by_key, queue
317     );
318 }
319 
320 /// space: O(n + m); n - number of keys, m - number of values
321 template<class KeyIterator, class ValueIterator, class Compare>
merge_blocks_on_gpu(KeyIterator keys_first,ValueIterator values_first,KeyIterator out_keys_first,ValueIterator out_values_first,Compare compare,const size_t count,const size_t block_size,const bool sort_by_key,command_queue & queue)322 inline void merge_blocks_on_gpu(KeyIterator keys_first,
323                                 ValueIterator values_first,
324                                 KeyIterator out_keys_first,
325                                 ValueIterator out_values_first,
326                                 Compare compare,
327                                 const size_t count,
328                                 const size_t block_size,
329                                 const bool sort_by_key,
330                                 command_queue &queue)
331 {
332     typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
333     typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
334 
335     meta_kernel k("merge_blocks");
336     size_t count_arg = k.add_arg<const uint_>("count");
337     size_t block_size_arg = k.add_arg<const uint_>("block_size");
338 
339     k <<
340         // get global id
341         k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
342         "if(gid >= count) {\n" <<
343             "return;\n" <<
344         "}\n" <<
345 
346         k.decl<const key_type>("my_key") << " = " <<
347             keys_first[k.var<const uint_>("gid")] << ";\n";
348 
349     if(sort_by_key) {
350         k <<
351             k.decl<const value_type>("my_value") << " = " <<
352                 values_first[k.var<const uint_>("gid")] << ";\n";
353     }
354 
355     k <<
356         // get my block idx
357         k.decl<const uint_>("my_block_idx") << " = gid / block_size;\n" <<
358         k.decl<const bool>("my_block_idx_is_odd") << " = " <<
359             "my_block_idx & 0x1;\n" <<
360 
361         k.decl<const uint_>("other_block_idx") << " = " <<
362             // if(my_block_idx is odd) {} else {}
363             "my_block_idx_is_odd ? my_block_idx - 1 : my_block_idx + 1;\n" <<
364 
365         // get ranges of my block and the other block
366         // [my_block_start; my_block_end)
367         // [other_block_start; other_block_end)
368         k.decl<const uint_>("my_block_start") << " = " <<
369             "min(my_block_idx * block_size, count);\n" << // including
370         k.decl<const uint_>("my_block_end") << " = " <<
371             "min((my_block_idx + 1) * block_size, count);\n" << // excluding
372 
373         k.decl<const uint_>("other_block_start") << " = " <<
374             "min(other_block_idx * block_size, count);\n" << // including
375         k.decl<const uint_>("other_block_end") << " = " <<
376             "min((other_block_idx + 1) * block_size, count);\n" << // excluding
377 
378         // other block is empty, nothing to merge here
379         "if(other_block_start == count){\n" <<
380             out_keys_first[k.var<uint_>("gid")] << " = my_key;\n";
381         if(sort_by_key) {
382             k <<
383                 out_values_first[k.var<uint_>("gid")] << " = my_value;\n";
384         }
385 
386         k <<
387         "return;\n" <<
388         "}\n" <<
389 
390         // lower bound
391         // left_idx - lower bound
392         k.decl<uint_>("left_idx") << " = other_block_start;\n" <<
393         k.decl<uint_>("right_idx") << " = other_block_end;\n" <<
394         "while(left_idx < right_idx) {\n" <<
395             k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
396             k.decl<key_type>("mid_key") << " = " <<
397                     keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
398             k.decl<bool>("smaller") << " = " <<
399                 compare(k.var<key_type>("mid_key"),
400                         k.var<key_type>("my_key")) << ";\n" <<
401             "left_idx = smaller ? mid_idx + 1 : left_idx;\n" <<
402             "right_idx = smaller ? right_idx :  mid_idx;\n" <<
403         "}\n" <<
404         // left_idx is found position in other block
405 
406         // if my_block is odd we need to get the upper bound
407         "right_idx = other_block_end;\n" <<
408         "if(my_block_idx_is_odd && left_idx != right_idx) {\n" <<
409             k.decl<key_type>("upper_key") << " = " <<
410                 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
411             "while(" <<
412                 "!(" << compare(k.var<key_type>("upper_key"),
413                                 k.var<key_type>("my_key")) <<
414                 ") && " <<
415                 "!(" << compare(k.var<key_type>("my_key"),
416                                 k.var<key_type>("upper_key")) <<
417                 ") && " <<
418                      "left_idx < right_idx" <<
419                 ")" <<
420             "{\n" <<
421                 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
422                 k.decl<key_type>("mid_key") << " = " <<
423                     keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
424                 k.decl<bool>("equal") << " = " <<
425                     "!(" << compare(k.var<key_type>("mid_key"),
426                                     k.var<key_type>("my_key")) <<
427                     ") && " <<
428                     "!(" << compare(k.var<key_type>("my_key"),
429                                     k.var<key_type>("mid_key")) <<
430                     ");\n" <<
431                 "left_idx = equal ? mid_idx + 1 : left_idx + 1;\n" <<
432                 "right_idx = equal ? right_idx : mid_idx;\n" <<
433                 "upper_key = " <<
434                     keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
435             "}\n" <<
436         "}\n" <<
437 
438         k.decl<uint_>("offset") << " = 0;\n" <<
439         "offset += gid - my_block_start;\n" <<
440         "offset += left_idx - other_block_start;\n" <<
441         "offset += min(my_block_start, other_block_start);\n" <<
442         out_keys_first[k.var<uint_>("offset")] << " = my_key;\n";
443     if(sort_by_key) {
444         k <<
445             out_values_first[k.var<uint_>("offset")] << " = my_value;\n";
446     }
447 
448     const context &context = queue.get_context();
449     ::boost::compute::kernel kernel = k.compile(context);
450 
451     const size_t work_group_size = (std::min)(
452         size_t(256),
453         kernel.get_work_group_info<size_t>(
454             queue.get_device(), CL_KERNEL_WORK_GROUP_SIZE
455         )
456     );
457     const size_t global_size =
458         work_group_size * static_cast<size_t>(
459             std::ceil(float(count) / work_group_size)
460         );
461 
462     kernel.set_arg(count_arg, static_cast<uint_>(count));
463     kernel.set_arg(block_size_arg, static_cast<uint_>(block_size));
464     queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
465 }
466 
467 template<class KeyIterator, class ValueIterator, class Compare>
merge_sort_by_key_on_gpu(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,Compare compare,bool stable,command_queue & queue)468 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
469                                      KeyIterator keys_last,
470                                      ValueIterator values_first,
471                                      Compare compare,
472                                      bool stable,
473                                      command_queue &queue)
474 {
475     typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
476     typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
477 
478     size_t count = iterator_range_size(keys_first, keys_last);
479     if(count < 2){
480         return;
481     }
482 
483     size_t block_size =
484         block_sort(
485             keys_first, values_first,
486             compare, count,
487             true /* sort_by_key */, stable /* stable */,
488             queue
489         );
490 
491     // for small input size only block sort is performed
492     if(count <= block_size) {
493         return;
494     }
495 
496     const context &context = queue.get_context();
497 
498     bool result_in_temporary_buffer = false;
499     ::boost::compute::vector<key_type> temp_keys(count, context);
500     ::boost::compute::vector<value_type> temp_values(count, context);
501 
502     for(; block_size < count; block_size *= 2) {
503         result_in_temporary_buffer = !result_in_temporary_buffer;
504         if(result_in_temporary_buffer) {
505             merge_blocks_on_gpu(keys_first, values_first,
506                                 temp_keys.begin(), temp_values.begin(),
507                                 compare, count, block_size,
508                                 true /* sort_by_key */, queue);
509         } else {
510             merge_blocks_on_gpu(temp_keys.begin(), temp_values.begin(),
511                                 keys_first, values_first,
512                                 compare, count, block_size,
513                                 true /* sort_by_key */, queue);
514         }
515     }
516 
517     if(result_in_temporary_buffer) {
518         copy_async(temp_keys.begin(), temp_keys.end(), keys_first, queue);
519         copy_async(temp_values.begin(), temp_values.end(), values_first, queue);
520     }
521 }
522 
523 template<class Iterator, class Compare>
merge_sort_on_gpu(Iterator first,Iterator last,Compare compare,bool stable,command_queue & queue)524 inline void merge_sort_on_gpu(Iterator first,
525                               Iterator last,
526                               Compare compare,
527                               bool stable,
528                               command_queue &queue)
529 {
530     typedef typename std::iterator_traits<Iterator>::value_type key_type;
531 
532     size_t count = iterator_range_size(first, last);
533     if(count < 2){
534         return;
535     }
536 
537     Iterator dummy;
538     size_t block_size =
539         block_sort(
540             first, dummy,
541             compare, count,
542             false /* sort_by_key */, stable /* stable */,
543             queue
544         );
545 
546     // for small input size only block sort is performed
547     if(count <= block_size) {
548         return;
549     }
550 
551     const context &context = queue.get_context();
552 
553     bool result_in_temporary_buffer = false;
554     ::boost::compute::vector<key_type> temp_keys(count, context);
555 
556     for(; block_size < count; block_size *= 2) {
557         result_in_temporary_buffer = !result_in_temporary_buffer;
558         if(result_in_temporary_buffer) {
559             merge_blocks_on_gpu(first, dummy, temp_keys.begin(), dummy,
560                                 compare, count, block_size,
561                                 false /* sort_by_key */, queue);
562         } else {
563             merge_blocks_on_gpu(temp_keys.begin(), dummy, first, dummy,
564                                 compare, count, block_size,
565                                 false /* sort_by_key */, queue);
566         }
567     }
568 
569     if(result_in_temporary_buffer) {
570         copy_async(temp_keys.begin(), temp_keys.end(), first, queue);
571     }
572 }
573 
574 template<class KeyIterator, class ValueIterator, class Compare>
merge_sort_by_key_on_gpu(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,Compare compare,command_queue & queue)575 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
576                                      KeyIterator keys_last,
577                                      ValueIterator values_first,
578                                      Compare compare,
579                                      command_queue &queue)
580 {
581     merge_sort_by_key_on_gpu(
582         keys_first, keys_last, values_first,
583         compare, false /* not stable */, queue
584     );
585 }
586 
587 template<class Iterator, class Compare>
merge_sort_on_gpu(Iterator first,Iterator last,Compare compare,command_queue & queue)588 inline void merge_sort_on_gpu(Iterator first,
589                               Iterator last,
590                               Compare compare,
591                               command_queue &queue)
592 {
593     merge_sort_on_gpu(
594         first, last, compare, false /* not stable */, queue
595     );
596 }
597 
598 } // end detail namespace
599 } // end compute namespace
600 } // end boost namespace
601 
602 #endif /* BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ */
603