1
2 // Copyright Oliver Kowalke 2017.
3 // Distributed under the Boost Software License, Version 1.0.
4 // (See accompanying file LICENSE_1_0.txt or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6
7 #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
8 #define BOOST_FIBERS_CUDA_WAITFOR_H
9
10 #include <initializer_list>
11 #include <mutex>
12 #include <iostream>
13 #include <set>
14 #include <tuple>
15 #include <vector>
16
17 #include <boost/assert.hpp>
18 #include <boost/config.hpp>
19
20 #include <cuda.h>
21
22 #include <boost/fiber/detail/config.hpp>
23 #include <boost/fiber/detail/is_all_same.hpp>
24 #include <boost/fiber/condition_variable.hpp>
25 #include <boost/fiber/mutex.hpp>
26
27 #ifdef BOOST_HAS_ABI_HEADERS
28 # include BOOST_ABI_PREFIX
29 #endif
30
31 namespace boost {
32 namespace fibers {
33 namespace cuda {
34 namespace detail {
35
36 template< typename Rendezvous >
trampoline(cudaStream_t st,cudaError_t status,void * vp)37 static void trampoline( cudaStream_t st, cudaError_t status, void * vp) {
38 Rendezvous * data = static_cast< Rendezvous * >( vp);
39 data->notify( st, status);
40 }
41
42 class single_stream_rendezvous {
43 public:
single_stream_rendezvous(cudaStream_t st)44 single_stream_rendezvous( cudaStream_t st) {
45 unsigned int flags = 0;
46 cudaError_t status = ::cudaStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
47 if ( cudaSuccess != status) {
48 st_ = st;
49 status_ = status;
50 done_ = true;
51 }
52 }
53
notify(cudaStream_t st,cudaError_t status)54 void notify( cudaStream_t st, cudaError_t status) noexcept {
55 std::unique_lock< mutex > lk{ mtx_ };
56 st_ = st;
57 status_ = status;
58 done_ = true;
59 lk.unlock();
60 cv_.notify_one();
61 }
62
wait()63 std::tuple< cudaStream_t, cudaError_t > wait() {
64 std::unique_lock< mutex > lk{ mtx_ };
65 cv_.wait( lk, [this]{ return done_; });
66 return std::make_tuple( st_, status_);
67 }
68
69 private:
70 mutex mtx_{};
71 condition_variable cv_{};
72 cudaStream_t st_{};
73 cudaError_t status_{ cudaErrorUnknown };
74 bool done_{ false };
75 };
76
77 class many_streams_rendezvous {
78 public:
many_streams_rendezvous(std::initializer_list<cudaStream_t> l)79 many_streams_rendezvous( std::initializer_list< cudaStream_t > l) :
80 stx_{ l } {
81 results_.reserve( stx_.size() );
82 for ( cudaStream_t st : stx_) {
83 unsigned int flags = 0;
84 cudaError_t status = ::cudaStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
85 if ( cudaSuccess != status) {
86 std::unique_lock< mutex > lk{ mtx_ };
87 stx_.erase( st);
88 results_.push_back( std::make_tuple( st, status) );
89 }
90 }
91 }
92
notify(cudaStream_t st,cudaError_t status)93 void notify( cudaStream_t st, cudaError_t status) noexcept {
94 std::unique_lock< mutex > lk{ mtx_ };
95 stx_.erase( st);
96 results_.push_back( std::make_tuple( st, status) );
97 if ( stx_.empty() ) {
98 lk.unlock();
99 cv_.notify_one();
100 }
101 }
102
wait()103 std::vector< std::tuple< cudaStream_t, cudaError_t > > wait() {
104 std::unique_lock< mutex > lk{ mtx_ };
105 cv_.wait( lk, [this]{ return stx_.empty(); });
106 return results_;
107 }
108
109 private:
110 mutex mtx_{};
111 condition_variable cv_{};
112 std::set< cudaStream_t > stx_;
113 std::vector< std::tuple< cudaStream_t, cudaError_t > > results_;
114 };
115
116 }
117
118 void waitfor_all();
119
120 inline
waitfor_all(cudaStream_t st)121 std::tuple< cudaStream_t, cudaError_t > waitfor_all( cudaStream_t st) {
122 detail::single_stream_rendezvous rendezvous( st);
123 return rendezvous.wait();
124 }
125
126 template< typename ... STP >
waitfor_all(cudaStream_t st0,STP...stx)127 std::vector< std::tuple< cudaStream_t, cudaError_t > > waitfor_all( cudaStream_t st0, STP ... stx) {
128 static_assert( boost::fibers::detail::is_all_same< cudaStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
129 detail::many_streams_rendezvous rendezvous{ st0, stx ... };
130 return rendezvous.wait();
131 }
132
133 }}}
134
135 #ifdef BOOST_HAS_ABI_HEADERS
136 # include BOOST_ABI_SUFFIX
137 #endif
138
139 #endif // BOOST_FIBERS_CUDA_WAITFOR_H
140