1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
13
14 #include <iterator>
15
16 #include <boost/compute/device.hpp>
17 #include <boost/compute/kernel.hpp>
18 #include <boost/compute/command_queue.hpp>
19 #include <boost/compute/algorithm/detail/serial_scan.hpp>
20 #include <boost/compute/detail/meta_kernel.hpp>
21 #include <boost/compute/detail/iterator_range_size.hpp>
22 #include <boost/compute/detail/parameter_cache.hpp>
23
24 namespace boost {
25 namespace compute {
26 namespace detail {
27
28 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
scan_on_cpu(InputIterator first,InputIterator last,OutputIterator result,bool exclusive,T init,BinaryOperator op,command_queue & queue)29 inline OutputIterator scan_on_cpu(InputIterator first,
30 InputIterator last,
31 OutputIterator result,
32 bool exclusive,
33 T init,
34 BinaryOperator op,
35 command_queue &queue)
36 {
37 typedef typename
38 std::iterator_traits<InputIterator>::value_type input_type;
39 typedef typename
40 std::iterator_traits<OutputIterator>::value_type output_type;
41
42 const context &context = queue.get_context();
43 const device &device = queue.get_device();
44 const size_t compute_units = queue.get_device().compute_units();
45
46 boost::shared_ptr<parameter_cache> parameters =
47 detail::parameter_cache::get_global_cache(device);
48
49 std::string cache_key =
50 "__boost_scan_cpu_" + boost::lexical_cast<std::string>(sizeof(T));
51
52 // for inputs smaller than serial_scan_threshold
53 // serial_scan algorithm is used
54 uint_ serial_scan_threshold =
55 parameters->get(cache_key, "serial_scan_threshold", 16384 * sizeof(T));
56 serial_scan_threshold =
57 (std::max)(serial_scan_threshold, uint_(compute_units));
58
59 size_t count = detail::iterator_range_size(first, last);
60 if(count == 0){
61 return result;
62 }
63 else if(count < serial_scan_threshold) {
64 return serial_scan(first, last, result, exclusive, init, op, queue);
65 }
66
67 buffer block_partial_sums(context, sizeof(output_type) * compute_units );
68
69 // create scan kernel
70 meta_kernel k("scan_on_cpu_block_scan");
71
72 // Arguments
73 size_t count_arg = k.add_arg<uint_>("count");
74 size_t init_arg = k.add_arg<output_type>("initial_value");
75 size_t block_partial_sums_arg =
76 k.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
77
78 k <<
79 "uint block = (count + get_global_size(0))/(get_global_size(0) + 1);\n" <<
80 "uint index = get_global_id(0) * block;\n" <<
81 "uint end = min(count, index + block);\n" <<
82 "if(index >= end) return;\n";
83
84 if(!exclusive){
85 k <<
86 k.decl<output_type>("sum") << " = " <<
87 first[k.var<uint_>("index")] << ";\n" <<
88 result[k.var<uint_>("index")] << " = sum;\n" <<
89 "index++;\n";
90 }
91 else {
92 k <<
93 k.decl<output_type>("sum") << ";\n" <<
94 "if(index == 0){\n" <<
95 "sum = initial_value;\n" <<
96 "}\n" <<
97 "else {\n" <<
98 "sum = " << first[k.var<uint_>("index")] << ";\n" <<
99 "index++;\n" <<
100 "}\n";
101 }
102
103 k <<
104 "while(index < end){\n" <<
105 // load next value
106 k.decl<const input_type>("value") << " = "
107 << first[k.var<uint_>("index")] << ";\n";
108
109 if(exclusive){
110 k <<
111 "if(get_global_id(0) == 0){\n" <<
112 result[k.var<uint_>("index")] << " = sum;\n" <<
113 "}\n";
114 }
115 k <<
116 "sum = " << op(k.var<output_type>("sum"),
117 k.var<output_type>("value")) << ";\n";
118
119 if(!exclusive){
120 k <<
121 "if(get_global_id(0) == 0){\n" <<
122 result[k.var<uint_>("index")] << " = sum;\n" <<
123 "}\n";
124 }
125
126 k <<
127 "index++;\n" <<
128 "}\n" << // end while
129 "block_partial_sums[get_global_id(0)] = sum;\n";
130
131 // compile scan kernel
132 kernel block_scan_kernel = k.compile(context);
133
134 // setup kernel arguments
135 block_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
136 block_scan_kernel.set_arg(init_arg, static_cast<output_type>(init));
137 block_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
138
139 // execute the kernel
140 size_t global_work_size = compute_units;
141 queue.enqueue_1d_range_kernel(block_scan_kernel, 0, global_work_size, 0);
142
143 // scan is done
144 if(compute_units < 2) {
145 return result + count;
146 }
147
148 // final scan kernel
149 meta_kernel l("scan_on_cpu_final_scan");
150
151 // Arguments
152 count_arg = l.add_arg<uint_>("count");
153 block_partial_sums_arg =
154 l.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
155
156 l <<
157 "uint block = (count + get_global_size(0))/(get_global_size(0) + 1);\n" <<
158 "uint index = block + get_global_id(0) * block;\n" <<
159 "uint end = min(count, index + block);\n" <<
160 k.decl<output_type>("sum") << " = block_partial_sums[0];\n" <<
161 "for(uint i = 0; i < get_global_id(0); i++) {\n" <<
162 "sum = " << op(k.var<output_type>("sum"),
163 k.var<output_type>("block_partial_sums[i + 1]")) << ";\n" <<
164 "}\n" <<
165
166 "while(index < end){\n";
167 if(exclusive){
168 l <<
169 l.decl<output_type>("value") << " = "
170 << first[k.var<uint_>("index")] << ";\n" <<
171 result[k.var<uint_>("index")] << " = sum;\n" <<
172 "sum = " << op(k.var<output_type>("sum"),
173 k.var<output_type>("value")) << ";\n";
174 }
175 else {
176 l <<
177 "sum = " << op(k.var<output_type>("sum"),
178 first[k.var<uint_>("index")]) << ";\n" <<
179 result[k.var<uint_>("index")] << " = sum;\n";
180 }
181 l <<
182 "index++;\n" <<
183 "}\n";
184
185
186 // compile scan kernel
187 kernel final_scan_kernel = l.compile(context);
188
189 // setup kernel arguments
190 final_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
191 final_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
192
193 // execute the kernel
194 global_work_size = compute_units;
195 queue.enqueue_1d_range_kernel(final_scan_kernel, 0, global_work_size, 0);
196
197 // return iterator pointing to the end of the result range
198 return result + count;
199 }
200
201 } // end detail namespace
202 } // end compute namespace
203 } // end boost namespace
204
205 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
206