1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/control_flow/actor/exit_actor.h"
18 #include <algorithm>
19 #include "src/control_flow/kernel/exit_subgraph_kernel.h"
20 #include "src/litert/kernel_exec_util.h"
21 #include "src/common/tensor_util.h"
22
23 namespace {
24 const constexpr int kEntranceTensorIndex = 0;
25 }
26 namespace mindspore::lite {
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)27 void LiteExitOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
28 auto op_uuid = context->sequential_num_;
29 input_op_datas_[op_uuid].push_back(inputs);
30 inputs_data_[inputs->index_] = inputs->data_;
31 SetEntranceInputAID(inputs);
32 if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
33 return;
34 }
35
36 auto ret = InitInputData();
37 (void)input_op_datas_.erase(op_uuid);
38 if (ret != RET_OK) {
39 context->SetFailed(ret);
40 return;
41 }
42 AsyncOutput(context);
43 return;
44 }
45
InitInputData()46 int LiteExitOpActor::InitInputData() {
47 auto ret = SetInputShape();
48
49 for (size_t i = 1; i < inputs_data_.size(); ++i) {
50 auto dst_tensor = kernel_->out_tensors()[i - 1];
51 auto src_tensor = inputs_data_[i];
52 dst_tensor->set_data_type(src_tensor->data_type());
53 if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
54 (void)SetTensorData(dst_tensor, src_tensor);
55 } else {
56 (void)MoveTensorData(dst_tensor, src_tensor);
57 }
58 }
59 return ret;
60 }
61
SetInputShape()62 int LiteExitOpActor::SetInputShape() {
63 auto ret = RET_OK;
64 for (size_t i = 1; i < inputs_data_.size(); ++i) {
65 auto &output_tensor = kernel_->out_tensors()[i - 1];
66 if (output_tensor->shape() == inputs_data_[i]->shape()) {
67 continue;
68 }
69 ret = SetTensorShape(output_tensor, inputs_data_[i]);
70 MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
71 }
72 return RET_OK;
73 }
74
SetEntranceInputAID(const OpData<Tensor> * inputs)75 void LiteExitOpActor::SetEntranceInputAID(const OpData<Tensor> *inputs) {
76 if (inputs->index_ == kEntranceTensorIndex) {
77 entrance_input_aid_ = inputs->op_id_;
78 }
79 }
80
PrepareOutputData()81 int LiteExitOpActor::PrepareOutputData() {
82 // exit actor has not calculating, so send input directly.
83 outputs_data_.resize(output_data_arrows_.size());
84 for (size_t i = 0; i < output_data_arrows_.size(); i++) {
85 auto &arrow = output_data_arrows_[i];
86 auto data = std::make_shared<OpData<Tensor>>(this->GetAID(), (kernel_->out_tensors()).at(arrow->from_output_index_),
87 static_cast<int>(arrow->to_input_index_));
88 if (data == nullptr) {
89 MS_LOG(ERROR) << "new output_data failed.";
90 return RET_NULL_PTR;
91 }
92 outputs_data_.at(i) = data;
93 }
94 return RET_OK;
95 }
96
AsyncOutput(OpContext<Tensor> * context)97 void LiteExitOpActor::AsyncOutput(OpContext<Tensor> *context) {
98 AID to_op_id;
99 bool find_to_op_aid = false;
100 for (auto info : all_mapping_info_) {
101 if (info.partial_input_aid == entrance_input_aid_) {
102 find_to_op_aid = true;
103 to_op_id = info.call_output_aid;
104 }
105 }
106
107 if (!find_to_op_aid) {
108 MS_LOG(ERROR) << "exit actor can not find output actor.";
109 context->SetFailed(RET_ERROR);
110 return;
111 }
112
113 if (to_op_id.Name() == "") {
114 SetOutputData(context);
115 }
116
117 for (size_t i = 0; i < output_data_arrows_.size(); i++) {
118 if (output_data_arrows_[i]->to_op_id_ != to_op_id && output_data_arrows_[i]->to_op_id_.Name() != "") {
119 continue;
120 }
121 auto data = outputs_data_.at(i);
122 Async(to_op_id, get_actor_mgr(), &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
123 }
124 }
125
PreInit(std::vector<std::shared_ptr<LiteOpActor>> * actors,std::unordered_map<Tensor *,Tensor * > * input_map)126 int LiteExitOpActor::PreInit(std::vector<std::shared_ptr<LiteOpActor>> *actors,
127 std::unordered_map<Tensor *, Tensor *> *input_map) {
128 auto ret = IsolateInputData(actors, input_map);
129 if (ret != RET_OK) {
130 MS_LOG(ERROR) << "isolate input data failed.";
131 return ret;
132 }
133
134 ret = CreateMappingInfo();
135 if (ret != RET_OK) {
136 MS_LOG(ERROR) << "create partial call pairs failed.";
137 return ret;
138 }
139
140 ret = RecordCallNodeOutputActor(actors);
141 if (ret != RET_OK) {
142 MS_LOG(ERROR) << "record call node outputs AIDs failed";
143 return ret;
144 }
145 return RET_OK;
146 }
147
IsSubSet(const std::vector<lite::Tensor * > & all_set,const std::vector<lite::Tensor * > & sub_set)148 bool LiteExitOpActor::IsSubSet(const std::vector<lite::Tensor *> &all_set, const std::vector<lite::Tensor *> &sub_set) {
149 if (sub_set.size() > all_set.size()) {
150 return false;
151 }
152 for (auto &sub_item : sub_set) {
153 if (std::find(all_set.begin(), all_set.end(), sub_item) == all_set.end()) {
154 return false;
155 }
156 }
157 return true;
158 }
159
RecordCallNodeOutputActor(std::vector<std::shared_ptr<LiteOpActor>> * actors)160 int LiteExitOpActor::RecordCallNodeOutputActor(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
161 actors_ = actors;
162 for (auto actor : *actors_) {
163 auto actor_in_tensors = actor->GetKernel()->in_tensors();
164 for (auto &info : all_mapping_info_) {
165 auto &call = info.call_node;
166 if (IsSubSet(actor_in_tensors, call->out_tensors())) {
167 info.call_output_aid = actor->GetAID();
168 }
169 }
170 }
171 return RET_OK;
172 }
173
CreateMappingInfo()174 int LiteExitOpActor::CreateMappingInfo() {
175 auto exit_subgraph_kernel = reinterpret_cast<kernel::ExitSubGraphKernel *>(kernel_);
176 if (exit_subgraph_kernel == nullptr) {
177 MS_LOG(ERROR) << "cast to exit kernel failed.";
178 return RET_ERROR;
179 }
180 auto partial_set = exit_subgraph_kernel->GetPartials();
181 for (auto partial : partial_set) {
182 auto call_node = kernel::KernelExecUtil::GetPartialOutputCall(partial);
183 if (call_node == nullptr) {
184 MS_LOG(ERROR) << "get partial node: " << partial->name() << " 's call output node failed.";
185 return RET_ERROR;
186 }
187 MappingInfo info(partial, call_node);
188 (void)all_mapping_info_.emplace_back(info);
189 }
190 return RET_OK;
191 }
192
PostInit()193 int LiteExitOpActor::PostInit() {
194 auto ret = PrepareOutputData();
195 if (ret != RET_OK) {
196 MS_LOG(ERROR) << "prepare output data failed.";
197 return ret;
198 }
199
200 RecordPartialNodeInputActor();
201 return RET_OK;
202 }
203
RecordPartialNodeInputActor()204 void LiteExitOpActor::RecordPartialNodeInputActor() {
205 for (auto actor : *actors_) {
206 auto actor_partial_nodes = actor->GetPartialKernels();
207 if (actor_partial_nodes.empty()) {
208 continue;
209 }
210 for (auto &info : all_mapping_info_) {
211 auto partial = info.partial_node;
212 if (actor_partial_nodes.find(partial) == actor_partial_nodes.end()) {
213 continue;
214 }
215 info.partial_input_aid = actor->GetAID();
216 }
217 }
218 }
219 } // namespace mindspore::lite
220