• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
13 
14 #include <iterator>
15 
16 #include <boost/assert.hpp>
17 #include <boost/type_traits/is_signed.hpp>
18 #include <boost/type_traits/is_floating_point.hpp>
19 
20 #include <boost/mpl/and.hpp>
21 #include <boost/mpl/not.hpp>
22 
23 #include <boost/compute/kernel.hpp>
24 #include <boost/compute/program.hpp>
25 #include <boost/compute/command_queue.hpp>
26 #include <boost/compute/algorithm/exclusive_scan.hpp>
27 #include <boost/compute/container/vector.hpp>
28 #include <boost/compute/detail/iterator_range_size.hpp>
29 #include <boost/compute/detail/parameter_cache.hpp>
30 #include <boost/compute/type_traits/type_name.hpp>
31 #include <boost/compute/type_traits/is_fundamental.hpp>
32 #include <boost/compute/type_traits/is_vector_type.hpp>
33 #include <boost/compute/utility/program_cache.hpp>
34 
35 namespace boost {
36 namespace compute {
37 namespace detail {
38 
39 // meta-function returning true if type T is radix-sortable
40 template<class T>
41 struct is_radix_sortable :
42     boost::mpl::and_<
43         typename ::boost::compute::is_fundamental<T>::type,
44         typename boost::mpl::not_<typename is_vector_type<T>::type>::type
45     >
46 {
47 };
48 
49 template<size_t N>
50 struct radix_sort_value_type
51 {
52 };
53 
54 template<>
55 struct radix_sort_value_type<1>
56 {
57     typedef uchar_ type;
58 };
59 
60 template<>
61 struct radix_sort_value_type<2>
62 {
63     typedef ushort_ type;
64 };
65 
66 template<>
67 struct radix_sort_value_type<4>
68 {
69     typedef uint_ type;
70 };
71 
72 template<>
73 struct radix_sort_value_type<8>
74 {
75     typedef ulong_ type;
76 };
77 
78 template<typename T>
enable_double()79 inline const char* enable_double()
80 {
81     return " -DT2_double=0";
82 }
83 
84 template<>
enable_double()85 inline const char* enable_double<double>()
86 {
87     return " -DT2_double=1";
88 }
89 
90 const char radix_sort_source[] =
91 "#if T2_double\n"
92 "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
93 "#endif\n"
94 "#define K2_BITS (1 << K_BITS)\n"
95 "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
96 "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
97 
98 "#if defined(ASC)\n" // asc order
99 
100 "inline uint radix(const T x, const uint low_bit)\n"
101 "{\n"
102 "#if defined(IS_FLOATING_POINT)\n"
103 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
104 "    return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
105 "#elif defined(IS_SIGNED)\n"
106 "    return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
107 "#else\n"
108 "    return (x >> low_bit) & RADIX_MASK;\n"
109 "#endif\n"
110 "}\n"
111 
112 "#else\n" // desc order
113 
114 // For signed types we just negate the x and for unsigned types we
115 // subtract the x from max value of its type ((T)(-1) is a max value
116 // of type T when T is an unsigned type).
117 "inline uint radix(const T x, const uint low_bit)\n"
118 "{\n"
119 "#if defined(IS_FLOATING_POINT)\n"
120 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
121 "    return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
122 "#elif defined(IS_SIGNED)\n"
123 "    return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
124 "#else\n"
125 "    return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
126 "#endif\n"
127 "}\n"
128 
129 "#endif\n" // #if defined(ASC)
130 
131 "__kernel void count(__global const T *input,\n"
132 "                    const uint input_offset,\n"
133 "                    const uint input_size,\n"
134 "                    __global uint *global_counts,\n"
135 "                    __global uint *global_offsets,\n"
136 "                    __local uint *local_counts,\n"
137 "                    const uint low_bit)\n"
138 "{\n"
139      // work-item parameters
140 "    const uint gid = get_global_id(0);\n"
141 "    const uint lid = get_local_id(0);\n"
142 
143      // zero local counts
144 "    if(lid < K2_BITS){\n"
145 "        local_counts[lid] = 0;\n"
146 "    }\n"
147 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
148 
149      // reduce local counts
150 "    if(gid < input_size){\n"
151 "        T value = input[input_offset+gid];\n"
152 "        uint bucket = radix(value, low_bit);\n"
153 "        atomic_inc(local_counts + bucket);\n"
154 "    }\n"
155 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
156 
157      // write block-relative offsets
158 "    if(lid < K2_BITS){\n"
159 "        global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
160 
161          // write global offsets
162 "        if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
163 "            global_offsets[lid] = local_counts[lid];\n"
164 "        }\n"
165 "    }\n"
166 "}\n"
167 
168 "__kernel void scan(__global const uint *block_offsets,\n"
169 "                   __global uint *global_offsets,\n"
170 "                   const uint block_count)\n"
171 "{\n"
172 "    __global const uint *last_block_offsets =\n"
173 "        block_offsets + K2_BITS * (block_count - 1);\n"
174 
175      // calculate and scan global_offsets
176 "    uint sum = 0;\n"
177 "    for(uint i = 0; i < K2_BITS; i++){\n"
178 "        uint x = global_offsets[i] + last_block_offsets[i];\n"
179 "        mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX 500/Vega bug, see #811
180 "        global_offsets[i] = sum;\n"
181 "        sum += x;\n"
182 "        mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX Vega bug, see #811
183 "    }\n"
184 "}\n"
185 
186 "__kernel void scatter(__global const T *input,\n"
187 "                      const uint input_offset,\n"
188 "                      const uint input_size,\n"
189 "                      const uint low_bit,\n"
190 "                      __global const uint *counts,\n"
191 "                      __global const uint *global_offsets,\n"
192 "#ifndef SORT_BY_KEY\n"
193 "                      __global T *output,\n"
194 "                      const uint output_offset)\n"
195 "#else\n"
196 "                      __global T *keys_output,\n"
197 "                      const uint keys_output_offset,\n"
198 "                      __global T2 *values_input,\n"
199 "                      const uint values_input_offset,\n"
200 "                      __global T2 *values_output,\n"
201 "                      const uint values_output_offset)\n"
202 "#endif\n"
203 "{\n"
204      // work-item parameters
205 "    const uint gid = get_global_id(0);\n"
206 "    const uint lid = get_local_id(0);\n"
207 
208      // copy input to local memory
209 "    T value;\n"
210 "    uint bucket;\n"
211 "    __local uint local_input[BLOCK_SIZE];\n"
212 "    if(gid < input_size){\n"
213 "        value = input[input_offset+gid];\n"
214 "        bucket = radix(value, low_bit);\n"
215 "        local_input[lid] = bucket;\n"
216 "    }\n"
217 
218      // copy block counts to local memory
219 "    __local uint local_counts[(1 << K_BITS)];\n"
220 "    if(lid < K2_BITS){\n"
221 "        local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
222 "    }\n"
223 
224      // wait until local memory is ready
225 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
226 
227 "    if(gid >= input_size){\n"
228 "        return;\n"
229 "    }\n"
230 
231      // get global offset
232 "    uint offset = global_offsets[bucket] + local_counts[bucket];\n"
233 
234      // calculate local offset
235 "    uint local_offset = 0;\n"
236 "    for(uint i = 0; i < lid; i++){\n"
237 "        if(local_input[i] == bucket)\n"
238 "            local_offset++;\n"
239 "    }\n"
240 
241 "#ifndef SORT_BY_KEY\n"
242      // write value to output
243 "    output[output_offset + offset + local_offset] = value;\n"
244 "#else\n"
245      // write key and value if doing sort_by_key
246 "    keys_output[keys_output_offset+offset + local_offset] = value;\n"
247 "    values_output[values_output_offset+offset + local_offset] =\n"
248 "        values_input[values_input_offset+gid];\n"
249 "#endif\n"
250 "}\n";
251 
252 template<class T, class T2>
radix_sort_impl(const buffer_iterator<T> first,const buffer_iterator<T> last,const buffer_iterator<T2> values_first,const bool ascending,command_queue & queue)253 inline void radix_sort_impl(const buffer_iterator<T> first,
254                             const buffer_iterator<T> last,
255                             const buffer_iterator<T2> values_first,
256                             const bool ascending,
257                             command_queue &queue)
258 {
259 
260     typedef T value_type;
261     typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
262 
263     const device &device = queue.get_device();
264     const context &context = queue.get_context();
265 
266 
267     // if we have a valid values iterator then we are doing a
268     // sort by key and have to set up the values buffer
269     bool sort_by_key = (values_first.get_buffer().get() != 0);
270 
271     // load (or create) radix sort program
272     std::string cache_key =
273         std::string("__boost_radix_sort_") + type_name<value_type>();
274 
275     if(sort_by_key){
276         cache_key += std::string("_with_") + type_name<T2>();
277     }
278 
279     boost::shared_ptr<program_cache> cache =
280         program_cache::get_global_cache(context);
281     boost::shared_ptr<parameter_cache> parameters =
282         detail::parameter_cache::get_global_cache(device);
283 
284     // sort parameters
285     const uint_ k = parameters->get(cache_key, "k", 4);
286     const uint_ k2 = 1 << k;
287     const uint_ block_size = parameters->get(cache_key, "tpb", 128);
288 
289     // sort program compiler options
290     std::stringstream options;
291     options << "-DK_BITS=" << k;
292     options << " -DT=" << type_name<sort_type>();
293     options << " -DBLOCK_SIZE=" << block_size;
294 
295     if(boost::is_floating_point<value_type>::value){
296         options << " -DIS_FLOATING_POINT";
297     }
298 
299     if(boost::is_signed<value_type>::value){
300         options << " -DIS_SIGNED";
301     }
302 
303     if(sort_by_key){
304         options << " -DSORT_BY_KEY";
305         options << " -DT2=" << type_name<T2>();
306         options << enable_double<T2>();
307     }
308 
309     if(ascending){
310         options << " -DASC";
311     }
312 
313     // get type definition if it is a custom struct
314     std::string custom_type_def = boost::compute::type_definition<T2>() + "\n";
315 
316     // load radix sort program
317     program radix_sort_program = cache->get_or_build(
318        cache_key, options.str(), custom_type_def + radix_sort_source, context
319     );
320 
321     kernel count_kernel(radix_sort_program, "count");
322     kernel scan_kernel(radix_sort_program, "scan");
323     kernel scatter_kernel(radix_sort_program, "scatter");
324 
325     size_t count = detail::iterator_range_size(first, last);
326 
327     uint_ block_count = static_cast<uint_>(count / block_size);
328     if(block_count * block_size != count){
329         block_count++;
330     }
331 
332     // setup temporary buffers
333     vector<value_type> output(count, context);
334     vector<T2> values_output(sort_by_key ? count : 0, context);
335     vector<uint_> offsets(k2, context);
336     vector<uint_> counts(block_count * k2, context);
337 
338     const buffer *input_buffer = &first.get_buffer();
339     uint_ input_offset = static_cast<uint_>(first.get_index());
340     const buffer *output_buffer = &output.get_buffer();
341     uint_ output_offset = 0;
342     const buffer *values_input_buffer = &values_first.get_buffer();
343     uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
344     const buffer *values_output_buffer = &values_output.get_buffer();
345     uint_ values_output_offset = 0;
346 
347     for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
348         // write counts
349         count_kernel.set_arg(0, *input_buffer);
350         count_kernel.set_arg(1, input_offset);
351         count_kernel.set_arg(2, static_cast<uint_>(count));
352         count_kernel.set_arg(3, counts);
353         count_kernel.set_arg(4, offsets);
354         count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
355         count_kernel.set_arg(6, i * k);
356         queue.enqueue_1d_range_kernel(count_kernel,
357                                       0,
358                                       block_count * block_size,
359                                       block_size);
360 
361         // scan counts
362         if(k == 1){
363             typedef uint2_ counter_type;
364             ::boost::compute::exclusive_scan(
365                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
366                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
367                 make_buffer_iterator<counter_type>(counts.get_buffer()),
368                 queue
369             );
370         }
371         else if(k == 2){
372             typedef uint4_ counter_type;
373             ::boost::compute::exclusive_scan(
374                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
375                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
376                 make_buffer_iterator<counter_type>(counts.get_buffer()),
377                 queue
378             );
379         }
380         else if(k == 4){
381             typedef uint16_ counter_type;
382             ::boost::compute::exclusive_scan(
383                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
384                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
385                 make_buffer_iterator<counter_type>(counts.get_buffer()),
386                 queue
387             );
388         }
389         else {
390             BOOST_ASSERT(false && "unknown k");
391             break;
392         }
393 
394         // scan global offsets
395         scan_kernel.set_arg(0, counts);
396         scan_kernel.set_arg(1, offsets);
397         scan_kernel.set_arg(2, block_count);
398         queue.enqueue_task(scan_kernel);
399 
400         // scatter values
401         scatter_kernel.set_arg(0, *input_buffer);
402         scatter_kernel.set_arg(1, input_offset);
403         scatter_kernel.set_arg(2, static_cast<uint_>(count));
404         scatter_kernel.set_arg(3, i * k);
405         scatter_kernel.set_arg(4, counts);
406         scatter_kernel.set_arg(5, offsets);
407         scatter_kernel.set_arg(6, *output_buffer);
408         scatter_kernel.set_arg(7, output_offset);
409         if(sort_by_key){
410             scatter_kernel.set_arg(8, *values_input_buffer);
411             scatter_kernel.set_arg(9, values_input_offset);
412             scatter_kernel.set_arg(10, *values_output_buffer);
413             scatter_kernel.set_arg(11, values_output_offset);
414         }
415         queue.enqueue_1d_range_kernel(scatter_kernel,
416                                       0,
417                                       block_count * block_size,
418                                       block_size);
419 
420         // swap buffers
421         std::swap(input_buffer, output_buffer);
422         std::swap(values_input_buffer, values_output_buffer);
423         std::swap(input_offset, output_offset);
424         std::swap(values_input_offset, values_output_offset);
425     }
426 }
427 
428 template<class Iterator>
radix_sort(Iterator first,Iterator last,command_queue & queue)429 inline void radix_sort(Iterator first,
430                        Iterator last,
431                        command_queue &queue)
432 {
433     radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
434 }
435 
436 template<class KeyIterator, class ValueIterator>
radix_sort_by_key(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,command_queue & queue)437 inline void radix_sort_by_key(KeyIterator keys_first,
438                               KeyIterator keys_last,
439                               ValueIterator values_first,
440                               command_queue &queue)
441 {
442     radix_sort_impl(keys_first, keys_last, values_first, true, queue);
443 }
444 
445 template<class Iterator>
radix_sort(Iterator first,Iterator last,const bool ascending,command_queue & queue)446 inline void radix_sort(Iterator first,
447                        Iterator last,
448                        const bool ascending,
449                        command_queue &queue)
450 {
451     radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
452 }
453 
454 template<class KeyIterator, class ValueIterator>
radix_sort_by_key(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,const bool ascending,command_queue & queue)455 inline void radix_sort_by_key(KeyIterator keys_first,
456                               KeyIterator keys_last,
457                               ValueIterator values_first,
458                               const bool ascending,
459                               command_queue &queue)
460 {
461     radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
462 }
463 
464 
465 } // end detail namespace
466 } // end compute namespace
467 } // end boost namespace
468 
469 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
470