• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/control_flow/actor/switch_actor.h"
18 #include <utility>
19 #include <algorithm>
20 #include <set>
21 #include <vector>
22 #include "mindrt/include/mindrt.hpp"
23 #include "src/litert/kernel_exec_util.h"
24 #include "src/common/tensor_util.h"
25 #include "src/litert/inner_allocator.h"
26 #ifdef ENABLE_FP16
27 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
28 #endif
29 namespace {
30 const constexpr int kSwitchMaxInputKernelSize = 3;
31 const constexpr int kSwitchMinInputKernelSize = 2;
32 const constexpr int kSwitchTruePartialInputIndex = 1;
33 const constexpr int kSwitchFalsePartialInputIndex = 2;
34 const constexpr int kSwitchCondTensorIndex = 0;
35 }  // namespace
36 
37 namespace mindspore::lite {
SetSwitchPartialNodes()38 int LiteSwitchOpActor::SetSwitchPartialNodes() {
39   auto switch_op_input_kernel_size = switch_type_node_->in_kernels().size();
40   // special case, switch cond input is const, should be removed in the future.
41   if (switch_op_input_kernel_size == kSwitchMinInputKernelSize) {
42     // reverse switch node input, then false cast to 0, true cast to 1, which is same as switch layer index.
43     partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchFalsePartialInputIndex - 1));
44     partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchTruePartialInputIndex - 1));
45     return RET_OK;
46   }
47 
48   if (switch_op_input_kernel_size == kSwitchMaxInputKernelSize) {
49     // reverse switch node input.
50     partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchFalsePartialInputIndex));
51     partial_nodes_.push_back(switch_type_node_->in_kernels().at(kSwitchTruePartialInputIndex));
52     return RET_OK;
53   }
54   MS_LOG(ERROR) << "switch op input kernel size: " << switch_op_input_kernel_size << ", which is not support.";
55   return RET_ERROR;
56 }
57 
SetSwitchLayerPartialNodes()58 int LiteSwitchOpActor::SetSwitchLayerPartialNodes() {
59   for (size_t i = 1; i < switch_type_node_->in_kernels().size(); ++i) {
60     partial_nodes_.push_back(switch_type_node_->in_kernels()[i]);
61   }
62   return RET_OK;
63 }
64 
GetSwitchAndCallNode(kernel::SubGraphKernel * subgraph_kernel)65 int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) {
66   for (auto &node : subgraph_kernel->nodes()) {
67     if (node->type() != schema::PrimitiveType_Call) {
68       continue;
69     }
70     call_node_ = node;
71     auto switch_node = kernel::KernelExecUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch);
72     auto switch_layer_node = kernel::KernelExecUtil::GetInputsSpecificNode(node, schema::PrimitiveType_SwitchLayer);
73     if (switch_node != nullptr) {
74       switch_type_node_ = switch_node;
75       return SetSwitchPartialNodes();
76     }
77     if (switch_layer_node != nullptr) {
78       switch_type_node_ = switch_layer_node;
79       return SetSwitchLayerPartialNodes();
80     }
81   }
82   return RET_OK;
83 }
84 
AppendOutputTensors()85 void LiteSwitchOpActor::AppendOutputTensors() {
86   auto output_tensors = kernel_->out_tensors();
87   for (auto &partial_node : partial_nodes_) {
88     for (auto &tensor : partial_node->in_tensors()) {
89       if (std::find(output_tensors.begin(), output_tensors.end(), tensor) == output_tensors.end()) {
90         output_tensors.push_back(tensor);
91       }
92     }
93   }
94   kernel_->set_out_tensors(output_tensors);
95 }
96 
ModifySubgraphKernel()97 int LiteSwitchOpActor::ModifySubgraphKernel() {
98   auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
99   if (subgraph_kernel == nullptr) {
100     MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
101     return RET_OK;
102   }
103 
104   int ret = GetSwitchAndCallNode(subgraph_kernel);
105   if (ret != RET_OK) {
106     MS_LOG(ERROR) << "GetSwitchAndCallCnode failed.";
107     return ret;
108   }
109 
110   subgraph_kernel->DropNode(call_node_);
111   subgraph_kernel->DropNode(switch_type_node_);
112   for (auto &partial_node : partial_nodes_) {
113     subgraph_kernel->DropNode(partial_node);
114   }
115   return ret;
116 }
117 
UpdateActorOutput()118 int LiteSwitchOpActor::UpdateActorOutput() {
119   if (call_node_ == nullptr) {
120     MS_LOG(ERROR) << "not get the call node.";
121     return RET_ERROR;
122   }
123   auto call_output_tensors = call_node_->out_tensors();
124   auto output_tensors = kernel_->out_tensors();
125   for (auto iter = output_tensors.begin(); iter != output_tensors.end();) {
126     if (IsContain(call_output_tensors, *iter)) {
127       iter = output_tensors.erase(iter);
128     } else {
129       ++iter;
130     }
131   }
132   kernel_->set_out_tensors(output_tensors);
133   return RET_OK;
134 }
135 
CompileArrow(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map)136 int LiteSwitchOpActor::CompileArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) {
137   int ret = ModifySubgraphKernel();
138   if (ret != RET_OK) {
139     MS_LOG(ERROR) << "ModifySubgraphKernel failed.";
140     return ret;
141   }
142 
143   ret = UpdateActorOutput();
144   if (ret != RET_OK) {
145     MS_LOG(ERROR) << "UpdateActorOutput failed.";
146     return ret;
147   }
148 
149   if (!kernel_->out_tensors().empty()) {
150     ret = CompileArrowThroughOutputTensors(receivers_map);
151     if (ret != RET_OK) {
152       MS_LOG(ERROR) << "CompileArrowThroughOutputTensors failed.";
153       return ret;
154     }
155   }
156 
157   AppendOutputTensors();
158 
159   ret = CompileArrowThroughSwitchCall(receivers_map);
160   if (ret != RET_OK) {
161     MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed.";
162     return ret;
163   }
164 
165   return ret;
166 }
167 
CreateSwitchTypeArrow(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map,const std::set<void * > & receiver_tensors,const Tensor * partial_in_tensor,std::vector<DataArrowPtr> * branch_output_data_arrows)168 int LiteSwitchOpActor::CreateSwitchTypeArrow(
169   const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map,
170   const std::set<void *> &receiver_tensors, const Tensor *partial_in_tensor,
171   std::vector<DataArrowPtr> *branch_output_data_arrows) {
172   for (auto receiver_tensor : receiver_tensors) {
173     MS_CHECK_TRUE_MSG(receivers_map.find(receiver_tensor) != receivers_map.end(), RET_ERROR,
174                       "not find receiver_tensor in receivers_map");
175     auto receiver_set = receivers_map.at(receiver_tensor);
176     for (auto item : receiver_set) {
177       for (size_t j = 0; j < kernel_->out_tensors().size(); ++j) {
178         if (partial_in_tensor != kernel_->out_tensors()[j]) {
179           continue;
180         }
181         auto arrow = std::make_shared<DataArrow>(j, item.first, item.second);
182         MS_CHECK_TRUE_MSG(arrow != nullptr, RET_ERROR, "create data arrow failed.");
183         branch_output_data_arrows->push_back(arrow);
184         break;
185       }
186     }
187   }
188   return RET_OK;
189 }
190 
CompileArrowThroughSwitchCall(const std::unordered_map<void *,std::set<std::pair<AID,size_t>>> & receivers_map)191 int LiteSwitchOpActor::CompileArrowThroughSwitchCall(
192   const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) {
193   for (auto &partial_node : partial_nodes_) {
194     if (partial_node == nullptr) {
195       MS_LOG(ERROR) << "partial_node_ is nullptr.";
196       return RET_NULL_PTR;
197     }
198     std::vector<DataArrowPtr> branch_output_data_arrows;
199     auto partial_in_tensors = partial_node->in_tensors();
200     for (size_t i = 0; i < partial_in_tensors.size(); ++i) {
201       auto receiver_tensors = ctx_->GetLinkInfo(partial_in_tensors[i]);
202       MS_CHECK_TRUE_MSG(!receiver_tensors.empty(), RET_ERROR, "no reviver for this actor");
203       auto ret =
204         CreateSwitchTypeArrow(receivers_map, receiver_tensors, partial_in_tensors[i], &branch_output_data_arrows);
205       if (ret != RET_OK) {
206         MS_LOG(ERROR) << "create switch type arrow failed, partial in tensor name: "
207                       << partial_in_tensors[i]->tensor_name();
208         return ret;
209       }
210     }
211     all_branch_output_data_arrows_.push_back(branch_output_data_arrows);
212   }
213   return RET_OK;
214 }
215 
PrepareOutputData()216 int LiteSwitchOpActor::PrepareOutputData() {
217   if (LiteOpActor::PrepareOutputData() != RET_OK) {
218     MS_LOG(ERROR) << "lite actor prepare output data failed.";
219     return RET_ERROR;
220   }
221   std::vector<int> arrow_num_of_each_tensor(kernel_->out_tensors().size(), 0);
222   std::set<int> arrow_indexes;
223   for (auto &branch_output_data_arrows : all_branch_output_data_arrows_) {
224     std::vector<OpDataPtr<Tensor>> branch_outputs_data{};
225     branch_outputs_data.resize(branch_output_data_arrows.size());
226     for (size_t i = 0; i < branch_output_data_arrows.size(); i++) {
227       auto &arrow = branch_output_data_arrows[i];
228       arrow_indexes.insert(arrow->from_output_index_);
229       ++arrow_num_of_each_tensor[arrow->from_output_index_];
230       auto data =
231         std::make_shared<OpData<Tensor>>(this->GetAID(), (kernel_->out_tensors()).at(arrow->from_output_index_),
232                                          static_cast<int>(arrow->to_input_index_));
233       if (data == nullptr) {
234         MS_LOG(ERROR) << "new branch output data failed.";
235         return RET_NULL_PTR;
236       }
237       branch_outputs_data.at(i) = data;
238     }
239     all_branchs_output_data_.push_back(branch_outputs_data);
240   }
241   for (auto index : arrow_indexes) {
242     kernel_->out_tensors().at(index)->set_init_ref_count(arrow_num_of_each_tensor[index]);
243   }
244   return RET_OK;
245 }
246 
DecreaseOtherBranchInputTensor(const size_t & index)247 void LiteSwitchOpActor::DecreaseOtherBranchInputTensor(const size_t &index) {
248   switch_type_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
249   for (size_t i = 0; i < partial_nodes_.size(); ++i) {
250     if (i == index) {
251       continue;
252     }
253     for (auto input : partial_nodes_[i]->in_tensors()) {
254       input->DecRefCount();
255     }
256   }
257 }
258 
AsyncBranchOutput(const size_t & index,OpContext<Tensor> * context)259 STATUS LiteSwitchOpActor::AsyncBranchOutput(const size_t &index, OpContext<Tensor> *context) {
260   if (index >= all_branch_output_data_arrows_.size()) {
261     MS_LOG(ERROR) << "index " << index
262                   << " extend all_branch_output_data_arrows_.size(): " << all_branch_output_data_arrows_.size();
263     context->SetFailed(RET_ERROR);
264     return RET_ERROR;
265   }
266   if (index >= all_branchs_output_data_.size()) {
267     MS_LOG(ERROR) << "index " << index
268                   << " extend all_branchs_output_data_.size(): " << all_branchs_output_data_.size();
269     context->SetFailed(RET_ERROR);
270     return RET_ERROR;
271   }
272   auto branch_output_data_arrows = all_branch_output_data_arrows_.at(index);
273   auto branch_outputs_data = all_branchs_output_data_.at(index);
274   if (branch_output_data_arrows.size() != branch_outputs_data.size()) {
275     MS_LOG(ERROR) << "index " << index
276                   << " extend all_branchs_output_data_.size(): " << all_branchs_output_data_.size();
277     context->SetFailed(RET_ERROR);
278     return RET_ERROR;
279   }
280   for (size_t i = 0; i < branch_output_data_arrows.size(); ++i) {
281     auto &data = branch_outputs_data.at(i);
282     Async(branch_output_data_arrows[i]->to_op_id_, get_actor_mgr(), &mindspore::OpActor<Tensor>::RunOpData, data.get(),
283           context);
284   }
285   return RET_OK;
286 }
287 
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)288 void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
289   auto op_uuid = context->sequential_num_;
290   input_op_datas_[op_uuid].push_back(inputs);
291   inputs_data_[inputs->index_] = inputs->data_;
292   if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
293     return;
294   }
295 
296   auto ret = InitInputData();
297   if (ret != RET_OK) {
298     (void)input_op_datas_.erase(op_uuid);
299     context->SetFailed(ret);
300     return;
301   }
302 
303   ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
304                   *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
305   if (ret != RET_OK) {
306     (void)input_op_datas_.erase(op_uuid);
307     context->SetFailed(ret);
308     return;
309   }
310   (void)input_op_datas_.erase(op_uuid);
311 
312   auto cond_ptr = reinterpret_cast<bool *>(switch_type_node_->in_tensors()[kSwitchCondTensorIndex]->data());
313   if (cond_ptr == nullptr) {
314     MS_LOG(ERROR) << "switch cond input data is nullptr.";
315     context->SetFailed(RET_NULL_PTR);
316     return;
317   }
318   size_t index = static_cast<size_t>(*cond_ptr);
319   DecreaseOtherBranchInputTensor(index);
320   ret = AsyncBranchOutput(index, context);
321   if (ret != RET_OK) {
322     MS_LOG(ERROR) << "AsyncBranchOutput failed.";
323     return;
324   }
325   if (!output_data_arrows_.empty()) {
326     AsyncOutput(context);
327     SetOutputData(context);
328   }
329 }
330 }  // namespace mindspore::lite
331