• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
18 #define MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
19 
20 #include <list>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include "actor/actor.h"
26 #include "async/uuid_base.h"
27 #include "async/future.h"
28 #include "async/async.h"
29 #include "mindrt/include/async/collect.h"
30 
31 namespace mindspore {
32 // OpActor data route.
33 struct DataArrow {
DataArrowDataArrow34   DataArrow(int from_output_index, const AID &to_op_id, int to_input_index)
35       : from_output_index_(from_output_index), to_op_id_(to_op_id), to_input_index_(to_input_index) {}
36   int from_output_index_;
37   AID to_op_id_;
38   int to_input_index_;
39 };
40 using DataArrowPtr = std::shared_ptr<DataArrow>;
41 
42 // OpActor data.
43 template <typename T>
44 struct OpData {
OpDataOpData45   OpData(const AID &op_id, T *data, int index) : op_id_(op_id), data_(data), index_(index) {}
46   AID op_id_;
47   T *data_;
48   int index_;
49 };
50 
51 class RandInt {
52  public:
Get()53   int Get() { return rand(); }
Instance()54   static RandInt &Instance() {
55     static RandInt instance;
56     return instance;
57   }
58 
59  private:
RandInt()60   RandInt() { srand(time(NULL)); }
61 };
62 
63 template <typename T>
64 using OpDataPtr = std::shared_ptr<OpData<T>>;
65 
66 template <typename T>
67 using OpDataUniquePtr = std::unique_ptr<OpData<T>>;
68 
69 // The context of opActor running.
70 template <typename T>
71 struct OpContext {
72   int sequential_num_;
73   std::vector<OpDataPtr<T>> *output_data_;
74   std::vector<Promise<int>> *results_;
75   const void *kernel_call_back_before_;
76   const void *kernel_call_back_after_;
77 
SetFailedOpContext78   void SetFailed(int32_t code) {
79     if (code == MindrtStatus::KINIT) {
80       code = MindrtStatus::KERROR;
81     }
82     for (auto promise : *results_) {
83       promise.SetFailed(code);
84     }
85   }
86 
SetSuccessOpContext87   void SetSuccess(int32_t code) {
88     for (auto promise : *results_) {
89       promise.SetValue(code);
90     }
91   }
92 
SetResultOpContext93   void SetResult(size_t index, int value) { results_->at(index).SetValue(value); }
94 };
95 
96 template <typename T>
97 class OpActor : public ActorBase {
98  public:
OpActor(std::string op_name)99   explicit OpActor(std::string op_name) : ActorBase(op_name) {}
100   virtual ~OpActor() = default;
101 
102   // The op actor run when receive the input data.
103   virtual void RunOpData(OpData<T> *input_data, OpContext<T> *context = nullptr) {}
104 
105   // The op actor run when receive the input control.
106   virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {}
107 
output_data_arrows()108   std::vector<DataArrowPtr> output_data_arrows() const { return output_data_arrows_; }
output_control_arrows()109   std::vector<AID> output_control_arrows() const { return output_control_arrows_; }
110 
111  protected:
112   // The op data.
113   std::unordered_map<int, std::vector<OpData<T> *>> input_op_datas_;
114   std::vector<DataArrowPtr> output_data_arrows_;
115 
116   // The op controls.
117   std::unordered_map<int, std::vector<AID *>> input_op_controls_;
118   std::vector<AID> output_control_arrows_;
119 };
120 
121 template <typename T>
MindrtAsyncRun(const std::vector<OpDataPtr<T>> & input_data,OpContext<T> * context)122 Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &input_data, OpContext<T> *context) {
123   std::list<Future<int>> futures;
124   for (auto promise : *(context->results_)) {
125     futures.push_back(promise.GetFuture());
126   }
127   Future<std::list<int>> collect = mindspore::Collect<int>(futures);
128 
129   for (auto data : input_data) {
130     Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data.get(), context);
131   }
132 
133   return collect;
134 }
135 
136 template <typename T>
MindrtRun(const std::vector<OpDataPtr<T>> & input_data,std::vector<OpDataPtr<T>> * output_data,const void * kernel_call_back_before,const void * kernel_call_back_after)137 int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr<T>> *output_data,
138               const void *kernel_call_back_before, const void *kernel_call_back_after) {
139   OpContext<T> context;
140   std::vector<Promise<int>> promises(output_data->size());
141   context.sequential_num_ = RandInt::Instance().Get();
142   context.results_ = &promises;
143   context.output_data_ = output_data;
144   context.kernel_call_back_before_ = kernel_call_back_before;
145   context.kernel_call_back_after_ = kernel_call_back_after;
146 
147   auto collect = MindrtAsyncRun<T>(input_data, &context);
148   collect.Wait();
149   if (!collect.IsOK()) {
150     return -1;
151   }
152 
153   return 0;
154 }
155 
156 }  // namespace mindspore
157 
158 #endif  // MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
159