• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #ifndef FFRT_API_CPP_FUTURE_H
16 #define FFRT_API_CPP_FUTURE_H
17 #include <memory>
18 #include <optional>
19 #include <chrono>
20 #include <cassert>
21 #include "cpp/condition_variable.h"
22 #include "cpp/thread.h"
23 
24 namespace ffrt {
25 struct non_copyable {
26 protected:
27     non_copyable() = default;
28     ~non_copyable() = default;
29     non_copyable(const non_copyable&) = delete;
30     non_copyable& operator=(const non_copyable&) = delete;
31 };
32 enum class future_status { ready, timeout, deferred };
33 
34 namespace detail {
35 template <typename Derived>
36 struct shared_state_base : private non_copyable {
waitshared_state_base37     void wait() const noexcept
38     {
39         std::unique_lock lk(this->m_mtx);
40         wait_(lk);
41     }
42 
43     template <typename Rep, typename Period>
wait_forshared_state_base44     future_status wait_for(const std::chrono::duration<Rep, Period>& waitTime) const noexcept
45     {
46         std::unique_lock<mutex> lk(m_mtx);
47         return m_cv.wait_for(lk, waitTime, [this] { return get_derived().has_value(); }) ? future_status::ready :
48             future_status::timeout;
49     }
50 
51     template <typename Clock, typename Duration>
wait_untilshared_state_base52     future_status wait_until(const std::chrono::time_point<Clock, Duration>& tp) const noexcept
53     {
54         std::unique_lock<mutex> lk(m_mtx);
55         return m_cv.wait_until(lk, tp, [this] { return get_derived().has_value(); }) ? future_status::ready :
56             future_status::timeout;
57     }
58 
59 protected:
wait_shared_state_base60     void wait_(std::unique_lock<mutex>& lk) const noexcept
61     {
62         assert(lk.owns_lock());
63         m_cv.wait(lk, [this] { return get_derived().has_value(); });
64     }
65 
66     mutable mutex m_mtx;
67     mutable condition_variable m_cv;
68 
69 private:
get_derivedshared_state_base70     const Derived& get_derived() const
71     {
72         return *static_cast<const Derived*>(this);
73     }
74 };
75 
76 template <typename R>
77 struct shared_state : shared_state_base<shared_state<R>> {
set_valueshared_state78     void set_value(const R& value) noexcept
79     {
80         {
81             std::unique_lock<mutex> lk(this->m_mtx);
82             assert(!m_res.has_value());
83             m_res.emplace(value);
84         }
85         this->m_cv.notify_all();
86     }
87 
set_valueshared_state88     void set_value(R&& value) noexcept
89     {
90         {
91             std::unique_lock<mutex> lk(this->m_mtx);
92             assert(!m_res.has_value());
93             m_res.emplace(std::move(value));
94         }
95         this->m_cv.notify_all();
96     }
97 
getshared_state98     R& get() noexcept
99     {
100         std::unique_lock lk(this->m_mtx);
101         this->wait_(lk);
102         assert(m_res.has_value());
103         return m_res.value();
104     }
105 
has_valueshared_state106     bool has_value() const noexcept
107     {
108         return m_res.has_value();
109     }
110 
111 private:
112     std::optional<R> m_res;
113 };
114 
115 template <>
116 struct shared_state<void> : shared_state_base<shared_state<void>> {
117     void set_value() noexcept
118     {
119         {
120             std::unique_lock<mutex> lk(this->m_mtx);
121             assert(!m_hasValue);
122             m_hasValue = true;
123         }
124         this->m_cv.notify_all();
125     }
126 
127     void get() noexcept
128     {
129         std::unique_lock lk(this->m_mtx);
130         this->wait_(lk);
131         assert(m_hasValue);
132     }
133 
134     bool has_value() const noexcept
135     {
136         return m_hasValue;
137     }
138 
139 private:
140     bool m_hasValue {false};
141 };
142 }; // namespace detail
143 
144 template <typename R>
145 class future : private non_copyable {
146     template <typename>
147     friend struct promise;
148 
149     template <typename>
150     friend struct packaged_task;
151 
152 public:
153     explicit future(const std::shared_ptr<detail::shared_state<R>>& state) noexcept : m_state(state)
154     {
155     }
156 
157     future() noexcept = default;
158 
159     future(future&& fut) noexcept
160     {
161         swap(fut);
162     }
163     future& operator=(future&& fut) noexcept
164     {
165         if (this != &fut) {
166             future tmp(std::move(fut));
167             swap(tmp);
168         }
169         return *this;
170     }
171 
172     bool valid() const noexcept
173     {
174         return m_state != nullptr;
175     }
176 
177     R get() noexcept
178     {
179         assert(valid());
180         auto tmp = std::move(m_state);
181         if constexpr(!std::is_void_v<R>) {
182             return std::move(tmp->get());
183         } else {
184             return tmp->get();
185         }
186     }
187 
188     template <typename Rep, typename Period>
189     future_status wait_for(const std::chrono::duration<Rep, Period>& waitTime) const noexcept
190     {
191         assert(valid());
192         return m_state->wait_for(waitTime);
193     }
194 
195     template <typename Clock, typename Duration>
196     future_status wait_until(const std::chrono::time_point<Clock, Duration>& tp) const noexcept
197     {
198         assert(valid());
199         return m_state->wait_until(tp);
200     }
201 
202     void wait() const noexcept
203     {
204         assert(valid());
205         m_state->wait();
206     }
207 
208     void swap(future<R>& rhs) noexcept
209     {
210         std::swap(m_state, rhs.m_state);
211     }
212 
213 private:
214     std::shared_ptr<detail::shared_state<R>> m_state;
215 };
216 
217 template <typename R>
218 struct promise : private non_copyable {
219     promise() noexcept : m_state {std::make_shared<detail::shared_state<R>>()}
220     {
221     }
222     promise(promise&& p) noexcept
223     {
224         swap(p);
225     }
226     promise& operator=(promise&& p) noexcept
227     {
228         if (this != &p) {
229             promise tmp(std::move(p));
230             swap(tmp);
231         }
232         return *this;
233     }
234 
235     void set_value(const R& value) noexcept
236     {
237         m_state->set_value(value);
238     }
239 
240     void set_value(R&& value) noexcept
241     {
242         m_state->set_value(std::move(value));
243     }
244 
245     future<R> get_future() noexcept
246     {
247         assert(m_state.use_count() == 1);
248         return future<R> {m_state};
249     }
250 
251     void swap(promise<R>& rhs) noexcept
252     {
253         std::swap(m_state, rhs.m_state);
254     }
255 
256 private:
257     std::shared_ptr<detail::shared_state<R>> m_state;
258 };
259 
260 template <>
261 struct promise<void> : private non_copyable {
262     promise() noexcept : m_state {std::make_shared<detail::shared_state<void>>()}
263     {
264     }
265     promise(promise&& p) noexcept
266     {
267         swap(p);
268     }
269     promise& operator=(promise&& p) noexcept
270     {
271         if (this != &p) {
272             promise tmp(std::move(p));
273             swap(tmp);
274         }
275         return *this;
276     }
277 
278     void set_value() noexcept
279     {
280         m_state->set_value();
281     }
282 
283     future<void> get_future() noexcept
284     {
285         assert(m_state.use_count() == 1);
286         return future<void> {m_state};
287     }
288 
289     void swap(promise<void>& rhs) noexcept
290     {
291         std::swap(m_state, rhs.m_state);
292     }
293 
294 private:
295     std::shared_ptr<detail::shared_state<void>> m_state;
296 };
297 
298 template <typename F>
299 struct packaged_task;
300 
301 template <typename R, typename... Args>
302 struct packaged_task<R(Args...)> {
303     packaged_task() noexcept = default;
304 
305     packaged_task(const packaged_task& pt) noexcept
306     {
307         m_fn = pt.m_fn;
308         m_state = pt.m_state;
309     }
310 
311     packaged_task(packaged_task&& pt) noexcept
312     {
313         swap(pt);
314     }
315 
316     packaged_task& operator=(packaged_task&& pt) noexcept
317     {
318         if (this != &pt) {
319             packaged_task tmp(std::move(pt));
320             swap(tmp);
321         }
322         return *this;
323     }
324 
325     template <typename F>
326     explicit packaged_task(F&& f) noexcept
327         : m_fn {std::forward<F>(f)}, m_state {std::make_shared<detail::shared_state<R>>()}
328     {
329     }
330 
331     bool valid() const noexcept
332     {
333         return bool(m_fn) && m_state != nullptr;
334     }
335 
336     future<R> get_future() noexcept
337     {
338         assert(m_state.use_count() == 1);
339         return future<R> {m_state};
340     }
341 
342     void operator()(Args... args)
343     {
344         assert(valid());
345         if constexpr(!std::is_void_v<R>) {
346             m_state->set_value(m_fn(std::forward<Args>(args)...));
347         } else {
348             m_fn(std::forward<Args>(args)...);
349             m_state->set_value();
350         }
351     }
352 
353     void swap(packaged_task& pt) noexcept
354     {
355         std::swap(m_fn, pt.m_fn);
356         std::swap(m_state, pt.m_state);
357     }
358 
359 private:
360     std::function<R(Args...)> m_fn;
361     std::shared_ptr<detail::shared_state<R>> m_state;
362 };
363 
364 template <typename F, typename... Args>
365 future<std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>> async(F&& f, Args&& ... args)
366 {
367     using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
368     packaged_task<R(std::decay_t<Args>...)> pt {std::forward<F>(f)};
369     auto fut {pt.get_future()};
370     auto th = ffrt::thread(std::move(pt), std::forward<Args>(args)...);
371     th.detach();
372     return fut;
373 }
374 } // namespace ffrt
375 #endif
376