• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ASYNC_COLLECT_H
18 #define MINDSPORE_CORE_MINDRT_INCLUDE_ASYNC_COLLECT_H
19 
20 #include <future>
21 #include <iostream>
22 #include <list>
23 #include <memory>
24 #include <tuple>
25 #include "async/common.h"
26 #include "async/future.h"
27 #include "async/defer.h"
28 #include "async/spinlock.h"
29 #include "actor/actor.h"
30 #include "mindrt/include/mindrt.hpp"
31 
32 namespace mindspore {
33 
34 template <typename T>
35 class Future;
36 
37 template <typename T>
38 class Promise;
39 
40 template <typename T>
41 class Collected;
42 
43 template <typename T>
44 class Collected {
45  public:
Collected(const std::list<Future<T>> & f,Promise<std::list<T>> * p)46   Collected(const std::list<Future<T>> &f, Promise<std::list<T>> *p) : futures(f), promise(p), ready(0) {}
47 
~Collected()48   virtual ~Collected() {
49     delete promise;
50     promise = nullptr;
51   }
52 
53   Collected(const Collected &) = delete;
54   Collected(Collected &&) = default;
55 
56   Collected &operator=(const Collected &) = delete;
57   Collected &operator=(Collected &&) = default;
58 
59  public:
Discarded()60   void Discarded() {
61     auto iter = futures.begin();
62     for (; iter != futures.end(); ++iter) {
63       iter->SetFailed(MindrtStatus::KERROR);
64     }
65   }
66 
Waited(const Future<T> & future)67   void Waited(const Future<T> &future) {
68     if (future.IsError()) {
69       promise->SetFailed(future.GetErrorCode());
70     } else if (future.IsOK()) {
71       ready.fetch_add(1);
72       if (ready.load() == futures.size()) {
73         std::list<T> values;
74         auto iter = futures.begin();
75         for (; iter != futures.end(); ++iter) {
76           values.push_back(iter->Get());
77         }
78         promise->SetValue(values);
79       }
80     }
81   }
82 
83  private:
84   const std::list<Future<T>> futures;
85   Promise<std::list<T>> *promise;
86   std::atomic_ulong ready;
87 };
88 
89 template <typename T>
Collect(const std::list<Future<T>> & futures)90 inline Future<std::list<T>> Collect(const std::list<Future<T>> &futures) {
91   if (futures.empty()) return Future<std::list<T>>(std::list<T>());
92 
93   Promise<std::list<T>> *promise = new (std::nothrow) Promise<std::list<T>>();
94   MINDRT_OOM_EXIT(promise);
95   std::shared_ptr<Collected<T>> collect = std::make_shared<Collected<T>>(futures, promise);
96 
97   for (auto iter = futures.begin(); iter != futures.end(); ++iter) {
98     iter->OnComplete(Defer(collect, &Collected<T>::Waited, std::placeholders::_1));
99   }
100 
101   Future<std::list<T>> future = promise->GetFuture();
102   future.OnComplete(Defer(collect, &Collected<T>::Discarded));
103 
104   return future;
105 }
106 
107 template <typename... Ts>
Collect(const Future<Ts> &...futures)108 Future<std::tuple<Ts...>> Collect(const Future<Ts> &... futures) {
109   std::list<Future<Nothing>> wrappers = {futures.Then([]() { return Nothing(); })...};
110 
111   auto f = [](const Future<Ts> &... futures) { return std::make_tuple(futures.Get()...); };
112 
113   return Collect(wrappers).Then(std::bind(f, futures...));
114 }
115 
116 };  // namespace mindspore
117 
118 #endif
119