• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 //          Copyright Oliver Kowalke 2017.
3 // Distributed under the Boost Software License, Version 1.0.
4 //    (See accompanying file LICENSE_1_0.txt or copy at
5 //          http://www.boost.org/LICENSE_1_0.txt)
6 
7 #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
8 #define BOOST_FIBERS_CUDA_WAITFOR_H
9 
10 #include <initializer_list>
11 #include <mutex>
12 #include <iostream>
13 #include <set>
14 #include <tuple>
15 #include <vector>
16 
17 #include <boost/assert.hpp>
18 #include <boost/config.hpp>
19 
20 #include <cuda.h>
21 
22 #include <boost/fiber/detail/config.hpp>
23 #include <boost/fiber/detail/is_all_same.hpp>
24 #include <boost/fiber/condition_variable.hpp>
25 #include <boost/fiber/mutex.hpp>
26 
27 #ifdef BOOST_HAS_ABI_HEADERS
28 #  include BOOST_ABI_PREFIX
29 #endif
30 
31 namespace boost {
32 namespace fibers {
33 namespace cuda {
34 namespace detail {
35 
36 template< typename Rendezvous >
trampoline(cudaStream_t st,cudaError_t status,void * vp)37 static void trampoline( cudaStream_t st, cudaError_t status, void * vp) {
38     Rendezvous * data = static_cast< Rendezvous * >( vp);
39     data->notify( st, status);
40 }
41 
42 class single_stream_rendezvous {
43 public:
single_stream_rendezvous(cudaStream_t st)44     single_stream_rendezvous( cudaStream_t st) {
45         unsigned int flags = 0;
46         cudaError_t status = ::cudaStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
47         if ( cudaSuccess != status) {
48             st_ = st;
49             status_ = status;
50             done_ = true;
51         }
52     }
53 
notify(cudaStream_t st,cudaError_t status)54     void notify( cudaStream_t st, cudaError_t status) noexcept {
55         std::unique_lock< mutex > lk{ mtx_ };
56         st_ = st;
57         status_ = status;
58         done_ = true;
59         lk.unlock();
60         cv_.notify_one();
61     }
62 
wait()63     std::tuple< cudaStream_t, cudaError_t > wait() {
64         std::unique_lock< mutex > lk{ mtx_ };
65         cv_.wait( lk, [this]{ return done_; });
66         return std::make_tuple( st_, status_);
67     }
68 
69 private:
70     mutex               mtx_{};
71     condition_variable  cv_{};
72     cudaStream_t        st_{};
73     cudaError_t         status_{ cudaErrorUnknown };
74     bool                done_{ false };
75 };
76 
77 class many_streams_rendezvous {
78 public:
many_streams_rendezvous(std::initializer_list<cudaStream_t> l)79     many_streams_rendezvous( std::initializer_list< cudaStream_t > l) :
80             stx_{ l } {
81         results_.reserve( stx_.size() );
82         for ( cudaStream_t st : stx_) {
83             unsigned int flags = 0;
84             cudaError_t status = ::cudaStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
85             if ( cudaSuccess != status) {
86                 std::unique_lock< mutex > lk{ mtx_ };
87                 stx_.erase( st);
88                 results_.push_back( std::make_tuple( st, status) );
89             }
90         }
91     }
92 
notify(cudaStream_t st,cudaError_t status)93     void notify( cudaStream_t st, cudaError_t status) noexcept {
94         std::unique_lock< mutex > lk{ mtx_ };
95         stx_.erase( st);
96         results_.push_back( std::make_tuple( st, status) );
97         if ( stx_.empty() ) {
98             lk.unlock();
99             cv_.notify_one();
100         }
101     }
102 
wait()103     std::vector< std::tuple< cudaStream_t, cudaError_t > > wait() {
104         std::unique_lock< mutex > lk{ mtx_ };
105         cv_.wait( lk, [this]{ return stx_.empty(); });
106         return results_;
107     }
108 
109 private:
110     mutex                                                   mtx_{};
111     condition_variable                                      cv_{};
112     std::set< cudaStream_t >                                stx_;
113     std::vector< std::tuple< cudaStream_t, cudaError_t > >  results_;
114 };
115 
116 }
117 
118 void waitfor_all();
119 
120 inline
waitfor_all(cudaStream_t st)121 std::tuple< cudaStream_t, cudaError_t > waitfor_all( cudaStream_t st) {
122     detail::single_stream_rendezvous rendezvous( st);
123     return rendezvous.wait();
124 }
125 
126 template< typename ... STP >
waitfor_all(cudaStream_t st0,STP...stx)127 std::vector< std::tuple< cudaStream_t, cudaError_t > > waitfor_all( cudaStream_t st0, STP ... stx) {
128     static_assert( boost::fibers::detail::is_all_same< cudaStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
129     detail::many_streams_rendezvous rendezvous{ st0, stx ... };
130     return rendezvous.wait();
131 }
132 
133 }}}
134 
135 #ifdef BOOST_HAS_ABI_HEADERS
136 #  include BOOST_ABI_SUFFIX
137 #endif
138 
139 #endif // BOOST_FIBERS_CUDA_WAITFOR_H
140