• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <utility>
18 #include <algorithm>
19 #include "src/lite_mindrt.h"
20 #include "mindrt/include/mindrt.hpp"
21 #include "src/lite_kernel_util.h"
22 #include "src/common/tensor_util.h"
23 #include "src/runtime/inner_allocator.h"
24 #include "src/runtime/kernel/arm/base/partial_fusion.h"
25 #ifdef ENABLE_FP16
26 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
27 #endif
28 
29 namespace mindspore::lite {
RunOpData(OpData<lite::Tensor> * inputs,OpContext<lite::Tensor> * context)30 void LiteOpActor::RunOpData(OpData<lite::Tensor> *inputs, OpContext<lite::Tensor> *context) {
31   auto op_uuid = context->sequential_num_;
32   input_op_datas_[op_uuid].push_back(inputs);
33   inputs_data_[inputs->index_] = inputs->data_;
34   if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
35     return;
36   }
37 
38   auto ret = InitInputData();
39   if (ret != RET_OK) {
40     input_op_datas_.erase(op_uuid);
41     context->SetFailed(ret);
42     return;
43   }
44 
45   ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
46                   *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
47   if (ret != RET_OK) {
48     input_op_datas_.erase(op_uuid);
49     context->SetFailed(ret);
50     return;
51   }
52   input_op_datas_.erase(op_uuid);
53   AsyncOutput(context);
54 
55   SetOutputData(context);
56 
57   return;
58 }
59 
OfflineIsolated(const std::vector<kernel::LiteKernel * > & kernels,const kernel::LiteKernel & this_kernel,const lite::Tensor & this_input_tensor)60 bool OfflineIsolated(const std::vector<kernel::LiteKernel *> &kernels, const kernel::LiteKernel &this_kernel,
61                      const lite::Tensor &this_input_tensor) {
62   if (this_input_tensor.IsGraphInput()) {
63     return false;
64   }
65   for (auto &kernel : kernels) {
66     if (kernel == &this_kernel) {
67       continue;
68     }
69     if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(),
70                     [&this_input_tensor](lite::Tensor *tensor) { return tensor == &this_input_tensor; })) {
71       return false;
72     }
73   }
74   return true;
75 }
76 
ReplaceNodeInTensor(kernel::LiteKernel * kernel,Tensor * old_tensor,Tensor * new_tensor)77 void LiteOpActor::ReplaceNodeInTensor(kernel::LiteKernel *kernel, Tensor *old_tensor, Tensor *new_tensor) {
78   int ref_count = 0;
79 #ifndef DELEGATE_CLIP
80   /* set op input for calculate */
81   if (kernel->desc().arch == kernel::kDelegate) {
82     ref_count++;
83   } else {
84 #endif
85     for (auto in_node : reinterpret_cast<kernel::SubGraphKernel *>(kernel)->in_nodes()) {
86       for (size_t node_in_index = 0; node_in_index < in_node->in_tensors().size(); node_in_index++) {
87         if (old_tensor == in_node->in_tensors()[node_in_index]) {
88           in_node->set_in_tensor(new_tensor, node_in_index);
89           ref_count++;
90         }
91       }
92     }
93 #ifndef DELEGATE_CLIP
94   }
95 #endif
96   new_tensor->set_init_ref_count(ref_count);
97 }
98 
IsolateInputData(std::vector<std::shared_ptr<LiteOpActor>> * actors)99 int LiteOpActor::IsolateInputData(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
100   std::vector<kernel::LiteKernel *> kernels{};
101   std::transform(actors->begin(), actors->end(), std::back_inserter(kernels),
102                  [](std::shared_ptr<LiteOpActor> actor) { return actor->kernel_; });
103   size_t in_tensor_size = kernel_->in_tensors().size();
104   for (size_t i = 0; i < in_tensor_size; i++) {
105     Tensor *old_tensor = kernel_->in_tensors()[i];
106 
107     if (OfflineIsolated(kernels, *kernel_, *old_tensor)) {
108       if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
109         old_tensor->set_data_type(kernel_->desc().data_type);
110       }
111 #ifndef CONTROLFLOW_TENSORLIST_CLIP
112       if (old_tensor->data_type() == kObjectTypeTensorType) {
113         auto old_tensorlist = reinterpret_cast<TensorList *>(old_tensor);
114         if (old_tensorlist->tensors_data_type() == kNumberTypeFloat16 ||
115             old_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
116           old_tensorlist->set_tensors_data_type(kernel_->desc().data_type);
117         }
118       }
119 #endif
120       old_tensor->set_allocator(kernel_->Context()->allocator);
121       continue;
122     }
123 
124     TypeId new_data_type = old_tensor->data_type();
125     if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
126       new_data_type = kernel_->desc().data_type;
127     }
128 
129     Tensor *new_tensor = new Tensor(new_data_type, old_tensor->shape(), old_tensor->format(), old_tensor->category());
130     if (new_tensor == nullptr) {
131       MS_LOG(ERROR) << "new Tensor failed.";
132       return RET_NULL_PTR;
133     }
134     new_tensor->set_allocator(old_tensor->allocator());
135     if (new_tensor->allocator() == nullptr && kernel_->Context() != nullptr &&
136         kernel_->desc().arch != kernel::kDelegate) {
137       new_tensor->set_allocator(kernel_->Context()->allocator);
138     }
139 
140     new_tensor->set_tensor_name(kernel_->name() + "_duplicate_" + old_tensor->tensor_name());
141     for (LiteQuantParam quant : old_tensor->quant_params()) {
142       new_tensor->AddQuantParam(quant);
143     }
144     isolate_input_map_.insert(std::make_pair(new_tensor, old_tensor));
145     ReplaceNodeInTensor(kernel_, old_tensor, new_tensor);
146     /* set subgraph input for copy data */
147     kernel_->set_in_tensor(new_tensor, i);
148   }
149   return RET_OK;
150 }
151 
LiteActorInit(std::vector<std::shared_ptr<LiteOpActor>> * actors)152 int LiteOpActor::LiteActorInit(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
153   /* Init output arrow */
154   auto ret = CompileArrow();
155   if (ret != RET_OK) {
156     MS_LOG(ERROR) << "compile arrow failed.";
157     return ret;
158   }
159 
160   /* Init Actor output data */
161   ret = PrepareOutputData();
162   if (ret != RET_OK) {
163     MS_LOG(ERROR) << "prepare output data failed.";
164     return ret;
165   }
166 
167   /* subgraph transaction isolation */
168   ret = IsolateInputData(actors);
169   if (ret != RET_OK) {
170     MS_LOG(ERROR) << "isolate input data failed.";
171     return ret;
172   }
173   return RET_OK;
174 }
175 
ResizeGraphInput(const std::vector<mindspore::tensor::MSTensor * > & inputs,const std::vector<std::vector<int>> & dims)176 int LiteOpActor::ResizeGraphInput(const std::vector<mindspore::tensor::MSTensor *> &inputs,
177                                   const std::vector<std::vector<int>> &dims) {
178   for (auto map : isolate_input_map_) {
179     auto isolate_tensor = map.first;
180     auto src_tensor = map.second;
181     for (size_t i = 0; i < inputs.size(); i++) {
182       if (src_tensor == inputs[i]) {
183         isolate_tensor->set_shape(dims[i]);
184       }
185     }
186   }
187   return RET_OK;
188 }
189 
CompileArrowThroughOutputKernels()190 int LiteOpActor::CompileArrowThroughOutputKernels() {
191   output_data_arrows_.clear();
192   int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
193   for (int i = 0; i < out_tensor_size; i++) {
194     for (auto out : kernel_->out_kernels()) {
195       int in_tensor_size = static_cast<int>(out->in_tensors().size());
196       int to_input_index = -1;
197       for (int j = 0; j < in_tensor_size; j++) {
198         if (kernel_->out_tensors()[i] == out->in_tensors()[j]) {
199           to_input_index = j;
200           break;
201         }
202       }
203       if (to_input_index == -1) {
204         continue;
205       }
206       auto id = out->name() + this->GetAID().Url();
207       auto arrow = std::make_shared<DataArrow>(i, AID(id), to_input_index);
208       if (arrow == nullptr) {
209         MS_LOG(ERROR) << "create DataArrow failed, out kernel: " << out->name();
210         return RET_ERROR;
211       }
212       output_data_arrows_.emplace_back(std::move(arrow));
213     }
214   }
215   return RET_OK;
216 }
217 
218 #ifndef CONTROLFLOW_TENSORLIST_CLIP
CompileArrowThroughPartialCall()219 int LiteOpActor::CompileArrowThroughPartialCall() {
220 #ifndef DELEGATE_CLIP
221   if (kernel_->desc().arch == kernel::kDelegate) {
222     MS_LOG(INFO) << "kernel is delegate subgraph kernel.";
223     return RET_OK;
224   }
225 #endif
226   auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
227   if (subgraph_kernel == nullptr) {
228     MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
229     return RET_OK;
230   }
231   for (auto &node : subgraph_kernel->nodes()) {
232     if (node->type() != schema::PrimitiveType_Call) {
233       continue;
234     }
235     call_node_ = node;
236     auto partial_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_PartialFusion);
237     if (!partial_node) {
238       continue;
239     }
240     partial_node_ = partial_node;
241     auto subgraph = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel())->subgraph_kernel();
242     auto out_actor_id = subgraph_to_actor_.at(subgraph);
243 
244     kernel_->set_out_tensors(partial_node->in_tensors());
245     for (size_t i = 0; i < partial_node->in_tensors().size(); ++i) {
246       auto arrow = std::make_shared<DataArrow>(i, out_actor_id, i);
247       if (arrow == nullptr) {
248         MS_LOG(ERROR) << "create DataArrow failed";
249         return RET_ERROR;
250       }
251       output_data_arrows_.emplace_back(std::move(arrow));
252     }
253   }
254 
255   subgraph_kernel->DropNode(partial_node_);
256   subgraph_kernel->DropNode(call_node_);
257   return RET_OK;
258 }
259 #endif
260 
CompileArrow()261 int LiteOpActor::CompileArrow() {
262   int ret;
263   output_data_arrows_.clear();
264 #ifndef CONTROLFLOW_TENSORLIST_CLIP
265   ret = CompileArrowThroughPartialCall();
266   if (ret != RET_OK) {
267     output_data_arrows_.clear();
268     MS_LOG(ERROR) << "CompileArrowThroughPartialCall failed.";
269     return ret;
270   }
271   if (!output_data_arrows_.empty()) {
272     MS_LOG(INFO) << "CompileArrowThroughPartialCall done.";
273     return RET_OK;
274   }
275 #endif
276   ret = CompileArrowThroughOutputKernels();
277   if (ret != RET_OK) {
278     output_data_arrows_.clear();
279     MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
280     return ret;
281   }
282   return ret;
283 }
284 
MoveTensorInputData(Tensor * dst_tensor,Tensor * src_tensor)285 void LiteOpActor::MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor) {
286   MS_ASSERT(src_tensor != dst_tensor);
287   dst_tensor->FreeData();
288   dst_tensor->ResetRefCount();
289   dst_tensor->set_allocator(src_tensor->allocator());
290 
291   src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
292 
293   if (src_tensor->data() != nullptr) {
294     dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
295   }
296 
297   dst_tensor->set_own_data(src_tensor->own_data());
298   src_tensor->DecRefCount();
299 }
300 
MoveInputData(Tensor * dst_tensor,Tensor * src_tensor)301 void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
302   if (src_tensor == dst_tensor) {
303     MS_LOG(INFO) << "no need to move.";
304     return;
305   }
306   MS_ASSERT(src_tensor->allocator() != nullptr);
307 #ifndef CONTROLFLOW_TENSORLIST_CLIP
308   if (src_tensor->data_type() == kObjectTypeTensorType) {
309     MoveTensorListInputData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
310   } else {
311     MoveTensorInputData(dst_tensor, src_tensor);
312   }
313 #else
314   MoveTensorInputData(dst_tensor, src_tensor);
315 #endif
316   return;
317 }
318 
SetInputData(Tensor * dst_tensor,Tensor * src_tensor)319 void LiteOpActor::SetInputData(Tensor *dst_tensor, Tensor *src_tensor) {
320   dst_tensor->set_data(src_tensor->data());
321   dst_tensor->set_own_data(false);
322 }
323 
CastInputData(Tensor * dst,Tensor * src)324 int LiteOpActor::CastInputData(Tensor *dst, Tensor *src) {
325   int ret = RET_OK;
326 #ifndef CONTROLFLOW_TENSORLIST_CLIP
327   if (src->data_type() != kObjectTypeTensorType) {
328     ret = CastTensorInputData(dst, src);
329   } else {
330     ret = CastTensorListInputData(reinterpret_cast<TensorList *>(dst), reinterpret_cast<TensorList *>(src));
331   }
332 #else
333   ret = CastTensorInputData(dst, src);
334 #endif
335   src->DecRefCount();
336   return ret;
337 }
338 
NeedCastData(Tensor * dst_tensor,Tensor * src_tensor)339 bool LiteOpActor::NeedCastData(Tensor *dst_tensor, Tensor *src_tensor) {
340   if (dst_tensor->data_type() != kObjectTypeTensorType && src_tensor->data_type() != kObjectTypeTensorType &&
341       dst_tensor->data_type() != src_tensor->data_type()) {
342     return true;
343   }
344 #ifndef CONTROLFLOW_TENSORLIST_CLIP
345   if (dst_tensor->data_type() == kObjectTypeTensorType && src_tensor->data_type() == kObjectTypeTensorType &&
346       reinterpret_cast<TensorList *>(dst_tensor)->tensors_data_type() !=
347         reinterpret_cast<TensorList *>(src_tensor)->tensors_data_type()) {
348     return true;
349   }
350 #endif
351   return false;
352 }
353 
CastTensorInputData(Tensor * dst,Tensor * src)354 int LiteOpActor::CastTensorInputData(Tensor *dst, Tensor *src) {
355   dst->MallocData();
356   dst->ResetRefCount();
357 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
358   if (dst->shape() != src->shape()) {
359     MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs "
360                   << "src tensor: " << src->tensor_name() << " shape: " << src->shape();
361     return RET_PARAM_INVALID;
362   }
363   auto dst_data = dst->MutableData(); /* using MutableData to sync GPU data */
364   auto src_data = src->MutableData();
365   auto src_nums_size = src->ElementsNum();
366   auto dst_data_type = static_cast<int>(dst->data_type());
367   auto src_data_type = static_cast<int>(src->data_type());
368   if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) {
369     Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
370   } else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) {
371     Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
372   } else {
373     MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type;
374     return RET_NOT_SUPPORT;
375   }
376   return RET_OK;
377 #endif
378   return RET_ERROR;
379 }
380 
381 #ifndef CONTROLFLOW_TENSORLIST_CLIP
MoveTensorListInputData(TensorList * dst_tensorlist,TensorList * src_tensorlist)382 void LiteOpActor::MoveTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
383   MS_ASSERT(src_tensorlist != nullptr);
384   MS_ASSERT(dst_tensorlist != nullptr);
385   dst_tensorlist->FreeData();
386   dst_tensorlist->ResetRefCount();
387   dst_tensorlist->set_allocator(src_tensorlist->allocator());
388 
389   auto src_tensorlist_tensors_size = src_tensorlist->tensors().size();
390   auto dst_tensorlist_tensors_size = dst_tensorlist->tensors().size();
391   if (src_tensorlist_tensors_size != dst_tensorlist_tensors_size) {
392     MS_LOG(ERROR) << "src tensorlist: " << src_tensorlist->tensor_name()
393                   << " tesnors size: " << src_tensorlist_tensors_size
394                   << " vs dst tensorlist: " << src_tensorlist->tensor_name()
395                   << " tensors size: " << dst_tensorlist_tensors_size;
396     return;
397   }
398 
399   dst_tensorlist->set_own_data(src_tensorlist->own_data());
400   for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
401     auto &src_tensor = src_tensorlist->tensors()[i];
402     auto &dst_tensor = dst_tensorlist->tensors()[i];
403 
404     if (src_tensor->allocator() != nullptr) {
405       src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
406     }
407     dst_tensor->set_own_data(src_tensor->own_data());
408     if (src_tensor->data() != nullptr) {
409       dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
410     }
411     dst_tensor->set_shape(src_tensor->shape());
412   }
413 
414   if (src_tensorlist->IsConst() || src_tensorlist->IsGraphInput()) {
415     dst_tensorlist->set_own_data(false);
416   } else {
417     src_tensorlist->DecRefCount();
418   }
419 }
420 
CastTensorListInputData(TensorList * dst_tensorlist,TensorList * src_tensorlist)421 int LiteOpActor::CastTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
422   MS_ASSERT(src_tensorlist != nullptr);
423   MS_ASSERT(dst_tensorlist != nullptr);
424   dst_tensorlist->set_shape(src_tensorlist->shape());
425   std::vector<std::vector<int>> tensors_shapes{};
426   tensors_shapes.resize(src_tensorlist->tensors().size());
427   for (size_t i = 0; i < tensors_shapes.size(); ++i) {
428     tensors_shapes[i] = src_tensorlist->tensors()[i]->shape();
429   }
430   if (src_tensorlist->tensors_data_type() == kNumberTypeFloat16) {
431     dst_tensorlist->MallocTensorListData(kNumberTypeFloat32, tensors_shapes);
432   }
433   if (src_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
434     dst_tensorlist->MallocTensorListData(kNumberTypeFloat16, tensors_shapes);
435   }
436   dst_tensorlist->set_allocator(src_tensorlist->allocator());
437   dst_tensorlist->ResetRefCount();
438 
439   for (size_t i = 0; i < src_tensorlist->tensors().size(); ++i) {
440     auto &src_tensor = src_tensorlist->tensors()[i];
441     auto &dst_tensor = dst_tensorlist->tensors()[i];
442     CastTensorInputData(dst_tensor, src_tensor);
443   }
444   return RET_OK;
445 }
446 
CompileTrueBranchArrow()447 int LiteSwitchOpActor::CompileTrueBranchArrow() {
448   if (true_partial_node_ == nullptr) {
449     MS_LOG(ERROR) << "true_partial_node_ is nullptr.";
450     return RET_NULL_PTR;
451   }
452   auto subgraph = static_cast<kernel::PartialFusionKernel *>(true_partial_node_->kernel())->subgraph_kernel();
453   auto true_branch_actor_id = subgraph_to_actor_.at(subgraph);
454 
455   for (size_t i = 0; i < true_partial_node_->in_tensors().size(); ++i) {
456     int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
457     for (int j = 0; j < out_tensor_size; ++j) {
458       if (true_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
459         continue;
460       }
461       auto arrow = std::make_shared<DataArrow>(j, true_branch_actor_id, i);
462       if (arrow == nullptr) {
463         MS_LOG(ERROR) << "create DataArrow failed";
464         return RET_ERROR;
465       }
466       true_branch_output_data_arrows_.emplace_back(std::move(arrow));
467     }
468   }
469   return RET_OK;
470 }
471 
CompileFalseBranchArrow()472 int LiteSwitchOpActor::CompileFalseBranchArrow() {
473   if (false_partial_node_ == nullptr) {
474     MS_LOG(ERROR) << "false_partial_node_ is nullptr.";
475     return RET_NULL_PTR;
476   }
477   auto subgraph = static_cast<kernel::PartialFusionKernel *>(false_partial_node_->kernel())->subgraph_kernel();
478   auto false_branch_actor_id = subgraph_to_actor_.at(subgraph);
479 
480   for (size_t i = 0; i < false_partial_node_->in_tensors().size(); ++i) {
481     int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
482     for (int j = 0; j < out_tensor_size; ++j) {
483       if (false_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
484         continue;
485       }
486       auto arrow = std::make_shared<DataArrow>(j, false_branch_actor_id, i);
487       if (arrow == nullptr) {
488         MS_LOG(ERROR) << "create DataArrow failed";
489         return RET_ERROR;
490       }
491       false_branch_output_data_arrows_.emplace_back(std::move(arrow));
492     }
493   }
494   return RET_OK;
495 }
496 
GetSwitchAndCallNode(kernel::SubGraphKernel * subgraph_kernel)497 int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) {
498   for (auto &node : subgraph_kernel->nodes()) {
499     if (node->type() != schema::PrimitiveType_Call) {
500       continue;
501     }
502     call_node_ = node;
503     auto switch_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch);
504     if (!switch_node) {
505       continue;
506     }
507 
508     if (switch_node->in_tensors().size() < kSwitchMinInputTensorSize) {
509       MS_LOG(ERROR) << "actor name: " << this->GetAID() << "'s switch node " << switch_node->name()
510                     << " input tensor size: " << switch_node->in_tensors().size() << " is less than 3.";
511       return RET_ERROR;
512     }
513 
514     switch_node_ = switch_node;
515     if (switch_node->in_kernels().size() == kSwitchMaxInputKernelSize) {
516       true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex);
517       false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex);
518     }
519 
520     if (switch_node->in_kernels().size() == kSwitchMinInputKernelSize) {
521       true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex - 1);
522       false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex - 1);
523     }
524     break;
525   }
526   return RET_OK;
527 }
528 
AppendOutputTensors()529 void LiteSwitchOpActor::AppendOutputTensors() {
530   for (auto &tensor : true_partial_node_->in_tensors()) {
531     if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
532       output_tensors_.push_back(tensor);
533     }
534   }
535   for (auto &tensor : false_partial_node_->in_tensors()) {
536     if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
537       output_tensors_.push_back(tensor);
538     }
539   }
540   kernel_->set_out_tensors(output_tensors_);
541 }
542 
CompileArrowThroughSwitchCall()543 int LiteSwitchOpActor::CompileArrowThroughSwitchCall() {
544   auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
545   if (subgraph_kernel == nullptr) {
546     MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
547     return RET_OK;
548   }
549 
550   int ret = GetSwitchAndCallNode(subgraph_kernel);
551   if (ret != RET_OK) {
552     MS_LOG(ERROR) << "GetSwitchAndCallCnode failed.";
553     return ret;
554   }
555 
556   AppendOutputTensors();
557 
558   ret = CompileTrueBranchArrow();
559   if (ret != RET_OK) {
560     MS_LOG(ERROR) << "CompileTrueBranchArrow failed.";
561     true_branch_output_data_arrows_.clear();
562     return ret;
563   }
564 
565   ret = CompileFalseBranchArrow();
566   if (ret != RET_OK) {
567     MS_LOG(ERROR) << "CompileFalseBranchArrow failed.";
568     false_branch_output_data_arrows_.clear();
569     true_branch_output_data_arrows_.clear();
570     return ret;
571   }
572 
573   subgraph_kernel->DropNode(call_node_);
574   subgraph_kernel->DropNode(switch_node_);
575   subgraph_kernel->DropNode(true_partial_node_);
576   subgraph_kernel->DropNode(false_partial_node_);
577 
578   return ret;
579 }
580 
CompileArrow()581 int LiteSwitchOpActor::CompileArrow() {
582   int ret = CompileArrowThroughSwitchCall();
583   if (ret != RET_OK) {
584     true_branch_output_data_arrows_.clear();
585     false_branch_output_data_arrows_.clear();
586     MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed.";
587     return ret;
588   }
589   if (!true_branch_output_data_arrows_.empty() && !false_branch_output_data_arrows_.empty()) {
590     MS_LOG(INFO) << "CompileArrowThroughSwitchCall done.";
591     return RET_OK;
592   }
593   ret = CompileArrowThroughOutputKernels();
594   if (ret != RET_OK) {
595     output_data_arrows_.clear();
596     MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
597     return ret;
598   }
599   return ret;
600 }
601 
PrepareOutputData()602 int LiteSwitchOpActor::PrepareOutputData() {
603   true_branch_outputs_data_.resize(true_branch_output_data_arrows_.size());
604   for (size_t i = 0; i < true_branch_output_data_arrows_.size(); i++) {
605     auto &arrow = true_branch_output_data_arrows_[i];
606     auto data =
607       std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
608                                        static_cast<int>(arrow->to_input_index_));
609     if (data == nullptr) {
610       MS_LOG(ERROR) << "new true_branch_output_data failed.";
611       return RET_NULL_PTR;
612     }
613     true_branch_outputs_data_.at(i) = data;
614   }
615 
616   false_branch_outputs_data_.resize(false_branch_output_data_arrows_.size());
617   for (size_t i = 0; i < false_branch_output_data_arrows_.size(); i++) {
618     auto &arrow = false_branch_output_data_arrows_[i];
619     auto data =
620       std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
621                                        static_cast<int>(arrow->to_input_index_));
622     if (data == nullptr) {
623       MS_LOG(ERROR) << "new false_branch_output_data failed.";
624       return RET_NULL_PTR;
625     }
626     false_branch_outputs_data_.at(i) = data;
627   }
628   return RET_OK;
629 }
630 
DecreaseTrueBranchInputTensor()631 void LiteSwitchOpActor::DecreaseTrueBranchInputTensor() {
632   switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
633   for (auto input : true_partial_node_->in_tensors()) {
634     input->DecRefCount();
635   }
636 }
637 
DecreaseFalseBranchInputTensor()638 void LiteSwitchOpActor::DecreaseFalseBranchInputTensor() {
639   switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
640   for (auto input : false_partial_node_->in_tensors()) {
641     input->DecRefCount();
642   }
643 }
644 
AsyncTrueBranchOutput(OpContext<Tensor> * context)645 void LiteSwitchOpActor::AsyncTrueBranchOutput(OpContext<Tensor> *context) {
646   MS_ASSERT(true_branch_output_data_arrows_.size() == true_branch_outputs_data_.size());
647   for (size_t i = 0; i < true_branch_output_data_arrows_.size(); ++i) {
648     auto &data = true_branch_outputs_data_.at(i);
649     Async(true_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
650   }
651 }
652 
AsyncFalseBranchOutput(OpContext<Tensor> * context)653 void LiteSwitchOpActor::AsyncFalseBranchOutput(OpContext<Tensor> *context) {
654   MS_ASSERT(false_branch_output_data_arrows_.size() == false_branch_outputs_data_.size());
655   for (size_t i = 0; i < false_branch_output_data_arrows_.size(); ++i) {
656     auto &data = false_branch_outputs_data_.at(i);
657     Async(false_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
658   }
659 }
660 
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)661 void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
662   auto op_uuid = context->sequential_num_;
663   input_op_datas_[op_uuid].push_back(inputs);
664   inputs_data_[inputs->index_] = inputs->data_;
665   if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
666     return;
667   }
668 
669   int ret = InitInputData();
670   if (ret != RET_OK) {
671     input_op_datas_.erase(op_uuid);
672     context->SetFailed(ret);
673     return;
674   }
675 
676   ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
677                   *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
678   if (ret != RET_OK) {
679     input_op_datas_.erase(op_uuid);
680     context->SetFailed(ret);
681     return;
682   }
683   input_op_datas_.erase(op_uuid);
684 
685   auto cond_ptr = reinterpret_cast<bool *>(switch_node_->in_tensors()[kSwitchCondTensorIndex]->data());
686   if (cond_ptr == nullptr) {
687     MS_LOG(ERROR) << "switch cond input data is nullptr.";
688     context->SetFailed(RET_NULL_PTR);
689     return;
690   }
691   if (*cond_ptr) {
692     DecreaseFalseBranchInputTensor();
693     AsyncTrueBranchOutput(context);
694   } else {
695     DecreaseTrueBranchInputTensor();
696     AsyncFalseBranchOutput(context);
697   }
698 }
699 
700 #endif
701 
SetInputShape()702 void LiteOpActor::SetInputShape() {
703   for (size_t i = 0; i < inputs_data_.size(); ++i) {
704     auto &input_tensor = kernel_->in_tensors()[i];
705     if (input_tensor->shape() == inputs_data_[i]->shape()) {
706       continue;
707     }
708     MS_LOG(DEBUG) << "inputs_data_[" << i << "].shape: " << inputs_data_[i]->shape() << " vs kernel_->in_tensors()["
709                   << i << "].shape: " << kernel_->in_tensors()[i]->shape() << " are not equal.";
710     MS_LOG(DEBUG) << "this->kernel_->name(): " << this->kernel_->name();
711 
712     if (input_tensor->data_type() == kObjectTypeTensorType) {
713 #ifndef CONTROLFLOW_TENSORLIST_CLIP
714       auto input_tensorlist = reinterpret_cast<TensorList *>(input_tensor);
715       auto input_data_tensorlist = reinterpret_cast<TensorList *>(inputs_data_[i]);
716       input_tensorlist->FreeTensorListData();
717       input_tensorlist->set_element_shape(input_data_tensorlist->element_shape());
718       input_tensorlist->set_shape(input_data_tensorlist->shape());
719       std::vector<std::vector<int>> tensor_shape{};
720       std::transform(input_data_tensorlist->tensors().begin(), input_data_tensorlist->tensors().end(),
721                      std::back_inserter(tensor_shape), [](Tensor *tensor_item) { return tensor_item->shape(); });
722       input_tensorlist->MallocTensorListData(input_data_tensorlist->tensors_data_type(), tensor_shape);
723 #endif
724     } else {
725       input_tensor->set_shape(inputs_data_[i]->shape());
726       input_tensor->set_format(inputs_data_[i]->format());
727     }
728   }
729 }
730 
InitInputData()731 int LiteOpActor::InitInputData() {
732   SetInputShape();
733 
734   for (size_t i = 0; i < inputs_data_.size(); ++i) {
735     auto dst_tensor = kernel_->in_tensors()[i];
736     auto src_tensor = inputs_data_[i];
737     if (dst_tensor->init_ref_count() == 0) {
738       src_tensor->DecRefCount();
739       continue;
740     }
741 
742     if (NeedCastData(dst_tensor, src_tensor)) {
743       CastInputData(dst_tensor, src_tensor);
744       continue;
745     }
746 
747     /* same data-type  */
748     if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
749       // delegate graph kernel output tensor
750       SetInputData(dst_tensor, src_tensor);
751     } else {
752       MoveInputData(dst_tensor, src_tensor);
753     }
754   }
755   return RET_OK;
756 }
757 
AsyncOutput(OpContext<Tensor> * context)758 void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
759   for (size_t i = 0; i < output_data_arrows_.size(); i++) {
760     auto data = outputs_data_.at(i);
761     Async(output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
762   }
763 }
764 
AddResultIndex(size_t index)765 void LiteOpActor::AddResultIndex(size_t index) { results_index_.push_back(index); }
766 
SetOutputData(OpContext<Tensor> * context)767 void LiteOpActor::SetOutputData(OpContext<Tensor> *context) {
768   for (auto index : results_index_) {
769     context->SetResult(index, RET_OK);
770   }
771 }
772 
PrepareOutputData()773 int LiteOpActor::PrepareOutputData() {
774   outputs_data_.resize(output_data_arrows_.size());
775   for (size_t i = 0; i < output_data_arrows_.size(); i++) {
776     auto &arrow = output_data_arrows_[i];
777     auto data =
778       std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
779                                        static_cast<int>(arrow->to_input_index_));
780     if (data == nullptr) {
781       MS_LOG(ERROR) << "new output_data failed.";
782       return RET_NULL_PTR;
783     }
784     outputs_data_.at(i) = data;
785   }
786   return RET_OK;
787 }
788 
CreateOpActor(const std::vector<kernel::LiteKernel * > & kernels,const lite::InnerContext * ctx)789 std::vector<std::shared_ptr<LiteOpActor>> CreateOpActor(const std::vector<kernel::LiteKernel *> &kernels,
790                                                         const lite::InnerContext *ctx) {
791   std::vector<std::shared_ptr<LiteOpActor>> actors;
792   std::unordered_map<kernel::LiteKernel *, AID> subgraph_name_AID_map{};
793   ActorThreadPool *thread_pool = reinterpret_cast<ActorThreadPool *>(ctx->thread_pool());
794   if (thread_pool == nullptr) {
795     MS_LOG(ERROR) << "thread pool is nullptr";
796     return actors;
797   }
798   for (auto &kernel : kernels) {
799     /* make subgraph name (actor name) unique */
800     kernel->set_name(kernel->name() + "_" + to_string(actor_count++));
801 #ifndef CONTROLFLOW_TENSORLIST_CLIP
802     if ((kernel::LiteKernelUtil::IsSwitchCall(kernel))) {
803       auto switch_actor = std::make_shared<LiteSwitchOpActor>(kernel);
804       if (switch_actor == nullptr) {
805         MS_LOG(ERROR) << "create LiteSwitchOpActor failed: " << kernel->name();
806         actors.clear();
807         return actors;
808       }
809       switch_actor->set_thread_pool(thread_pool);
810       subgraph_name_AID_map[kernel] = switch_actor->GetAID();
811       actors.push_back(switch_actor);
812     } else {
813 #endif
814       auto actor = std::make_shared<LiteOpActor>(kernel);
815       if (actor == nullptr) {
816         MS_LOG(ERROR) << "create LiteOpActor failed: " << kernel->name();
817         actors.clear();
818         return actors;
819       }
820       actor->set_thread_pool(thread_pool);
821       subgraph_name_AID_map[kernel] = actor->GetAID();
822       actors.push_back(actor);
823 #ifndef CONTROLFLOW_TENSORLIST_CLIP
824     }
825 #endif
826   }
827 
828   for (auto &actor : actors) {
829     actor->SetSubgraphAIDMap(subgraph_name_AID_map);
830     auto aid = mindspore::Spawn(actor);
831   }
832   return actors;
833 }
834 
MindrtInit()835 int MindrtInit() { return mindspore::Initialize("", "", "", ""); }
836 
MindrtTerminate(const std::vector<std::shared_ptr<LiteOpActor>> & actor_list)837 void MindrtTerminate(const std::vector<std::shared_ptr<LiteOpActor>> &actor_list) {
838   for (const auto &actor : actor_list) {
839     mindspore::Terminate(actor->GetAID());
840   }
841 }
842 }  // namespace mindspore::lite
843