• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
18 #define MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
19 
20 #include <list>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include "utils/hash_map.h"
25 #include "mindapi/base/macros.h"
26 #include "actor/actor.h"
27 #include "async/uuid_base.h"
28 #include "async/future.h"
29 #include "async/async.h"
30 #include "mindrt/include/async/collect.h"
31 
32 namespace mindspore {
33 // OpActor data route.
34 struct DataArrow {
DataArrowDataArrow35   DataArrow(int from_output_index, const AID &to_op_id, int to_input_index)
36       : from_output_index_(from_output_index), to_op_id_(to_op_id), to_input_index_(to_input_index), flag_{0} {}
37   int from_output_index_;
38   AID to_op_id_;
39   int to_input_index_;
40   // Used to indicate the attribute of data arrow.
41   size_t flag_;
42 };
43 using DataArrowPtr = std::shared_ptr<DataArrow>;
44 
45 // OpActor control route.
46 struct ControlArrow {
ControlArrowControlArrow47   explicit ControlArrow(const AID &to_op_id) : to_op_id_(to_op_id), flag_{0} {}
48   AID to_op_id_;
49   // Used to indicate the attribute of control arrow.
50   size_t flag_;
51 };
52 using ControlArrowPtr = std::shared_ptr<ControlArrow>;
53 
54 // OpActor data.
55 template <typename T>
56 struct OpData {
OpDataOpData57   OpData(const AID &op_id, T *data, int index) : op_id_(op_id), data_(data), index_(index) {}
58   AID op_id_;
59   T *data_;
60   int index_;
61 };
62 
63 class MS_CORE_API RandInt {
64  public:
Get()65   int Get() const { return rand(); }
66   static RandInt &Instance();
67 
68  private:
RandInt()69   RandInt() { srand(static_cast<unsigned int>(time(nullptr))); }
70 };
71 
72 template <typename T>
73 using OpDataPtr = std::shared_ptr<OpData<T>>;
74 
75 template <typename T>
76 using OpDataUniquePtr = std::unique_ptr<OpData<T>>;
77 
78 // The context of opActor running.
79 template <typename T>
80 struct OpContext {
81   int sequential_num_;
82   std::vector<OpDataPtr<T>> *output_data_;
83   std::vector<Promise<int>> *results_;
84   // Record the error info for print.
85   std::string error_info_{""};
86   const void *kernel_call_back_before_;
87   const void *kernel_call_back_after_;
88 
SetFailedOpContext89   void SetFailed(int32_t code) const {
90     if (code == MindrtStatus::KINIT) {
91       code = MindrtStatus::KERROR;
92     }
93     results_->front().SetFailed(code);
94   }
95 
SetSuccessOpContext96   void SetSuccess(int32_t code) const {
97     for (auto promise : *results_) {
98       promise.SetValue(code);
99     }
100   }
101 
SetResultOpContext102   void SetResult(size_t index, int value) const { results_->at(index).SetValue(value); }
103 };
104 
105 template <typename T>
106 class OpActor : public ActorBase {
107  public:
OpActor(const std::string & op_name)108   explicit OpActor(const std::string &op_name) : ActorBase(op_name) {}
109   ~OpActor() override = default;
110 
111   // The op actor run when receive the input data.
112   virtual void RunOpData(OpData<T> *input_data, OpContext<T> *context = nullptr) {}
113 
114   // The op actor run when receive the input control.
115   virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {}
116 
output_data_arrows()117   const std::vector<DataArrowPtr> &output_data_arrows() const { return output_data_arrows_; }
output_control_arrows()118   const std::vector<ControlArrowPtr> &output_control_arrows() const { return output_control_arrows_; }
119 
120  protected:
121   // The op data.
122   mindspore::HashMap<int, std::vector<OpData<T> *>> input_op_datas_;
123   std::vector<DataArrowPtr> output_data_arrows_;
124 
125   // The op controls.
126   mindspore::HashMap<int, std::vector<AID *>> input_op_controls_;
127   std::vector<ControlArrowPtr> output_control_arrows_;
128 };
129 }  // namespace mindspore
130 
131 #endif  // MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
132