• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
12 #define BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
13 
14 #include <boost/type_traits.hpp>
15 #include <boost/static_assert.hpp>
16 #include <boost/assert.hpp>
17 
18 #include <boost/compute/cl.hpp>
19 #include <boost/compute/kernel.hpp>
20 #include <boost/compute/context.hpp>
21 #include <boost/compute/command_queue.hpp>
22 #include <boost/compute/type_traits/is_device_iterator.hpp>
23 
24 namespace boost {
25 namespace compute {
26 
27 // forward declaration for svm_ptr<T>
28 template<class T>
29 class svm_ptr;
30 
31 // svm functions require OpenCL 2.0
32 #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
33 namespace detail {
34 
35 template<class T, class IndexExpr>
36 struct svm_ptr_index_expr
37 {
38     typedef T result_type;
39 
svm_ptr_index_exprboost::compute::detail::svm_ptr_index_expr40     svm_ptr_index_expr(const svm_ptr<T> &svm_ptr,
41                        const IndexExpr &expr)
42         : m_svm_ptr(svm_ptr),
43           m_expr(expr)
44     {
45     }
46 
operator Tboost::compute::detail::svm_ptr_index_expr47     operator T() const
48     {
49         BOOST_STATIC_ASSERT_MSG(boost::is_integral<IndexExpr>::value,
50                                 "Index expression must be integral");
51 
52         BOOST_ASSERT(m_svm_ptr.get());
53 
54         const context &context = m_svm_ptr.get_context();
55         const device &device = context.get_device();
56         command_queue queue(context, device);
57 
58         T value;
59         T* ptr =
60             static_cast<T*>(m_svm_ptr.get()) + static_cast<std::ptrdiff_t>(m_expr);
61         queue.enqueue_svm_map(static_cast<void*>(ptr), sizeof(T), CL_MAP_READ);
62         value = *(ptr);
63         queue.enqueue_svm_unmap(static_cast<void*>(ptr)).wait();
64 
65         return value;
66     }
67 
68     const svm_ptr<T> &m_svm_ptr;
69     IndexExpr m_expr;
70 };
71 
72 } // end detail namespace
73 #endif
74 
75 template<class T>
76 class svm_ptr
77 {
78 public:
79     typedef T value_type;
80     typedef std::ptrdiff_t difference_type;
81     typedef T* pointer;
82     typedef T& reference;
83     typedef std::random_access_iterator_tag iterator_category;
84 
svm_ptr()85     svm_ptr()
86         : m_ptr(0)
87     {
88     }
89 
svm_ptr(void * ptr,const context & context)90     svm_ptr(void *ptr, const context &context)
91         : m_ptr(static_cast<T*>(ptr)),
92           m_context(context)
93     {
94     }
95 
svm_ptr(const svm_ptr<T> & other)96     svm_ptr(const svm_ptr<T> &other)
97         : m_ptr(other.m_ptr),
98           m_context(other.m_context)
99     {
100     }
101 
operator =(const svm_ptr<T> & other)102     svm_ptr<T>& operator=(const svm_ptr<T> &other)
103     {
104         m_ptr = other.m_ptr;
105         m_context = other.m_context;
106         return *this;
107     }
108 
~svm_ptr()109     ~svm_ptr()
110     {
111     }
112 
get() const113     void* get() const
114     {
115         return m_ptr;
116     }
117 
operator +(difference_type n)118     svm_ptr<T> operator+(difference_type n)
119     {
120         return svm_ptr<T>(m_ptr + n, m_context);
121     }
122 
operator -(svm_ptr<T> other)123     difference_type operator-(svm_ptr<T> other)
124     {
125         BOOST_ASSERT(other.m_context == m_context);
126         return m_ptr - other.m_ptr;
127     }
128 
get_context() const129     const context& get_context() const
130     {
131         return m_context;
132     }
133 
operator ==(const svm_ptr<T> & other) const134     bool operator==(const svm_ptr<T>& other) const
135     {
136         return (other.m_context == m_context) && (m_ptr == other.m_ptr);
137     }
138 
operator !=(const svm_ptr<T> & other) const139     bool operator!=(const svm_ptr<T>& other) const
140     {
141         return (other.m_context != m_context) || (m_ptr != other.m_ptr);
142     }
143 
144     // svm functions require OpenCL 2.0
145     #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
146     /// \internal_
147     template<class Expr>
148     detail::svm_ptr_index_expr<T, Expr>
operator [](const Expr & expr) const149     operator[](const Expr &expr) const
150     {
151         BOOST_ASSERT(m_ptr);
152 
153         return detail::svm_ptr_index_expr<T, Expr>(*this,
154                                                    expr);
155     }
156     #endif
157 
158 private:
159     T *m_ptr;
160     context m_context;
161 };
162 
163 namespace detail {
164 
165 /// \internal_
166 template<class T>
167 struct set_kernel_arg<svm_ptr<T> >
168 {
operator ()boost::compute::detail::set_kernel_arg169     void operator()(kernel &kernel_, size_t index, const svm_ptr<T> &ptr)
170     {
171         kernel_.set_arg_svm_ptr(index, ptr.get());
172     }
173 };
174 
175 } // end detail namespace
176 
177 /// \internal_ (is_device_iterator specialization for svm_ptr)
178 template<class T>
179 struct is_device_iterator<svm_ptr<T> > : boost::true_type {};
180 
181 } // end compute namespace
182 } // end boost namespace
183 
184 #endif // BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
185