1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <utility>
18 #include <algorithm>
19 #include "src/lite_mindrt.h"
20 #include "mindrt/include/mindrt.hpp"
21 #include "src/lite_kernel_util.h"
22 #include "src/common/tensor_util.h"
23 #include "src/runtime/inner_allocator.h"
24 #include "src/runtime/kernel/arm/base/partial_fusion.h"
25 #ifdef ENABLE_FP16
26 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
27 #endif
28
29 namespace mindspore::lite {
RunOpData(OpData<lite::Tensor> * inputs,OpContext<lite::Tensor> * context)30 void LiteOpActor::RunOpData(OpData<lite::Tensor> *inputs, OpContext<lite::Tensor> *context) {
31 auto op_uuid = context->sequential_num_;
32 input_op_datas_[op_uuid].push_back(inputs);
33 inputs_data_[inputs->index_] = inputs->data_;
34 if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
35 return;
36 }
37
38 auto ret = InitInputData();
39 if (ret != RET_OK) {
40 input_op_datas_.erase(op_uuid);
41 context->SetFailed(ret);
42 return;
43 }
44
45 ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
46 *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
47 if (ret != RET_OK) {
48 input_op_datas_.erase(op_uuid);
49 context->SetFailed(ret);
50 return;
51 }
52 input_op_datas_.erase(op_uuid);
53 AsyncOutput(context);
54
55 SetOutputData(context);
56
57 return;
58 }
59
OfflineIsolated(const std::vector<kernel::LiteKernel * > & kernels,const kernel::LiteKernel & this_kernel,const lite::Tensor & this_input_tensor)60 bool OfflineIsolated(const std::vector<kernel::LiteKernel *> &kernels, const kernel::LiteKernel &this_kernel,
61 const lite::Tensor &this_input_tensor) {
62 if (this_input_tensor.IsGraphInput()) {
63 return false;
64 }
65 for (auto &kernel : kernels) {
66 if (kernel == &this_kernel) {
67 continue;
68 }
69 if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(),
70 [&this_input_tensor](lite::Tensor *tensor) { return tensor == &this_input_tensor; })) {
71 return false;
72 }
73 }
74 return true;
75 }
76
ReplaceNodeInTensor(kernel::LiteKernel * kernel,Tensor * old_tensor,Tensor * new_tensor)77 void LiteOpActor::ReplaceNodeInTensor(kernel::LiteKernel *kernel, Tensor *old_tensor, Tensor *new_tensor) {
78 int ref_count = 0;
79 #ifndef DELEGATE_CLIP
80 /* set op input for calculate */
81 if (kernel->desc().arch == kernel::kDelegate) {
82 ref_count++;
83 } else {
84 #endif
85 for (auto in_node : reinterpret_cast<kernel::SubGraphKernel *>(kernel)->in_nodes()) {
86 for (size_t node_in_index = 0; node_in_index < in_node->in_tensors().size(); node_in_index++) {
87 if (old_tensor == in_node->in_tensors()[node_in_index]) {
88 in_node->set_in_tensor(new_tensor, node_in_index);
89 ref_count++;
90 }
91 }
92 }
93 #ifndef DELEGATE_CLIP
94 }
95 #endif
96 new_tensor->set_init_ref_count(ref_count);
97 }
98
IsolateInputData(std::vector<std::shared_ptr<LiteOpActor>> * actors)99 int LiteOpActor::IsolateInputData(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
100 std::vector<kernel::LiteKernel *> kernels{};
101 std::transform(actors->begin(), actors->end(), std::back_inserter(kernels),
102 [](std::shared_ptr<LiteOpActor> actor) { return actor->kernel_; });
103 size_t in_tensor_size = kernel_->in_tensors().size();
104 for (size_t i = 0; i < in_tensor_size; i++) {
105 Tensor *old_tensor = kernel_->in_tensors()[i];
106
107 if (OfflineIsolated(kernels, *kernel_, *old_tensor)) {
108 if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
109 old_tensor->set_data_type(kernel_->desc().data_type);
110 }
111 #ifndef CONTROLFLOW_TENSORLIST_CLIP
112 if (old_tensor->data_type() == kObjectTypeTensorType) {
113 auto old_tensorlist = reinterpret_cast<TensorList *>(old_tensor);
114 if (old_tensorlist->tensors_data_type() == kNumberTypeFloat16 ||
115 old_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
116 old_tensorlist->set_tensors_data_type(kernel_->desc().data_type);
117 }
118 }
119 #endif
120 old_tensor->set_allocator(kernel_->Context()->allocator);
121 continue;
122 }
123
124 TypeId new_data_type = old_tensor->data_type();
125 if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
126 new_data_type = kernel_->desc().data_type;
127 }
128
129 Tensor *new_tensor = new Tensor(new_data_type, old_tensor->shape(), old_tensor->format(), old_tensor->category());
130 if (new_tensor == nullptr) {
131 MS_LOG(ERROR) << "new Tensor failed.";
132 return RET_NULL_PTR;
133 }
134 new_tensor->set_allocator(old_tensor->allocator());
135 if (new_tensor->allocator() == nullptr && kernel_->Context() != nullptr &&
136 kernel_->desc().arch != kernel::kDelegate) {
137 new_tensor->set_allocator(kernel_->Context()->allocator);
138 }
139
140 new_tensor->set_tensor_name(kernel_->name() + "_duplicate_" + old_tensor->tensor_name());
141 for (LiteQuantParam quant : old_tensor->quant_params()) {
142 new_tensor->AddQuantParam(quant);
143 }
144 isolate_input_map_.insert(std::make_pair(new_tensor, old_tensor));
145 ReplaceNodeInTensor(kernel_, old_tensor, new_tensor);
146 /* set subgraph input for copy data */
147 kernel_->set_in_tensor(new_tensor, i);
148 }
149 return RET_OK;
150 }
151
LiteActorInit(std::vector<std::shared_ptr<LiteOpActor>> * actors)152 int LiteOpActor::LiteActorInit(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
153 /* Init output arrow */
154 auto ret = CompileArrow();
155 if (ret != RET_OK) {
156 MS_LOG(ERROR) << "compile arrow failed.";
157 return ret;
158 }
159
160 /* Init Actor output data */
161 ret = PrepareOutputData();
162 if (ret != RET_OK) {
163 MS_LOG(ERROR) << "prepare output data failed.";
164 return ret;
165 }
166
167 /* subgraph transaction isolation */
168 ret = IsolateInputData(actors);
169 if (ret != RET_OK) {
170 MS_LOG(ERROR) << "isolate input data failed.";
171 return ret;
172 }
173 return RET_OK;
174 }
175
ResizeGraphInput(const std::vector<mindspore::tensor::MSTensor * > & inputs,const std::vector<std::vector<int>> & dims)176 int LiteOpActor::ResizeGraphInput(const std::vector<mindspore::tensor::MSTensor *> &inputs,
177 const std::vector<std::vector<int>> &dims) {
178 for (auto map : isolate_input_map_) {
179 auto isolate_tensor = map.first;
180 auto src_tensor = map.second;
181 for (size_t i = 0; i < inputs.size(); i++) {
182 if (src_tensor == inputs[i]) {
183 isolate_tensor->set_shape(dims[i]);
184 }
185 }
186 }
187 return RET_OK;
188 }
189
CompileArrowThroughOutputKernels()190 int LiteOpActor::CompileArrowThroughOutputKernels() {
191 output_data_arrows_.clear();
192 int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
193 for (int i = 0; i < out_tensor_size; i++) {
194 for (auto out : kernel_->out_kernels()) {
195 int in_tensor_size = static_cast<int>(out->in_tensors().size());
196 int to_input_index = -1;
197 for (int j = 0; j < in_tensor_size; j++) {
198 if (kernel_->out_tensors()[i] == out->in_tensors()[j]) {
199 to_input_index = j;
200 break;
201 }
202 }
203 if (to_input_index == -1) {
204 continue;
205 }
206 auto id = out->name() + this->GetAID().Url();
207 auto arrow = std::make_shared<DataArrow>(i, AID(id), to_input_index);
208 if (arrow == nullptr) {
209 MS_LOG(ERROR) << "create DataArrow failed, out kernel: " << out->name();
210 return RET_ERROR;
211 }
212 output_data_arrows_.emplace_back(std::move(arrow));
213 }
214 }
215 return RET_OK;
216 }
217
218 #ifndef CONTROLFLOW_TENSORLIST_CLIP
CompileArrowThroughPartialCall()219 int LiteOpActor::CompileArrowThroughPartialCall() {
220 #ifndef DELEGATE_CLIP
221 if (kernel_->desc().arch == kernel::kDelegate) {
222 MS_LOG(INFO) << "kernel is delegate subgraph kernel.";
223 return RET_OK;
224 }
225 #endif
226 auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
227 if (subgraph_kernel == nullptr) {
228 MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
229 return RET_OK;
230 }
231 for (auto &node : subgraph_kernel->nodes()) {
232 if (node->type() != schema::PrimitiveType_Call) {
233 continue;
234 }
235 call_node_ = node;
236 auto partial_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_PartialFusion);
237 if (!partial_node) {
238 continue;
239 }
240 partial_node_ = partial_node;
241 auto subgraph = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel())->subgraph_kernel();
242 auto out_actor_id = subgraph_to_actor_.at(subgraph);
243
244 kernel_->set_out_tensors(partial_node->in_tensors());
245 for (size_t i = 0; i < partial_node->in_tensors().size(); ++i) {
246 auto arrow = std::make_shared<DataArrow>(i, out_actor_id, i);
247 if (arrow == nullptr) {
248 MS_LOG(ERROR) << "create DataArrow failed";
249 return RET_ERROR;
250 }
251 output_data_arrows_.emplace_back(std::move(arrow));
252 }
253 }
254
255 subgraph_kernel->DropNode(partial_node_);
256 subgraph_kernel->DropNode(call_node_);
257 return RET_OK;
258 }
259 #endif
260
CompileArrow()261 int LiteOpActor::CompileArrow() {
262 int ret;
263 output_data_arrows_.clear();
264 #ifndef CONTROLFLOW_TENSORLIST_CLIP
265 ret = CompileArrowThroughPartialCall();
266 if (ret != RET_OK) {
267 output_data_arrows_.clear();
268 MS_LOG(ERROR) << "CompileArrowThroughPartialCall failed.";
269 return ret;
270 }
271 if (!output_data_arrows_.empty()) {
272 MS_LOG(INFO) << "CompileArrowThroughPartialCall done.";
273 return RET_OK;
274 }
275 #endif
276 ret = CompileArrowThroughOutputKernels();
277 if (ret != RET_OK) {
278 output_data_arrows_.clear();
279 MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
280 return ret;
281 }
282 return ret;
283 }
284
MoveTensorInputData(Tensor * dst_tensor,Tensor * src_tensor)285 void LiteOpActor::MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor) {
286 MS_ASSERT(src_tensor != dst_tensor);
287 dst_tensor->FreeData();
288 dst_tensor->ResetRefCount();
289 dst_tensor->set_allocator(src_tensor->allocator());
290
291 src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
292
293 if (src_tensor->data() != nullptr) {
294 dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
295 }
296
297 dst_tensor->set_own_data(src_tensor->own_data());
298 src_tensor->DecRefCount();
299 }
300
MoveInputData(Tensor * dst_tensor,Tensor * src_tensor)301 void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
302 if (src_tensor == dst_tensor) {
303 MS_LOG(INFO) << "no need to move.";
304 return;
305 }
306 MS_ASSERT(src_tensor->allocator() != nullptr);
307 #ifndef CONTROLFLOW_TENSORLIST_CLIP
308 if (src_tensor->data_type() == kObjectTypeTensorType) {
309 MoveTensorListInputData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
310 } else {
311 MoveTensorInputData(dst_tensor, src_tensor);
312 }
313 #else
314 MoveTensorInputData(dst_tensor, src_tensor);
315 #endif
316 return;
317 }
318
SetInputData(Tensor * dst_tensor,Tensor * src_tensor)319 void LiteOpActor::SetInputData(Tensor *dst_tensor, Tensor *src_tensor) {
320 dst_tensor->set_data(src_tensor->data());
321 dst_tensor->set_own_data(false);
322 }
323
CastInputData(Tensor * dst,Tensor * src)324 int LiteOpActor::CastInputData(Tensor *dst, Tensor *src) {
325 int ret = RET_OK;
326 #ifndef CONTROLFLOW_TENSORLIST_CLIP
327 if (src->data_type() != kObjectTypeTensorType) {
328 ret = CastTensorInputData(dst, src);
329 } else {
330 ret = CastTensorListInputData(reinterpret_cast<TensorList *>(dst), reinterpret_cast<TensorList *>(src));
331 }
332 #else
333 ret = CastTensorInputData(dst, src);
334 #endif
335 src->DecRefCount();
336 return ret;
337 }
338
NeedCastData(Tensor * dst_tensor,Tensor * src_tensor)339 bool LiteOpActor::NeedCastData(Tensor *dst_tensor, Tensor *src_tensor) {
340 if (dst_tensor->data_type() != kObjectTypeTensorType && src_tensor->data_type() != kObjectTypeTensorType &&
341 dst_tensor->data_type() != src_tensor->data_type()) {
342 return true;
343 }
344 #ifndef CONTROLFLOW_TENSORLIST_CLIP
345 if (dst_tensor->data_type() == kObjectTypeTensorType && src_tensor->data_type() == kObjectTypeTensorType &&
346 reinterpret_cast<TensorList *>(dst_tensor)->tensors_data_type() !=
347 reinterpret_cast<TensorList *>(src_tensor)->tensors_data_type()) {
348 return true;
349 }
350 #endif
351 return false;
352 }
353
CastTensorInputData(Tensor * dst,Tensor * src)354 int LiteOpActor::CastTensorInputData(Tensor *dst, Tensor *src) {
355 dst->MallocData();
356 dst->ResetRefCount();
357 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
358 if (dst->shape() != src->shape()) {
359 MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs "
360 << "src tensor: " << src->tensor_name() << " shape: " << src->shape();
361 return RET_PARAM_INVALID;
362 }
363 auto dst_data = dst->MutableData(); /* using MutableData to sync GPU data */
364 auto src_data = src->MutableData();
365 auto src_nums_size = src->ElementsNum();
366 auto dst_data_type = static_cast<int>(dst->data_type());
367 auto src_data_type = static_cast<int>(src->data_type());
368 if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) {
369 Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
370 } else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) {
371 Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
372 } else {
373 MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type;
374 return RET_NOT_SUPPORT;
375 }
376 return RET_OK;
377 #endif
378 return RET_ERROR;
379 }
380
381 #ifndef CONTROLFLOW_TENSORLIST_CLIP
MoveTensorListInputData(TensorList * dst_tensorlist,TensorList * src_tensorlist)382 void LiteOpActor::MoveTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
383 MS_ASSERT(src_tensorlist != nullptr);
384 MS_ASSERT(dst_tensorlist != nullptr);
385 dst_tensorlist->FreeData();
386 dst_tensorlist->ResetRefCount();
387 dst_tensorlist->set_allocator(src_tensorlist->allocator());
388
389 auto src_tensorlist_tensors_size = src_tensorlist->tensors().size();
390 auto dst_tensorlist_tensors_size = dst_tensorlist->tensors().size();
391 if (src_tensorlist_tensors_size != dst_tensorlist_tensors_size) {
392 MS_LOG(ERROR) << "src tensorlist: " << src_tensorlist->tensor_name()
393 << " tesnors size: " << src_tensorlist_tensors_size
394 << " vs dst tensorlist: " << src_tensorlist->tensor_name()
395 << " tensors size: " << dst_tensorlist_tensors_size;
396 return;
397 }
398
399 dst_tensorlist->set_own_data(src_tensorlist->own_data());
400 for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
401 auto &src_tensor = src_tensorlist->tensors()[i];
402 auto &dst_tensor = dst_tensorlist->tensors()[i];
403
404 if (src_tensor->allocator() != nullptr) {
405 src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
406 }
407 dst_tensor->set_own_data(src_tensor->own_data());
408 if (src_tensor->data() != nullptr) {
409 dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
410 }
411 dst_tensor->set_shape(src_tensor->shape());
412 }
413
414 if (src_tensorlist->IsConst() || src_tensorlist->IsGraphInput()) {
415 dst_tensorlist->set_own_data(false);
416 } else {
417 src_tensorlist->DecRefCount();
418 }
419 }
420
CastTensorListInputData(TensorList * dst_tensorlist,TensorList * src_tensorlist)421 int LiteOpActor::CastTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
422 MS_ASSERT(src_tensorlist != nullptr);
423 MS_ASSERT(dst_tensorlist != nullptr);
424 dst_tensorlist->set_shape(src_tensorlist->shape());
425 std::vector<std::vector<int>> tensors_shapes{};
426 tensors_shapes.resize(src_tensorlist->tensors().size());
427 for (size_t i = 0; i < tensors_shapes.size(); ++i) {
428 tensors_shapes[i] = src_tensorlist->tensors()[i]->shape();
429 }
430 if (src_tensorlist->tensors_data_type() == kNumberTypeFloat16) {
431 dst_tensorlist->MallocTensorListData(kNumberTypeFloat32, tensors_shapes);
432 }
433 if (src_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
434 dst_tensorlist->MallocTensorListData(kNumberTypeFloat16, tensors_shapes);
435 }
436 dst_tensorlist->set_allocator(src_tensorlist->allocator());
437 dst_tensorlist->ResetRefCount();
438
439 for (size_t i = 0; i < src_tensorlist->tensors().size(); ++i) {
440 auto &src_tensor = src_tensorlist->tensors()[i];
441 auto &dst_tensor = dst_tensorlist->tensors()[i];
442 CastTensorInputData(dst_tensor, src_tensor);
443 }
444 return RET_OK;
445 }
446
CompileTrueBranchArrow()447 int LiteSwitchOpActor::CompileTrueBranchArrow() {
448 if (true_partial_node_ == nullptr) {
449 MS_LOG(ERROR) << "true_partial_node_ is nullptr.";
450 return RET_NULL_PTR;
451 }
452 auto subgraph = static_cast<kernel::PartialFusionKernel *>(true_partial_node_->kernel())->subgraph_kernel();
453 auto true_branch_actor_id = subgraph_to_actor_.at(subgraph);
454
455 for (size_t i = 0; i < true_partial_node_->in_tensors().size(); ++i) {
456 int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
457 for (int j = 0; j < out_tensor_size; ++j) {
458 if (true_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
459 continue;
460 }
461 auto arrow = std::make_shared<DataArrow>(j, true_branch_actor_id, i);
462 if (arrow == nullptr) {
463 MS_LOG(ERROR) << "create DataArrow failed";
464 return RET_ERROR;
465 }
466 true_branch_output_data_arrows_.emplace_back(std::move(arrow));
467 }
468 }
469 return RET_OK;
470 }
471
CompileFalseBranchArrow()472 int LiteSwitchOpActor::CompileFalseBranchArrow() {
473 if (false_partial_node_ == nullptr) {
474 MS_LOG(ERROR) << "false_partial_node_ is nullptr.";
475 return RET_NULL_PTR;
476 }
477 auto subgraph = static_cast<kernel::PartialFusionKernel *>(false_partial_node_->kernel())->subgraph_kernel();
478 auto false_branch_actor_id = subgraph_to_actor_.at(subgraph);
479
480 for (size_t i = 0; i < false_partial_node_->in_tensors().size(); ++i) {
481 int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
482 for (int j = 0; j < out_tensor_size; ++j) {
483 if (false_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
484 continue;
485 }
486 auto arrow = std::make_shared<DataArrow>(j, false_branch_actor_id, i);
487 if (arrow == nullptr) {
488 MS_LOG(ERROR) << "create DataArrow failed";
489 return RET_ERROR;
490 }
491 false_branch_output_data_arrows_.emplace_back(std::move(arrow));
492 }
493 }
494 return RET_OK;
495 }
496
GetSwitchAndCallNode(kernel::SubGraphKernel * subgraph_kernel)497 int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) {
498 for (auto &node : subgraph_kernel->nodes()) {
499 if (node->type() != schema::PrimitiveType_Call) {
500 continue;
501 }
502 call_node_ = node;
503 auto switch_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch);
504 if (!switch_node) {
505 continue;
506 }
507
508 if (switch_node->in_tensors().size() < kSwitchMinInputTensorSize) {
509 MS_LOG(ERROR) << "actor name: " << this->GetAID() << "'s switch node " << switch_node->name()
510 << " input tensor size: " << switch_node->in_tensors().size() << " is less than 3.";
511 return RET_ERROR;
512 }
513
514 switch_node_ = switch_node;
515 if (switch_node->in_kernels().size() == kSwitchMaxInputKernelSize) {
516 true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex);
517 false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex);
518 }
519
520 if (switch_node->in_kernels().size() == kSwitchMinInputKernelSize) {
521 true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex - 1);
522 false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex - 1);
523 }
524 break;
525 }
526 return RET_OK;
527 }
528
AppendOutputTensors()529 void LiteSwitchOpActor::AppendOutputTensors() {
530 for (auto &tensor : true_partial_node_->in_tensors()) {
531 if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
532 output_tensors_.push_back(tensor);
533 }
534 }
535 for (auto &tensor : false_partial_node_->in_tensors()) {
536 if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
537 output_tensors_.push_back(tensor);
538 }
539 }
540 kernel_->set_out_tensors(output_tensors_);
541 }
542
CompileArrowThroughSwitchCall()543 int LiteSwitchOpActor::CompileArrowThroughSwitchCall() {
544 auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
545 if (subgraph_kernel == nullptr) {
546 MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
547 return RET_OK;
548 }
549
550 int ret = GetSwitchAndCallNode(subgraph_kernel);
551 if (ret != RET_OK) {
552 MS_LOG(ERROR) << "GetSwitchAndCallCnode failed.";
553 return ret;
554 }
555
556 AppendOutputTensors();
557
558 ret = CompileTrueBranchArrow();
559 if (ret != RET_OK) {
560 MS_LOG(ERROR) << "CompileTrueBranchArrow failed.";
561 true_branch_output_data_arrows_.clear();
562 return ret;
563 }
564
565 ret = CompileFalseBranchArrow();
566 if (ret != RET_OK) {
567 MS_LOG(ERROR) << "CompileFalseBranchArrow failed.";
568 false_branch_output_data_arrows_.clear();
569 true_branch_output_data_arrows_.clear();
570 return ret;
571 }
572
573 subgraph_kernel->DropNode(call_node_);
574 subgraph_kernel->DropNode(switch_node_);
575 subgraph_kernel->DropNode(true_partial_node_);
576 subgraph_kernel->DropNode(false_partial_node_);
577
578 return ret;
579 }
580
CompileArrow()581 int LiteSwitchOpActor::CompileArrow() {
582 int ret = CompileArrowThroughSwitchCall();
583 if (ret != RET_OK) {
584 true_branch_output_data_arrows_.clear();
585 false_branch_output_data_arrows_.clear();
586 MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed.";
587 return ret;
588 }
589 if (!true_branch_output_data_arrows_.empty() && !false_branch_output_data_arrows_.empty()) {
590 MS_LOG(INFO) << "CompileArrowThroughSwitchCall done.";
591 return RET_OK;
592 }
593 ret = CompileArrowThroughOutputKernels();
594 if (ret != RET_OK) {
595 output_data_arrows_.clear();
596 MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
597 return ret;
598 }
599 return ret;
600 }
601
PrepareOutputData()602 int LiteSwitchOpActor::PrepareOutputData() {
603 true_branch_outputs_data_.resize(true_branch_output_data_arrows_.size());
604 for (size_t i = 0; i < true_branch_output_data_arrows_.size(); i++) {
605 auto &arrow = true_branch_output_data_arrows_[i];
606 auto data =
607 std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
608 static_cast<int>(arrow->to_input_index_));
609 if (data == nullptr) {
610 MS_LOG(ERROR) << "new true_branch_output_data failed.";
611 return RET_NULL_PTR;
612 }
613 true_branch_outputs_data_.at(i) = data;
614 }
615
616 false_branch_outputs_data_.resize(false_branch_output_data_arrows_.size());
617 for (size_t i = 0; i < false_branch_output_data_arrows_.size(); i++) {
618 auto &arrow = false_branch_output_data_arrows_[i];
619 auto data =
620 std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
621 static_cast<int>(arrow->to_input_index_));
622 if (data == nullptr) {
623 MS_LOG(ERROR) << "new false_branch_output_data failed.";
624 return RET_NULL_PTR;
625 }
626 false_branch_outputs_data_.at(i) = data;
627 }
628 return RET_OK;
629 }
630
DecreaseTrueBranchInputTensor()631 void LiteSwitchOpActor::DecreaseTrueBranchInputTensor() {
632 switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
633 for (auto input : true_partial_node_->in_tensors()) {
634 input->DecRefCount();
635 }
636 }
637
DecreaseFalseBranchInputTensor()638 void LiteSwitchOpActor::DecreaseFalseBranchInputTensor() {
639 switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
640 for (auto input : false_partial_node_->in_tensors()) {
641 input->DecRefCount();
642 }
643 }
644
AsyncTrueBranchOutput(OpContext<Tensor> * context)645 void LiteSwitchOpActor::AsyncTrueBranchOutput(OpContext<Tensor> *context) {
646 MS_ASSERT(true_branch_output_data_arrows_.size() == true_branch_outputs_data_.size());
647 for (size_t i = 0; i < true_branch_output_data_arrows_.size(); ++i) {
648 auto &data = true_branch_outputs_data_.at(i);
649 Async(true_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
650 }
651 }
652
AsyncFalseBranchOutput(OpContext<Tensor> * context)653 void LiteSwitchOpActor::AsyncFalseBranchOutput(OpContext<Tensor> *context) {
654 MS_ASSERT(false_branch_output_data_arrows_.size() == false_branch_outputs_data_.size());
655 for (size_t i = 0; i < false_branch_output_data_arrows_.size(); ++i) {
656 auto &data = false_branch_outputs_data_.at(i);
657 Async(false_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
658 }
659 }
660
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)661 void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
662 auto op_uuid = context->sequential_num_;
663 input_op_datas_[op_uuid].push_back(inputs);
664 inputs_data_[inputs->index_] = inputs->data_;
665 if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
666 return;
667 }
668
669 int ret = InitInputData();
670 if (ret != RET_OK) {
671 input_op_datas_.erase(op_uuid);
672 context->SetFailed(ret);
673 return;
674 }
675
676 ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
677 *(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
678 if (ret != RET_OK) {
679 input_op_datas_.erase(op_uuid);
680 context->SetFailed(ret);
681 return;
682 }
683 input_op_datas_.erase(op_uuid);
684
685 auto cond_ptr = reinterpret_cast<bool *>(switch_node_->in_tensors()[kSwitchCondTensorIndex]->data());
686 if (cond_ptr == nullptr) {
687 MS_LOG(ERROR) << "switch cond input data is nullptr.";
688 context->SetFailed(RET_NULL_PTR);
689 return;
690 }
691 if (*cond_ptr) {
692 DecreaseFalseBranchInputTensor();
693 AsyncTrueBranchOutput(context);
694 } else {
695 DecreaseTrueBranchInputTensor();
696 AsyncFalseBranchOutput(context);
697 }
698 }
699
700 #endif
701
SetInputShape()702 void LiteOpActor::SetInputShape() {
703 for (size_t i = 0; i < inputs_data_.size(); ++i) {
704 auto &input_tensor = kernel_->in_tensors()[i];
705 if (input_tensor->shape() == inputs_data_[i]->shape()) {
706 continue;
707 }
708 MS_LOG(DEBUG) << "inputs_data_[" << i << "].shape: " << inputs_data_[i]->shape() << " vs kernel_->in_tensors()["
709 << i << "].shape: " << kernel_->in_tensors()[i]->shape() << " are not equal.";
710 MS_LOG(DEBUG) << "this->kernel_->name(): " << this->kernel_->name();
711
712 if (input_tensor->data_type() == kObjectTypeTensorType) {
713 #ifndef CONTROLFLOW_TENSORLIST_CLIP
714 auto input_tensorlist = reinterpret_cast<TensorList *>(input_tensor);
715 auto input_data_tensorlist = reinterpret_cast<TensorList *>(inputs_data_[i]);
716 input_tensorlist->FreeTensorListData();
717 input_tensorlist->set_element_shape(input_data_tensorlist->element_shape());
718 input_tensorlist->set_shape(input_data_tensorlist->shape());
719 std::vector<std::vector<int>> tensor_shape{};
720 std::transform(input_data_tensorlist->tensors().begin(), input_data_tensorlist->tensors().end(),
721 std::back_inserter(tensor_shape), [](Tensor *tensor_item) { return tensor_item->shape(); });
722 input_tensorlist->MallocTensorListData(input_data_tensorlist->tensors_data_type(), tensor_shape);
723 #endif
724 } else {
725 input_tensor->set_shape(inputs_data_[i]->shape());
726 input_tensor->set_format(inputs_data_[i]->format());
727 }
728 }
729 }
730
InitInputData()731 int LiteOpActor::InitInputData() {
732 SetInputShape();
733
734 for (size_t i = 0; i < inputs_data_.size(); ++i) {
735 auto dst_tensor = kernel_->in_tensors()[i];
736 auto src_tensor = inputs_data_[i];
737 if (dst_tensor->init_ref_count() == 0) {
738 src_tensor->DecRefCount();
739 continue;
740 }
741
742 if (NeedCastData(dst_tensor, src_tensor)) {
743 CastInputData(dst_tensor, src_tensor);
744 continue;
745 }
746
747 /* same data-type */
748 if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
749 // delegate graph kernel output tensor
750 SetInputData(dst_tensor, src_tensor);
751 } else {
752 MoveInputData(dst_tensor, src_tensor);
753 }
754 }
755 return RET_OK;
756 }
757
AsyncOutput(OpContext<Tensor> * context)758 void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
759 for (size_t i = 0; i < output_data_arrows_.size(); i++) {
760 auto data = outputs_data_.at(i);
761 Async(output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
762 }
763 }
764
AddResultIndex(size_t index)765 void LiteOpActor::AddResultIndex(size_t index) { results_index_.push_back(index); }
766
SetOutputData(OpContext<Tensor> * context)767 void LiteOpActor::SetOutputData(OpContext<Tensor> *context) {
768 for (auto index : results_index_) {
769 context->SetResult(index, RET_OK);
770 }
771 }
772
PrepareOutputData()773 int LiteOpActor::PrepareOutputData() {
774 outputs_data_.resize(output_data_arrows_.size());
775 for (size_t i = 0; i < output_data_arrows_.size(); i++) {
776 auto &arrow = output_data_arrows_[i];
777 auto data =
778 std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
779 static_cast<int>(arrow->to_input_index_));
780 if (data == nullptr) {
781 MS_LOG(ERROR) << "new output_data failed.";
782 return RET_NULL_PTR;
783 }
784 outputs_data_.at(i) = data;
785 }
786 return RET_OK;
787 }
788
CreateOpActor(const std::vector<kernel::LiteKernel * > & kernels,const lite::InnerContext * ctx)789 std::vector<std::shared_ptr<LiteOpActor>> CreateOpActor(const std::vector<kernel::LiteKernel *> &kernels,
790 const lite::InnerContext *ctx) {
791 std::vector<std::shared_ptr<LiteOpActor>> actors;
792 std::unordered_map<kernel::LiteKernel *, AID> subgraph_name_AID_map{};
793 ActorThreadPool *thread_pool = reinterpret_cast<ActorThreadPool *>(ctx->thread_pool());
794 if (thread_pool == nullptr) {
795 MS_LOG(ERROR) << "thread pool is nullptr";
796 return actors;
797 }
798 for (auto &kernel : kernels) {
799 /* make subgraph name (actor name) unique */
800 kernel->set_name(kernel->name() + "_" + to_string(actor_count++));
801 #ifndef CONTROLFLOW_TENSORLIST_CLIP
802 if ((kernel::LiteKernelUtil::IsSwitchCall(kernel))) {
803 auto switch_actor = std::make_shared<LiteSwitchOpActor>(kernel);
804 if (switch_actor == nullptr) {
805 MS_LOG(ERROR) << "create LiteSwitchOpActor failed: " << kernel->name();
806 actors.clear();
807 return actors;
808 }
809 switch_actor->set_thread_pool(thread_pool);
810 subgraph_name_AID_map[kernel] = switch_actor->GetAID();
811 actors.push_back(switch_actor);
812 } else {
813 #endif
814 auto actor = std::make_shared<LiteOpActor>(kernel);
815 if (actor == nullptr) {
816 MS_LOG(ERROR) << "create LiteOpActor failed: " << kernel->name();
817 actors.clear();
818 return actors;
819 }
820 actor->set_thread_pool(thread_pool);
821 subgraph_name_AID_map[kernel] = actor->GetAID();
822 actors.push_back(actor);
823 #ifndef CONTROLFLOW_TENSORLIST_CLIP
824 }
825 #endif
826 }
827
828 for (auto &actor : actors) {
829 actor->SetSubgraphAIDMap(subgraph_name_AID_map);
830 auto aid = mindspore::Spawn(actor);
831 }
832 return actors;
833 }
834
MindrtInit()835 int MindrtInit() { return mindspore::Initialize("", "", "", ""); }
836
MindrtTerminate(const std::vector<std::shared_ptr<LiteOpActor>> & actor_list)837 void MindrtTerminate(const std::vector<std::shared_ptr<LiteOpActor>> &actor_list) {
838 for (const auto &actor : actor_list) {
839 mindspore::Terminate(actor->GetAID());
840 }
841 }
842 } // namespace mindspore::lite
843