1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H 18 #define MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H 19 20 #include <list> 21 #include <vector> 22 #include <memory> 23 #include <string> 24 #include "utils/hash_map.h" 25 #include "mindapi/base/macros.h" 26 #include "actor/actor.h" 27 #include "async/uuid_base.h" 28 #include "async/future.h" 29 #include "async/async.h" 30 #include "mindrt/include/async/collect.h" 31 32 namespace mindspore { 33 // OpActor data route. 34 struct DataArrow { DataArrowDataArrow35 DataArrow(int from_output_index, const AID &to_op_id, int to_input_index) 36 : from_output_index_(from_output_index), to_op_id_(to_op_id), to_input_index_(to_input_index), flag_{0} {} 37 int from_output_index_; 38 AID to_op_id_; 39 int to_input_index_; 40 // Used to indicate the attribute of data arrow. 41 size_t flag_; 42 }; 43 using DataArrowPtr = std::shared_ptr<DataArrow>; 44 45 // OpActor control route. 46 struct ControlArrow { ControlArrowControlArrow47 explicit ControlArrow(const AID &to_op_id) : to_op_id_(to_op_id), flag_{0} {} 48 AID to_op_id_; 49 // Used to indicate the attribute of control arrow. 50 size_t flag_; 51 }; 52 using ControlArrowPtr = std::shared_ptr<ControlArrow>; 53 54 // OpActor data. 55 template <typename T> 56 struct OpData { OpDataOpData57 OpData(const AID &op_id, T *data, int index) : op_id_(op_id), data_(data), index_(index) {} 58 AID op_id_; 59 T *data_; 60 int index_; 61 }; 62 63 class MS_CORE_API RandInt { 64 public: Get()65 int Get() const { return rand(); } 66 static RandInt &Instance(); 67 68 private: RandInt()69 RandInt() { srand(static_cast<unsigned int>(time(nullptr))); } 70 }; 71 72 template <typename T> 73 using OpDataPtr = std::shared_ptr<OpData<T>>; 74 75 template <typename T> 76 using OpDataUniquePtr = std::unique_ptr<OpData<T>>; 77 78 // The context of opActor running. 79 template <typename T> 80 struct OpContext { 81 int sequential_num_; 82 std::vector<OpDataPtr<T>> *output_data_; 83 std::vector<Promise<int>> *results_; 84 // Record the error info for print. 85 std::string error_info_{""}; 86 const void *kernel_call_back_before_; 87 const void *kernel_call_back_after_; 88 SetFailedOpContext89 void SetFailed(int32_t code) const { 90 if (code == MindrtStatus::KINIT) { 91 code = MindrtStatus::KERROR; 92 } 93 results_->front().SetFailed(code); 94 } 95 SetSuccessOpContext96 void SetSuccess(int32_t code) const { 97 for (auto promise : *results_) { 98 promise.SetValue(code); 99 } 100 } 101 SetResultOpContext102 void SetResult(size_t index, int value) const { results_->at(index).SetValue(value); } 103 }; 104 105 template <typename T> 106 class OpActor : public ActorBase { 107 public: OpActor(const std::string & op_name)108 explicit OpActor(const std::string &op_name) : ActorBase(op_name) {} 109 ~OpActor() override = default; 110 111 // The op actor run when receive the input data. 112 virtual void RunOpData(OpData<T> *input_data, OpContext<T> *context = nullptr) {} 113 114 // The op actor run when receive the input control. 115 virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {} 116 output_data_arrows()117 const std::vector<DataArrowPtr> &output_data_arrows() const { return output_data_arrows_; } output_control_arrows()118 const std::vector<ControlArrowPtr> &output_control_arrows() const { return output_control_arrows_; } 119 120 protected: 121 // The op data. 122 mindspore::HashMap<int, std::vector<OpData<T> *>> input_op_datas_; 123 std::vector<DataArrowPtr> output_data_arrows_; 124 125 // The op controls. 126 mindspore::HashMap<int, std::vector<AID *>> input_op_controls_; 127 std::vector<ControlArrowPtr> output_control_arrows_; 128 }; 129 } // namespace mindspore 130 131 #endif // MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H 132