• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_BALLOT_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_BALLOT_HPP
13 
14 #include <boost/compute/context.hpp>
15 #include <boost/compute/command_queue.hpp>
16 #include <boost/compute/container/vector.hpp>
17 #include <boost/compute/algorithm/reduce.hpp>
18 #include <boost/compute/functional/detail/nvidia_ballot.hpp>
19 #include <boost/compute/functional/detail/nvidia_popcount.hpp>
20 #include <boost/compute/detail/meta_kernel.hpp>
21 
22 namespace boost {
23 namespace compute {
24 namespace detail {
25 
26 template<class InputIterator, class Predicate>
count_if_with_ballot(InputIterator first,InputIterator last,Predicate predicate,command_queue & queue)27 inline size_t count_if_with_ballot(InputIterator first,
28                                    InputIterator last,
29                                    Predicate predicate,
30                                    command_queue &queue)
31 {
32     size_t count = iterator_range_size(first, last);
33     size_t block_size = 32;
34     size_t block_count = count / block_size;
35     if(block_count * block_size != count){
36         block_count++;
37     }
38 
39     const ::boost::compute::context &context = queue.get_context();
40 
41     ::boost::compute::vector<uint_> counts(block_count, context);
42 
43     ::boost::compute::detail::nvidia_popcount<uint_> popc;
44     ::boost::compute::detail::nvidia_ballot<uint_> ballot;
45 
46     meta_kernel k("count_if_with_ballot");
47     k <<
48         "const uint gid = get_global_id(0);\n" <<
49 
50         "bool value = false;\n" <<
51         "if(gid < count)\n" <<
52         "    value = " << predicate(first[k.var<const uint_>("gid")]) << ";\n" <<
53 
54         "uint bits = " << ballot(k.var<const uint_>("value")) << ";\n" <<
55 
56         "if(get_local_id(0) == 0)\n" <<
57             counts.begin()[k.var<uint_>("get_group_id(0)") ]
58                 << " = " << popc(k.var<uint_>("bits")) << ";\n";
59 
60     k.add_set_arg<const uint_>("count", count);
61 
62     k.exec_1d(queue, 0, block_size * block_count, block_size);
63 
64     uint_ result;
65     ::boost::compute::reduce(
66         counts.begin(),
67         counts.end(),
68         &result,
69         queue
70     );
71     return result;
72 }
73 
74 } // end detail namespace
75 } // end compute namespace
76 } // end boost namespace
77 
78 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_BALLOT_HPP
79