1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
18 #define MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
19
20 #include <list>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include "actor/actor.h"
26 #include "async/uuid_base.h"
27 #include "async/future.h"
28 #include "async/async.h"
29 #include "mindrt/include/async/collect.h"
30
31 namespace mindspore {
32 // OpActor data route.
33 struct DataArrow {
DataArrowDataArrow34 DataArrow(int from_output_index, const AID &to_op_id, int to_input_index)
35 : from_output_index_(from_output_index), to_op_id_(to_op_id), to_input_index_(to_input_index) {}
36 int from_output_index_;
37 AID to_op_id_;
38 int to_input_index_;
39 };
40 using DataArrowPtr = std::shared_ptr<DataArrow>;
41
42 // OpActor data.
43 template <typename T>
44 struct OpData {
OpDataOpData45 OpData(const AID &op_id, T *data, int index) : op_id_(op_id), data_(data), index_(index) {}
46 AID op_id_;
47 T *data_;
48 int index_;
49 };
50
51 class RandInt {
52 public:
Get()53 int Get() { return rand(); }
Instance()54 static RandInt &Instance() {
55 static RandInt instance;
56 return instance;
57 }
58
59 private:
RandInt()60 RandInt() { srand(time(NULL)); }
61 };
62
63 template <typename T>
64 using OpDataPtr = std::shared_ptr<OpData<T>>;
65
66 template <typename T>
67 using OpDataUniquePtr = std::unique_ptr<OpData<T>>;
68
69 // The context of opActor running.
70 template <typename T>
71 struct OpContext {
72 int sequential_num_;
73 std::vector<OpDataPtr<T>> *output_data_;
74 std::vector<Promise<int>> *results_;
75 const void *kernel_call_back_before_;
76 const void *kernel_call_back_after_;
77
SetFailedOpContext78 void SetFailed(int32_t code) {
79 if (code == MindrtStatus::KINIT) {
80 code = MindrtStatus::KERROR;
81 }
82 for (auto promise : *results_) {
83 promise.SetFailed(code);
84 }
85 }
86
SetSuccessOpContext87 void SetSuccess(int32_t code) {
88 for (auto promise : *results_) {
89 promise.SetValue(code);
90 }
91 }
92
SetResultOpContext93 void SetResult(size_t index, int value) { results_->at(index).SetValue(value); }
94 };
95
96 template <typename T>
97 class OpActor : public ActorBase {
98 public:
OpActor(std::string op_name)99 explicit OpActor(std::string op_name) : ActorBase(op_name) {}
100 virtual ~OpActor() = default;
101
102 // The op actor run when receive the input data.
103 virtual void RunOpData(OpData<T> *input_data, OpContext<T> *context = nullptr) {}
104
105 // The op actor run when receive the input control.
106 virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {}
107
output_data_arrows()108 std::vector<DataArrowPtr> output_data_arrows() const { return output_data_arrows_; }
output_control_arrows()109 std::vector<AID> output_control_arrows() const { return output_control_arrows_; }
110
111 protected:
112 // The op data.
113 std::unordered_map<int, std::vector<OpData<T> *>> input_op_datas_;
114 std::vector<DataArrowPtr> output_data_arrows_;
115
116 // The op controls.
117 std::unordered_map<int, std::vector<AID *>> input_op_controls_;
118 std::vector<AID> output_control_arrows_;
119 };
120
121 template <typename T>
MindrtAsyncRun(const std::vector<OpDataPtr<T>> & input_data,OpContext<T> * context)122 Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &input_data, OpContext<T> *context) {
123 std::list<Future<int>> futures;
124 for (auto promise : *(context->results_)) {
125 futures.push_back(promise.GetFuture());
126 }
127 Future<std::list<int>> collect = mindspore::Collect<int>(futures);
128
129 for (auto data : input_data) {
130 Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data.get(), context);
131 }
132
133 return collect;
134 }
135
136 template <typename T>
MindrtRun(const std::vector<OpDataPtr<T>> & input_data,std::vector<OpDataPtr<T>> * output_data,const void * kernel_call_back_before,const void * kernel_call_back_after)137 int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr<T>> *output_data,
138 const void *kernel_call_back_before, const void *kernel_call_back_after) {
139 OpContext<T> context;
140 std::vector<Promise<int>> promises(output_data->size());
141 context.sequential_num_ = RandInt::Instance().Get();
142 context.results_ = &promises;
143 context.output_data_ = output_data;
144 context.kernel_call_back_before_ = kernel_call_back_before;
145 context.kernel_call_back_after_ = kernel_call_back_after;
146
147 auto collect = MindrtAsyncRun<T>(input_data, &context);
148 collect.Wait();
149 if (!collect.IsOK()) {
150 return -1;
151 }
152
153 return 0;
154 }
155
156 } // namespace mindspore
157
158 #endif // MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
159