1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/control_flow/actor/switch_actor.h"
18 #include <utility>
19 #include <algorithm>
20 #include <set>
21 #include <vector>
22 #include "mindrt/include/mindrt.hpp"
23 #include "src/litert/kernel_exec_util.h"
24 #include "src/common/tensor_util.h"
25 #include "src/litert/inner_allocator.h"
26 #ifdef ENABLE_FP16
27 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
28 #endif
29 namespace {
30 const constexpr int kSwitchMaxInputKernelSize = 3;
31 const constexpr int kSwitchMinInputKernelSize = 2;
32 const constexpr int kSwitchTruePartialInputIndex = 1;
33 const constexpr int kSwitchFalsePartialInputIndex = 2;
34 const constexpr int kSwitchCondTensorIndex = 0;
35 } // namespace
36
37 namespace mindspore::lite {
SetSwitchPartialNodes()38 int LiteSwitchOpActor::SetSwitchPartialNodes() {
39 auto switch_op_input_kernel_size = switch_type_node_->in_kernels().size();
40 // special case, switch cond input is const, should be removed in the future.
41 if (switch_op_input_kernel_size == kSwitchMinInputKernelSize) {
42 // reverse switch node input, then false cast to 0, true cast to 1, which is same as switch layer index.
43 partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchFalsePartialInputIndex - 1));
44 partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchTruePartialInputIndex - 1));
45 return RET_OK;
46 }
47
48 if (switch_op_input_kernel_size == kSwitchMaxInputKernelSize) {
49 // reverse switch node input.
50 partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchFalsePartialInputIndex));
51 partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchTruePartialInputIndex));
52 return RET_OK;
53 }
54 MS_LOG(ERROR) << "switch op input kernel size: " << switch_op_input_kernel_size << ", which is not support.";
55 return RET_ERROR;
56 }
57
SetSwitchLayerPartialNodes()58 int LiteSwitchOpActor::SetSwitchLayerPartialNodes() {
59 for (size_t i = 1; i < switch_type_node_->in_kernels().size(); ++i) {
60 partial_nodes_.push_back(switch_type_node_->in_kernels()[i]);
61 }
62 return RET_OK;
63 }
64
GetSwitchAndCallNode(kernel::SubGraphKernel * subgraph_kernel)65 int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) {
66 for (auto &node : subgraph_kernel->nodes()) {
67 if (node->type() != schema::PrimitiveType_Call) {
68 continue;
69 }
70 call_node_ = node;
71 auto switch_node = kernel::KernelExecUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch);
72 auto switch_layer_node = kernel::KernelExecUtil::GetInputsSpecificNode(node, schema::PrimitiveType_SwitchLayer);
73 if (switch_node != nullptr) {
74 switch_type_node_ = switch_node;
75 return SetSwitchPartialNodes();
76 }
77 if (switch_layer_node != nullptr) {
78 switch_type_node_ = switch_layer_node;
79 return SetSwitchLayerPartialNodes();
80 }
81 }
82 return RET_OK;
83 }
84
AppendOutputTensors()85 void LiteSwitchOpActor::AppendOutputTensors() {
86 auto output_tensors = kernel_->out_tensors();
87 for (auto &partial_node : partial_nodes_) {
88 for (auto &tensor : partial_node->in_tensors()) {
89 if (std::find(output_tensors.begin(), output_tensors.end(), tensor) == output_tensors.end()) {
90 output_tensors.push_back(tensor);
91 }
92 }
93 }
94 kernel_->set_out_tensors(output_tensors);
95 }
96
ModifySubgraphKernel()97 int LiteSwitchOpActor::ModifySubgraphKernel() {
98 auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
99 if (subgraph_kernel == nullptr) {
100 MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
101 return RET_OK;
102 }
103
104 int ret = GetSwitchAndCallNode(subgraph_kernel);
105 if (ret != RET_OK) {
106 MS_LOG(ERROR) << "GetSwitchAndCallCnode failed.";
107 return ret;
108 }
109
110 subgraph_kernel->DropNode(call_node_);
111 subgraph_kernel->DropNode(switch_type_node_);
112 for (auto &partial_node : partial_nodes_) {
113 subgraph_kernel->DropNode(partial_node);
114 }
115 return ret;
116 }
117
UpdateActorOutput()118 int LiteSwitchOpActor::UpdateActorOutput() {
119 if (call_node_ == nullptr) {
120 MS_LOG(ERROR) << "not get the call node.";
121 return RET_ERROR;
122 }
123 auto call_output_tensors = call_node_->out_tensors();
124 auto output_tensors = kernel_->out_tensors();
125 for (auto iter = output_tensors.begin(); iter != output_tensors.end();) {
126 if (IsContain(call_output_tensors, *iter)) {
127 iter = output_tensors.erase(iter);
128 } else {
129 ++iter;
130 }
131 }
132 kernel_->set_out_tensors(output_tensors);
133 return RET_OK;
134 }
135
CompileArrow(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map)136 int LiteSwitchOpActor::CompileArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) {
137 int ret = ModifySubgraphKernel();
138 if (ret != RET_OK) {
139 MS_LOG(ERROR) << "ModifySubgraphKernel failed.";
140 return ret;
141 }
142
143 ret = UpdateActorOutput();
144 if (ret != RET_OK) {
145 MS_LOG(ERROR) << "UpdateActorOutput failed.";
146 return ret;
147 }
148
149 if (!kernel_->out_tensors().empty()) {
150 ret = CompileArrowThroughOutputTensors(receivers_map);
151 if (ret != RET_OK) {
152 MS_LOG(ERROR) << "CompileArrowThroughOutputTensors failed.";
153 return ret;
154 }
155 }
156
157 AppendOutputTensors();
158
159 ret = CompileArrowThroughSwitchCall(receivers_map);
160 if (ret != RET_OK) {
161 MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed.";
162 return ret;
163 }
164
165 return ret;
166 }
167
CreateSwitchTypeArrow(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map,const std::set<void * > & receiver_tensors,const Tensor * partial_in_tensor,std::vector<DataArrowPtr> * branch_output_data_arrows)168 int LiteSwitchOpActor::CreateSwitchTypeArrow(
169 const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map,
170 const std::set<void *> &receiver_tensors, const Tensor *partial_in_tensor,
171 std::vector<DataArrowPtr> *branch_output_data_arrows) {
172 for (auto receiver_tensor : receiver_tensors) {
173 MS_CHECK_TRUE_MSG(receivers_map.find(receiver_tensor) != receivers_map.end(), RET_ERROR,
174 "not find receiver_tensor in receivers_map");
175 auto receiver_set = receivers_map.at(receiver_tensor);
176 for (auto item : receiver_set) {
177 for (size_t j = 0; j < kernel_->out_tensors().size(); ++j) {
178 if (partial_in_tensor != kernel_->out_tensors()[j]) {
179 continue;
180 }
181 auto arrow = std::make_shared<DataArrow>(j, item.first, item.second);
182 MS_CHECK_TRUE_MSG(arrow != nullptr, RET_ERROR, "create data arrow failed.");
183 branch_output_data_arrows->push_back(arrow);
184 break;
185 }
186 }
187 }
188 return RET_OK;
189 }
190
CompileArrowThroughSwitchCall(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map)191 int LiteSwitchOpActor::CompileArrowThroughSwitchCall(
192 const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) {
193 for (auto &partial_node : partial_nodes_) {
194 if (partial_node == nullptr) {
195 MS_LOG(ERROR) << "partial_node_ is nullptr.";
196 return RET_NULL_PTR;
197 }
198 std::vector<DataArrowPtr> branch_output_data_arrows;
199 auto partial_in_tensors = partial_node->in_tensors();
200 for (size_t i = 0; i < partial_in_tensors.size(); ++i) {
201 auto receiver_tensors = ctx_->GetLinkInfo(partial_in_tensors[i]);
202 MS_CHECK_TRUE_MSG(!receiver_tensors.empty(), RET_ERROR, "no reviver for this actor");
203 auto ret =
204 CreateSwitchTypeArrow(receivers_map, receiver_tensors, partial_in_tensors[i], &branch_output_data_arrows);
205 if (ret != RET_OK) {
206 MS_LOG(ERROR) << "create switch type arrow failed, partial in tensor name: "
207 << partial_in_tensors[i]->tensor_name();
208 return ret;
209 }
210 }
211 all_branch_output_data_arrows_.push_back(branch_output_data_arrows);
212 }
213 return RET_OK;
214 }
215
PrepareOutputData()216 int LiteSwitchOpActor::PrepareOutputData() {
217 if (LiteOpActor::PrepareOutputData() != RET_OK) {
218 MS_LOG(ERROR) << "lite actor prepare output data failed.";
219 return RET_ERROR;
220 }
221 std::vector<int> arrow_num_of_each_tensor(kernel_->out_tensors().size(), 0);
222 std::set<int> arrow_indexes;
223 for (auto &branch_output_data_arrows : all_branch_output_data_arrows_) {
224 std::vector<OpDataPtr<Tensor>> branch_outputs_data{};
225 branch_outputs_data.resize(branch_output_data_arrows.size());
226 for (size_t i = 0; i < branch_output_data_arrows.size(); i++) {
227 auto &arrow = branch_output_data_arrows[i];
228 arrow_indexes.insert(arrow->from_output_index_);
229 ++arrow_num_of_each_tensor[arrow->from_output_index_];
230 auto data =
231 std::make_shared<OpData<Tensor>>(this->GetAID(), (kernel_->out_tensors()).at(arrow->from_output_index_),
232 static_cast<int>(arrow->to_input_index_));
233 if (data == nullptr) {
234 MS_LOG(ERROR) << "new branch output data failed.";
235 return RET_NULL_PTR;
236 }
237 branch_outputs_data.at(i) = data;
238 }
239 all_branchs_output_data_.push_back(branch_outputs_data);
240 }
241 for (auto index : arrow_indexes) {
242 kernel_->out_tensors().at(index)->set_init_ref_count(arrow_num_of_each_tensor[index]);
243 }
244 return RET_OK;
245 }
246
DecreaseOtherBranchInputTensor(const size_t & index)247 void LiteSwitchOpActor::DecreaseOtherBranchInputTensor(const size_t &index) {
248 switch_type_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
249 for (size_t i = 0; i < partial_nodes_.size(); ++i) {
250 if (i == index) {
251 continue;
252 }
253 for (auto input : partial_nodes_[i]->in_tensors()) {
254 input->DecRefCount();
255 }
256 }
257 }
258
AsyncBranchOutput(const size_t & index,OpContext<Tensor> * context)259 STATUS LiteSwitchOpActor::AsyncBranchOutput(const size_t &index, OpContext<Tensor> *context) {
260 if (index >= all_branch_output_data_arrows_.size()) {
261 MS_LOG(ERROR) << "index " << index
262 << " extend all_branch_output_data_arrows_.size(): " << all_branch_output_data_arrows_.size();
263 context->SetFailed(RET_ERROR);
264 return RET_ERROR;
265 }
266 if (index >= all_branchs_output_data_.size()) {
267 MS_LOG(ERROR) << "index " << index
268 << " extend all_branchs_output_data_.size(): " << all_branchs_output_data_.size();
269 context->SetFailed(RET_ERROR);
270 return RET_ERROR;
271 }
272 auto branch_output_data_arrows = all_branch_output_data_arrows_.at(index);
273 auto branch_outputs_data = all_branchs_output_data_.at(index);
274 if (branch_output_data_arrows.size() != branch_outputs_data.size()) {
275 MS_LOG(ERROR) << "index " << index
276 << " extend all_branchs_output_data_.size(): " << all_branchs_output_data_.size();
277 context->SetFailed(RET_ERROR);
278 return RET_ERROR;
279 }
280 for (size_t i = 0; i < branch_output_data_arrows.size(); ++i) {
281 auto &data = branch_outputs_data.at(i);
282 Async(branch_output_data_arrows[i]->to_op_id_, get_actor_mgr(), &mindspore::OpActor<Tensor>::RunOpData, data.get(),
283 context);
284 }
285 return RET_OK;
286 }
287
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)288 void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
289 auto op_uuid = context->sequential_num_;
290 input_op_datas_[op_uuid].push_back(inputs);
291 inputs_data_[inputs->index_] = inputs->data_;
292 if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
293 return;
294 }
295
296 auto ret = InitInputData();
297 if (ret != RET_OK) {
298 (void)input_op_datas_.erase(op_uuid);
299 context->SetFailed(ret);
300 return;
301 }
302
303 ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
304 *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
305 if (ret != RET_OK) {
306 (void)input_op_datas_.erase(op_uuid);
307 context->SetFailed(ret);
308 return;
309 }
310 (void)input_op_datas_.erase(op_uuid);
311
312 auto cond_ptr = reinterpret_cast<bool *>(switch_type_node_->in_tensors()[kSwitchCondTensorIndex]->data());
313 if (cond_ptr == nullptr) {
314 MS_LOG(ERROR) << "switch cond input data is nullptr.";
315 context->SetFailed(RET_NULL_PTR);
316 return;
317 }
318 size_t index = static_cast<size_t>(*cond_ptr);
319 DecreaseOtherBranchInputTensor(index);
320 ret = AsyncBranchOutput(index, context);
321 if (ret != RET_OK) {
322 MS_LOG(ERROR) << "AsyncBranchOutput failed.";
323 return;
324 }
325 if (!output_data_arrows_.empty()) {
326 AsyncOutput(context);
327 SetOutputData(context);
328 }
329 }
330 } // namespace mindspore::lite
331