• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP
13 
14 #include <boost/compute/functional.hpp>
15 #include <boost/compute/algorithm/find_if.hpp>
16 #include <boost/compute/algorithm/transform.hpp>
17 #include <boost/compute/command_queue.hpp>
18 #include <boost/compute/detail/parameter_cache.hpp>
19 
20 namespace boost {
21 namespace compute {
22 namespace detail{
23 
24 ///
25 /// \brief Binary find kernel class
26 ///
27 /// Subclass of meta_kernel to perform single step in binary find.
28 ///
29 template<class InputIterator, class UnaryPredicate>
30 class binary_find_kernel : public meta_kernel
31 {
32 public:
binary_find_kernel(InputIterator first,InputIterator last,UnaryPredicate predicate)33     binary_find_kernel(InputIterator first,
34                        InputIterator last,
35                        UnaryPredicate predicate)
36         : meta_kernel("binary_find")
37     {
38         typedef typename std::iterator_traits<InputIterator>::value_type value_type;
39 
40         m_index_arg = add_arg<uint_ *>(memory_object::global_memory, "index");
41         m_block_arg = add_arg<uint_>("block");
42 
43         atomic_min<uint_> atomic_min_uint;
44 
45         *this <<
46             "uint i = get_global_id(0) * block;\n" <<
47             decl<value_type>("value") << "=" << first[var<uint_>("i")] << ";\n" <<
48             "if(" << predicate(var<value_type>("value")) << ") {\n" <<
49                 atomic_min_uint(var<uint_ *>("index"), var<uint_>("i")) << ";\n" <<
50             "}\n";
51     }
52 
53     size_t m_index_arg;
54     size_t m_block_arg;
55 };
56 
57 ///
58 /// \brief Binary find algorithm
59 ///
60 /// Finds the end of true values in the partitioned range [first, last).
61 /// \return Iterator pointing to end of true values
62 ///
63 /// \param first Iterator pointing to start of range
64 /// \param last Iterator pointing to end of range
65 /// \param predicate Predicate according to which the range is partitioned
66 /// \param queue Queue on which to execute
67 ///
68 template<class InputIterator, class UnaryPredicate>
binary_find(InputIterator first,InputIterator last,UnaryPredicate predicate,command_queue & queue=system::default_queue ())69 inline InputIterator binary_find(InputIterator first,
70                                  InputIterator last,
71                                  UnaryPredicate predicate,
72                                  command_queue &queue = system::default_queue())
73 {
74     const device &device = queue.get_device();
75 
76     boost::shared_ptr<parameter_cache> parameters =
77         detail::parameter_cache::get_global_cache(device);
78 
79     const std::string cache_key = "__boost_binary_find";
80 
81     size_t find_if_limit = 128;
82     size_t threads = parameters->get(cache_key, "tpb", 128);
83     size_t count = iterator_range_size(first, last);
84 
85     InputIterator search_first = first;
86     InputIterator search_last = last;
87 
88     scalar<uint_> index(queue.get_context());
89 
90     // construct and compile binary_find kernel
91     binary_find_kernel<InputIterator, UnaryPredicate>
92         binary_find_kernel(search_first, search_last, predicate);
93     ::boost::compute::kernel kernel = binary_find_kernel.compile(queue.get_context());
94 
95     // set buffer for index
96     kernel.set_arg(binary_find_kernel.m_index_arg, index.get_buffer());
97 
98     while(count > find_if_limit) {
99         index.write(static_cast<uint_>(count), queue);
100 
101         // set block and run binary_find kernel
102         uint_ block = static_cast<uint_>((count - 1)/(threads - 1));
103         kernel.set_arg(binary_find_kernel.m_block_arg, block);
104         queue.enqueue_1d_range_kernel(kernel, 0, threads, 0);
105 
106         size_t i = index.read(queue);
107 
108         if(i == count) {
109             search_first = search_last - ((count - 1)%(threads - 1));
110             break;
111         } else {
112             search_last = search_first + i;
113             search_first = search_last - ((count - 1)/(threads - 1));
114         }
115 
116         // Make sure that first and last stay within the input range
117         search_last = (std::min)(search_last, last);
118         search_last = (std::max)(search_last, first);
119 
120         search_first = (std::max)(search_first, first);
121         search_first = (std::min)(search_first, last);
122 
123         count = iterator_range_size(search_first, search_last);
124     }
125 
126     return find_if(search_first, search_last, predicate, queue);
127 }
128 
129 } // end detail namespace
130 } // end compute namespace
131 } // end boost namespace
132 
133 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP
134