• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
13 
14 #include <boost/compute/kernel.hpp>
15 #include <boost/compute/detail/meta_kernel.hpp>
16 #include <boost/compute/command_queue.hpp>
17 #include <boost/compute/container/vector.hpp>
18 #include <boost/compute/detail/iterator_range_size.hpp>
19 #include <boost/compute/memory/local_buffer.hpp>
20 #include <boost/compute/iterator/buffer_iterator.hpp>
21 
22 namespace boost {
23 namespace compute {
24 namespace detail {
25 
26 template<class InputIterator, class OutputIterator, class BinaryOperator>
27 class local_scan_kernel : public meta_kernel
28 {
29 public:
local_scan_kernel(InputIterator first,InputIterator last,OutputIterator result,bool exclusive,BinaryOperator op)30     local_scan_kernel(InputIterator first,
31                       InputIterator last,
32                       OutputIterator result,
33                       bool exclusive,
34                       BinaryOperator op)
35         : meta_kernel("local_scan")
36     {
37         typedef typename std::iterator_traits<InputIterator>::value_type T;
38 
39         (void) last;
40 
41         bool checked = true;
42 
43         m_block_sums_arg = add_arg<T *>(memory_object::global_memory, "block_sums");
44         m_scratch_arg = add_arg<T *>(memory_object::local_memory, "scratch");
45         m_block_size_arg = add_arg<const cl_uint>("block_size");
46         m_count_arg = add_arg<const cl_uint>("count");
47         m_init_value_arg = add_arg<const T>("init");
48 
49         // work-item parameters
50         *this <<
51             "const uint gid = get_global_id(0);\n" <<
52             "const uint lid = get_local_id(0);\n";
53 
54         // check against data size
55         if(checked){
56             *this <<
57                 "if(gid < count){\n";
58         }
59 
60         // copy values from input to local memory
61         if(exclusive){
62             *this <<
63                 decl<const T>("local_init") << "= (gid == 0) ? init : 0;\n" <<
64                 "if(lid == 0){ scratch[lid] = local_init; }\n" <<
65                 "else { scratch[lid] = " << first[expr<cl_uint>("gid-1")] << "; }\n";
66         }
67         else{
68             *this <<
69                 "scratch[lid] = " << first[expr<cl_uint>("gid")] << ";\n";
70         }
71 
72         if(checked){
73             *this <<
74                 "}\n"
75                 "else {\n" <<
76                 "    scratch[lid] = 0;\n" <<
77                 "}\n";
78         }
79 
80         // wait for all threads to read from input
81         *this <<
82             "barrier(CLK_LOCAL_MEM_FENCE);\n";
83 
84         // perform scan
85         *this <<
86             "for(uint i = 1; i < block_size; i <<= 1){\n" <<
87             "    " << decl<const T>("x") << " = lid >= i ? scratch[lid-i] : 0;\n" <<
88             "    barrier(CLK_LOCAL_MEM_FENCE);\n" <<
89             "    if(lid >= i){\n" <<
90             "        scratch[lid] = " << op(var<T>("scratch[lid]"), var<T>("x")) << ";\n" <<
91             "    }\n" <<
92             "    barrier(CLK_LOCAL_MEM_FENCE);\n" <<
93             "}\n";
94 
95         // copy results to output
96         if(checked){
97             *this <<
98                 "if(gid < count){\n";
99         }
100 
101         *this <<
102             result[expr<cl_uint>("gid")] << " = scratch[lid];\n";
103 
104         if(checked){
105             *this << "}\n";
106         }
107 
108         // store sum for the block
109         if(exclusive){
110             *this <<
111                 "if(lid == block_size - 1 && gid < count) {\n" <<
112                 "    block_sums[get_group_id(0)] = " <<
113                        op(first[expr<cl_uint>("gid")], var<T>("scratch[lid]")) <<
114                        ";\n" <<
115                 "}\n";
116         }
117         else {
118             *this <<
119                 "if(lid == block_size - 1){\n" <<
120                 "    block_sums[get_group_id(0)] = scratch[lid];\n" <<
121                 "}\n";
122         }
123     }
124 
125     size_t m_block_sums_arg;
126     size_t m_scratch_arg;
127     size_t m_block_size_arg;
128     size_t m_count_arg;
129     size_t m_init_value_arg;
130 };
131 
132 template<class T, class BinaryOperator>
133 class write_scanned_output_kernel : public meta_kernel
134 {
135 public:
write_scanned_output_kernel(BinaryOperator op)136     write_scanned_output_kernel(BinaryOperator op)
137         : meta_kernel("write_scanned_output")
138     {
139         bool checked = true;
140 
141         m_output_arg = add_arg<T *>(memory_object::global_memory, "output");
142         m_block_sums_arg = add_arg<const T *>(memory_object::global_memory, "block_sums");
143         m_count_arg = add_arg<const cl_uint>("count");
144 
145         // work-item parameters
146         *this <<
147             "const uint gid = get_global_id(0);\n" <<
148             "const uint block_id = get_group_id(0);\n";
149 
150         // check against data size
151         if(checked){
152             *this << "if(gid < count){\n";
153         }
154 
155         // write output
156         *this <<
157             "output[gid] = " <<
158                 op(var<T>("block_sums[block_id]"), var<T>("output[gid] ")) << ";\n";
159 
160         if(checked){
161             *this << "}\n";
162         }
163     }
164 
165     size_t m_output_arg;
166     size_t m_block_sums_arg;
167     size_t m_count_arg;
168 };
169 
170 template<class InputIterator>
pick_scan_block_size(InputIterator first,InputIterator last)171 inline size_t pick_scan_block_size(InputIterator first, InputIterator last)
172 {
173     size_t count = iterator_range_size(first, last);
174 
175     if(count == 0)        { return 0; }
176     else if(count <= 1)   { return 1; }
177     else if(count <= 2)   { return 2; }
178     else if(count <= 4)   { return 4; }
179     else if(count <= 8)   { return 8; }
180     else if(count <= 16)  { return 16; }
181     else if(count <= 32)  { return 32; }
182     else if(count <= 64)  { return 64; }
183     else if(count <= 128) { return 128; }
184     else                  { return 256; }
185 }
186 
187 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
scan_impl(InputIterator first,InputIterator last,OutputIterator result,bool exclusive,T init,BinaryOperator op,command_queue & queue)188 inline OutputIterator scan_impl(InputIterator first,
189                                 InputIterator last,
190                                 OutputIterator result,
191                                 bool exclusive,
192                                 T init,
193                                 BinaryOperator op,
194                                 command_queue &queue)
195 {
196     typedef typename
197         std::iterator_traits<InputIterator>::value_type
198         input_type;
199     typedef typename
200         std::iterator_traits<InputIterator>::difference_type
201         difference_type;
202     typedef typename
203         std::iterator_traits<OutputIterator>::value_type
204         output_type;
205 
206     const context &context = queue.get_context();
207     const size_t count = detail::iterator_range_size(first, last);
208 
209     size_t block_size = pick_scan_block_size(first, last);
210     size_t block_count = count / block_size;
211 
212     if(block_count * block_size < count){
213         block_count++;
214     }
215 
216     ::boost::compute::vector<input_type> block_sums(block_count, context);
217 
218     // zero block sums
219     input_type zero;
220     std::memset(&zero, 0, sizeof(input_type));
221     ::boost::compute::fill(block_sums.begin(), block_sums.end(), zero, queue);
222 
223     // local scan
224     local_scan_kernel<InputIterator, OutputIterator, BinaryOperator>
225         local_scan_kernel(first, last, result, exclusive, op);
226 
227     ::boost::compute::kernel kernel = local_scan_kernel.compile(context);
228     kernel.set_arg(local_scan_kernel.m_scratch_arg, local_buffer<input_type>(block_size));
229     kernel.set_arg(local_scan_kernel.m_block_sums_arg, block_sums);
230     kernel.set_arg(local_scan_kernel.m_block_size_arg, static_cast<cl_uint>(block_size));
231     kernel.set_arg(local_scan_kernel.m_count_arg, static_cast<cl_uint>(count));
232     kernel.set_arg(local_scan_kernel.m_init_value_arg, static_cast<output_type>(init));
233 
234     queue.enqueue_1d_range_kernel(kernel,
235                                   0,
236                                   block_count * block_size,
237                                   block_size);
238 
239     // inclusive scan block sums
240     if(block_count > 1){
241         scan_impl(block_sums.begin(),
242                   block_sums.end(),
243                   block_sums.begin(),
244                   false,
245                   init,
246                   op,
247                   queue
248         );
249     }
250 
251     // add block sums to each block
252     if(block_count > 1){
253         write_scanned_output_kernel<input_type, BinaryOperator>
254             write_output_kernel(op);
255         kernel = write_output_kernel.compile(context);
256         kernel.set_arg(write_output_kernel.m_output_arg, result.get_buffer());
257         kernel.set_arg(write_output_kernel.m_block_sums_arg, block_sums);
258         kernel.set_arg(write_output_kernel.m_count_arg, static_cast<cl_uint>(count));
259 
260         queue.enqueue_1d_range_kernel(kernel,
261                                       block_size,
262                                       block_count * block_size,
263                                       block_size);
264     }
265 
266     return result + static_cast<difference_type>(count);
267 }
268 
269 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
dispatch_scan(InputIterator first,InputIterator last,OutputIterator result,bool exclusive,T init,BinaryOperator op,command_queue & queue)270 inline OutputIterator dispatch_scan(InputIterator first,
271                                     InputIterator last,
272                                     OutputIterator result,
273                                     bool exclusive,
274                                     T init,
275                                     BinaryOperator op,
276                                     command_queue &queue)
277 {
278     return scan_impl(first, last, result, exclusive, init, op, queue);
279 }
280 
281 template<class InputIterator, class T, class BinaryOperator>
dispatch_scan(InputIterator first,InputIterator last,InputIterator result,bool exclusive,T init,BinaryOperator op,command_queue & queue)282 inline InputIterator dispatch_scan(InputIterator first,
283                                    InputIterator last,
284                                    InputIterator result,
285                                    bool exclusive,
286                                    T init,
287                                    BinaryOperator op,
288                                    command_queue &queue)
289 {
290     typedef typename std::iterator_traits<InputIterator>::value_type value_type;
291 
292     if(first == result){
293         // scan input in-place
294         const context &context = queue.get_context();
295 
296         // make a temporary copy the input
297         size_t count = iterator_range_size(first, last);
298         vector<value_type> tmp(count, context);
299         copy(first, last, tmp.begin(), queue);
300 
301         // scan from temporary values
302         return scan_impl(tmp.begin(), tmp.end(), first, exclusive, init, op, queue);
303     }
304     else {
305         // scan input to output
306         return scan_impl(first, last, result, exclusive, init, op, queue);
307     }
308 }
309 
310 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
scan_on_gpu(InputIterator first,InputIterator last,OutputIterator result,bool exclusive,T init,BinaryOperator op,command_queue & queue)311 inline OutputIterator scan_on_gpu(InputIterator first,
312                                   InputIterator last,
313                                   OutputIterator result,
314                                   bool exclusive,
315                                   T init,
316                                   BinaryOperator op,
317                                   command_queue &queue)
318 {
319     if(first == last){
320         return result;
321     }
322 
323     return dispatch_scan(first, last, result, exclusive, init, op, queue);
324 }
325 
326 } // end detail namespace
327 } // end compute namespace
328 } // end boost namespace
329 
330 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
331