1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/scheduler.h"
18 #include <map>
19 #include <unordered_set>
20 #include <queue>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #include "src/tensorlist.h"
25 #include "nnacl/partial_fusion_parameter.h"
26 #include "include/errorcode.h"
27 #include "src/common/graph_util.h"
28 #include "src/common/utils.h"
29 #include "src/litert/kernel_registry.h"
30 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
31 #include "include/registry/register_kernel.h"
32 #endif
33 #include "src/litert/kernel_exec_util.h"
34 #include "src/executor/sub_graph_kernel.h"
35 #include "src/common/ops/populate/populate_register.h"
36 #include "src/common/version_manager.h"
37 #include "src/common/prim_util.h"
38 #include "src/litert/lite_model.h"
39 #include "src/common/tensor_util.h"
40 #include "src/common/context_util.h"
41 #include "src/litert/infer_manager.h"
42 #include "src/litert/runtime_pass.h"
43 #ifndef ENABLE_MULTI_LAYOUT
44 #include "src/litert/pass/format_pass/format_pass.h"
45 #endif
46 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
47 #include "src/litert/sub_graph_split.h"
48 #include "src/litert/pass/online_fusion/online_fusion_pass_registry.h"
49 #endif
50 #include "src/litert/weight_decoder.h"
51 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
52 #include "nnacl/nnacl_common.h"
53 #if GPU_OPENCL
54 #include "src/litert/kernel/opencl/opencl_subgraph.h"
55 #include "src/litert/kernel/gpu/opencl/opencl_runtime.h"
56 #endif
57 #include "include/registry/register_kernel_interface.h"
58 #include "extendrt/mindir_loader/abstract_base_model.h"
59 #include "src/litert/pack_weight_manager.h"
60 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
61 #include "thread/parallel_thread_pool_manager.h"
62 #endif
63 #ifdef SUPPORT_NNRT
64 #include "src/litert/delegate/nnrt/nnrt_delegate.h"
65 #endif
66
67 using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
68
69 namespace mindspore::lite {
70 namespace {
71 constexpr int kMainSubGraphIndex = 0;
72 } // namespace
73
74 namespace {
75 // support_fp16: current device and package support float16
CastKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel,bool support_fp16)76 int CastKernelWeight(const kernel::SubGraphType &belong_subgraph_type, const kernel::KernelExec *kernel,
77 bool support_fp16) {
78 MS_ASSERT(kernel != nullptr);
79 MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
80 if (belong_subgraph_type != kernel::kCpuFP32SubGraph && belong_subgraph_type != kernel::kCpuFP16SubGraph) {
81 return RET_OK;
82 }
83 for (auto *tensor : kernel->in_tensors()) {
84 MS_ASSERT(tensor != nullptr);
85 // only cast const tensor
86 // tensorlist not support fp16 now
87 if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
88 continue;
89 }
90 // only support fp32->fp16 or fp16->fp32
91 if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
92 continue;
93 }
94 if (tensor->data_type() == kNumberTypeFloat32 && belong_subgraph_type == kernel::kCpuFP16SubGraph) {
95 auto ret = CastConstTensorData(tensor, kNumberTypeFloat16, support_fp16);
96 if (ret != RET_OK) {
97 MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
98 return ret;
99 }
100 } else if (tensor->data_type() == kNumberTypeFloat16 && belong_subgraph_type == kernel::kCpuFP32SubGraph) {
101 auto ret = CastConstTensorData(tensor, kNumberTypeFloat32, support_fp16);
102 if (ret != RET_OK) {
103 MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
104 return ret;
105 }
106 } else {
107 MS_LOG(DEBUG) << "No need to cast";
108 }
109 }
110 return RET_OK;
111 }
112
CopyConstTensorData(const std::vector<Tensor * > & tensors,int op_type)113 int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
114 // packed kernels such as conv don't need to copy because weight will be packed in kernel
115 if (lite::PackWeightManager::GetInstance()->IsCopyTensor(op_type)) {
116 return RET_OK;
117 }
118
119 for (auto *tensor : tensors) {
120 // only copy non-copied const tensor
121 if (!tensor->IsConst() && tensor->data() != nullptr) {
122 MS_LOG(ERROR) << "Illegitimate tensor : " << tensor->tensor_name();
123 return RET_ERROR;
124 }
125 if (!tensor->IsConst() || tensor->own_data()) {
126 continue;
127 }
128 if (tensor->data_type() == kObjectTypeTensorType) {
129 // tensorlist's data is nullptr since ConvertTensors
130 // we never set or malloc data of tensorlist but malloc tensors in tensorlist
131 MS_ASSERT(tensor->data() == nullptr);
132 } else {
133 auto copy_tensor = Tensor::CopyTensor(*tensor, true);
134 if (copy_tensor == nullptr) {
135 MS_LOG(ERROR) << "Copy tensor failed";
136 return RET_ERROR;
137 }
138 tensor->FreeData();
139 tensor->set_data(copy_tensor->data());
140 tensor->set_own_data(true);
141 copy_tensor->set_data(nullptr);
142 delete (copy_tensor);
143 }
144 }
145 return RET_OK;
146 }
147 } // namespace
148
149 // support_fp16: current device and package support float16
HandleBuildinCpuKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel)150 int Scheduler::HandleBuildinCpuKernelWeight(const kernel::SubGraphType &belong_subgraph_type,
151 const kernel::KernelExec *kernel) {
152 MS_ASSERT(kernel != nullptr);
153 MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
154 if (is_train_session_ || kernel->type() == schema::PrimitiveType_Custom ||
155 kernel->desc().provider != kernel::kBuiltin) {
156 return RET_OK;
157 }
158 auto ret = CastKernelWeight(belong_subgraph_type, kernel, context_->device_and_pkg_support_fp16_);
159 if (ret != RET_OK) {
160 MS_LOG(DEBUG) << "CastKernelWeight failed: " << ret;
161 return RET_NOT_SUPPORT;
162 }
163 if (!(reinterpret_cast<LiteModel *>(src_model_)->keep_model_buf())) {
164 // we don't need to restore tensor for copy data
165 MS_CHECK_TRUE_RET(kernel->op_parameter() != nullptr, RET_ERROR);
166 ret = CopyConstTensorData(kernel->in_tensors(), kernel->op_parameter()->type_);
167 if (ret != RET_OK) {
168 MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
169 return RET_NOT_SUPPORT;
170 }
171 }
172 return RET_OK;
173 }
174
InitKernels(std::vector<kernel::KernelExec * > && dst_kernels)175 int Scheduler::InitKernels(std::vector<kernel::KernelExec *> &&dst_kernels) {
176 if (is_train_session_) {
177 return RET_OK;
178 }
179 for (auto kernel : dst_kernels) {
180 // delegate graph kernel
181 if (kernel->desc().arch == kernel::kDelegate) {
182 continue;
183 }
184 auto subgraph_type = kernel->subgraph_type();
185 if (subgraph_type == kernel::kNotSubGraph) {
186 MS_LOG(ERROR) << "construct subgraph failed.";
187 return RET_ERROR;
188 }
189 auto subgraph_nodes = reinterpret_cast<kernel::SubGraphKernel *>(kernel)->nodes();
190 for (auto node : subgraph_nodes) {
191 for (auto *tensor : node->out_tensors()) {
192 if (tensor->IsConst()) {
193 MS_CHECK_TRUE_MSG(node->op_parameter() != nullptr, RET_NULL_PTR, "node's op_parameter is invalid.");
194 if (node->op_parameter()->type_ == ::PrimType::PrimType_Inner_ShapeFusion) {
195 continue;
196 }
197 MS_LOG(ERROR) << "Illegitimate kernel output tensor : " << tensor->tensor_name();
198 continue;
199 }
200 }
201 auto ret = HandleBuildinCpuKernelWeight(subgraph_type, node);
202 if (ret != RET_OK) {
203 return ret;
204 }
205 }
206 #if GPU_OPENCL
207 if (kernel->desc().arch == kernel::kGPU) {
208 if (this->GetEnableGLTexture() == true && (kernel == dst_kernels.front() || kernel == dst_kernels.back() - 1)) {
209 kernel->SetOpenGLTextureEnable(true);
210 MS_LOG(INFO) << "Set OpenGLSharingMem for subgraph success!" << std::endl;
211 }
212 auto ret = reinterpret_cast<kernel::OpenCLSubGraph *>(kernel)->RunPass();
213 if (ret != RET_OK) {
214 MS_LOG(ERROR) << "OpenCLSubGraph RunPass failed.";
215 return ret;
216 }
217 }
218 #endif
219 }
220 return RET_OK;
221 }
222
SchedulePreProcess()223 int Scheduler::SchedulePreProcess() {
224 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
225 auto search_sub_graph =
226 SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
227 #endif
228
229 #if !defined(RUNTIME_PASS_CLIP)
230 OnlineFusionRegistry::GetInstance()->DoOnlineFusionPass(&search_sub_graph);
231 #endif
232
233 this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
234
235 if (src_model_->model_type_ != ModelType_MSLite) {
236 // call abstract model infer interface
237 *is_infershape_ = RET_OK;
238 } else {
239 *is_infershape_ = InferSubGraphShape(kMainSubGraphIndex);
240 }
241 if (*is_infershape_ != RET_OK && *is_infershape_ != RET_INFER_INVALID) {
242 MS_LOG(ERROR) << "op infer shape failed.";
243 return *is_infershape_;
244 }
245
246 if (context_->enable_parallel_) {
247 #ifndef AUTO_PARALLEL_CLIP
248 if (*is_infershape_ != RET_INFER_INVALID) {
249 search_sub_graph.SubGraphSplit();
250 }
251 #else
252 MS_LOG(ERROR) << unsupport_auto_parallel_log;
253 return RET_NOT_SUPPORT;
254 #endif
255 }
256 return RET_OK;
257 }
258
CheckCpuValid(const std::vector<kernel::KernelExec * > * dst_kernels) const259 int Scheduler::CheckCpuValid(const std::vector<kernel::KernelExec *> *dst_kernels) const {
260 if (context_->IsDeviceTypeEnabled(DT_CPU)) {
261 return RET_OK;
262 }
263 for (auto kernel : *dst_kernels) {
264 if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
265 MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
266 return RET_ERROR;
267 }
268 }
269 return RET_OK;
270 }
271
ConstructSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)272 int Scheduler::ConstructSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels) {
273 if (*is_control_flow_) {
274 return ConstructControlFlowMainGraph(dst_kernels);
275 }
276
277 auto src_kernel = *dst_kernels;
278 dst_kernels->clear();
279 std::map<const kernel::KernelExec *, bool> is_kernel_finish;
280 return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
281 }
282
ProcessSubGraphTranspose(std::vector<kernel::KernelExec * > * dst_kernels)283 int Scheduler::ProcessSubGraphTranspose(std::vector<kernel::KernelExec *> *dst_kernels) {
284 #ifndef ENABLE_MULTI_LAYOUT
285 auto ret = pass::RuntimeFormatPass(dst_kernels, src_tensors_, Format::NHWC);
286 if (ret != RET_OK) {
287 MS_LOG(ERROR) << "Run runtime format pass failed.";
288 return RET_ERROR;
289 }
290 #endif
291 return RET_OK;
292 }
293
DelQuantDTypeCastKernel(std::vector<kernel::KernelExec * > * kernels)294 STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *kernels) {
295 for (auto iter = (*kernels).begin(); iter != (*kernels).end();) {
296 auto cur_kernel = *iter;
297 if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
298 auto sub_inner_graph = reinterpret_cast<kernel::SubGraphKernel *>(cur_kernel);
299 auto &subgraph_nodes = sub_inner_graph->nodes();
300 if (DelQuantDTypeCastKernel(&subgraph_nodes) != RET_OK) {
301 MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
302 return RET_ERROR;
303 }
304 }
305 if (cur_kernel->type() != schema::PrimitiveType_QuantDTypeCast) {
306 iter++;
307 continue;
308 }
309 auto &post_kernels = cur_kernel->out_kernels();
310 auto &pre_kernels = cur_kernel->in_kernels();
311 if (cur_kernel->in_tensors().size() != 1) {
312 MS_LOG(ERROR) << cur_kernel->name() << " input size error."
313 << " cur_kernel in tensors size:" << cur_kernel->in_tensors().size();
314 return RET_ERROR;
315 }
316 bool graph_input = pre_kernels.empty();
317 if (!graph_input) {
318 // modify post kernel input to new kernel and new tensor
319 for (auto post_kernel : post_kernels) {
320 auto post_in_kernels = post_kernel->in_kernels();
321 auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
322 *post_input_iter = pre_kernels[0];
323 post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
324 post_kernel->set_in_kernels(post_in_kernels);
325 }
326 auto pre_out_kernels = pre_kernels[0]->out_kernels();
327 auto pre_out_iter = std::find(pre_out_kernels.begin(), pre_out_kernels.end(), cur_kernel);
328 if (pre_out_iter != pre_out_kernels.end()) {
329 pre_out_kernels.erase(pre_out_iter);
330 pre_out_kernels.insert(pre_out_iter, post_kernels.begin(), post_kernels.end());
331 pre_kernels[0]->set_out_kernels(pre_kernels);
332 }
333 } else {
334 for (auto post_kernel : post_kernels) {
335 auto post_in_kernels = post_kernel->in_kernels();
336 auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
337 *post_input_iter = {};
338 post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
339 post_kernel->set_in_kernels(post_in_kernels);
340 }
341 }
342
343 // update data type
344 for (auto tensor : cur_kernel->in_tensors()) {
345 tensor->set_data_type(kNumberTypeFloat32);
346 }
347 for (auto tensor : cur_kernel->out_tensors()) {
348 tensor->set_data_type(kNumberTypeFloat32);
349 }
350
351 // update model output kernel & tensor
352 if (cur_kernel->is_model_output()) {
353 pre_kernels[0]->set_is_model_output(true);
354 cur_kernel->in_tensors()[0]->set_category(Category::GRAPH_OUTPUT);
355 pre_kernels[0]->set_out_kernels({});
356 // If the current kernel is the output kernel, use the current output tensor as the output tensor of the previous
357 // node.
358 auto pre_out_tensors = pre_kernels[0]->out_tensors();
359 auto tensor_iter = std::find(pre_out_tensors.begin(), pre_out_tensors.end(), cur_kernel->in_tensors()[0]);
360 if (tensor_iter != pre_kernels[0]->out_tensors().end()) {
361 *tensor_iter = cur_kernel->out_tensors()[0];
362 }
363 }
364
365 // delete cur kernel
366 iter = kernels->erase(iter);
367 MS_LOG(DEBUG) << "Delete kernel: " << cur_kernel->name();
368 delete cur_kernel;
369 }
370 return RET_OK;
371 }
372
Schedule(std::vector<kernel::KernelExec * > * dst_kernels)373 int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
374 MS_LOG(DEBUG) << "Start schedule.";
375 int check_input_ret = CheckInputParam(dst_kernels);
376 if (check_input_ret != RET_OK) {
377 MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
378 return check_input_ret;
379 }
380
381 shape_fusion_pass_ =
382 std::make_shared<ShapeFusionPass>(context_, reinterpret_cast<LiteModel *>(src_model_), src_tensors_);
383 MS_CHECK_TRUE_RET(shape_fusion_pass_ != nullptr, RET_ERROR);
384 int ret = SchedulePreProcess();
385 if (ret != RET_OK) {
386 return ret;
387 }
388
389 if (*is_control_flow_) {
390 control_flow_scheduler_ = std::make_shared<ControlFlowScheduler>(context_, ms_context_, src_tensors_);
391 MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "new control scheduler failed.");
392 }
393
394 ret = ScheduleGraphToKernels(dst_kernels);
395 FreeOpParameters();
396 op_parameters_.clear();
397 if (ret != RET_OK) {
398 MS_LOG(ERROR) << "Schedule graph to kernels failed.";
399 return ret;
400 }
401 if (context_->float_mode) {
402 kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
403 ret = DelQuantDTypeCastKernel(dst_kernels);
404 if (ret != RET_OK) {
405 MS_LOG(ERROR) << "Delete quant_dtype_cast kernel failed.";
406 return ret;
407 }
408 }
409 shape_fusion_pass_->StoreStateAndReset();
410
411 MS_LOG(DEBUG) << "Start to init delegate kernels.";
412 ret = InitDelegateKernels(dst_kernels);
413 if (ret != RET_OK) {
414 MS_LOG(ERROR) << "Repalce delegate kernels failed.";
415 return ret;
416 }
417 MS_LOG(DEBUG) << "Finish to init delegate kernels.";
418
419 ret = CheckCpuValid(dst_kernels);
420 if (ret != RET_OK) {
421 MS_LOG(ERROR) << "kernels invalid in set devices.";
422 return ret;
423 }
424
425 kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
426
427 ret = ConstructSubGraphs(dst_kernels);
428 if (ret != RET_OK) {
429 MS_LOG(ERROR) << "ConstructSubGraphs failed.";
430 return ret;
431 }
432
433 ret = ProcessSubGraphTranspose(dst_kernels);
434 if (ret != RET_OK) {
435 MS_LOG(ERROR) << "Process SubGraph with multi layout failed.";
436 return ret;
437 }
438
439 if (*is_control_flow_) {
440 control_flow_scheduler_->SetSubgraphForPartialNode(&partial_kernel_subgraph_index_map_,
441 &subgraph_index_subgraph_kernel_map_);
442 ret = control_flow_scheduler_->Schedule(dst_kernels);
443 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "control flow schedule failed.");
444 }
445
446 auto status = RuntimePass(dst_kernels, src_tensors_);
447 if (status != RET_OK) {
448 MS_LOG(ERROR) << "runtime pass failed.";
449 return RET_ERROR;
450 }
451
452 ret = InitKernels(std::move(*dst_kernels));
453 if (ret != RET_OK) {
454 MS_LOG(ERROR) << "InitKernels failed.";
455 return ret;
456 }
457 shape_fusion_pass_->RestoreState();
458 if (IsPrintDebug()) {
459 MS_LOG(DEBUG) << "schedule kernels success.";
460 for (auto subgraph : *dst_kernels) {
461 MS_LOG(DEBUG) << "[subgraph] : " << subgraph->name() << ", type:" << subgraph->subgraph_type();
462 if (subgraph->desc().arch == kernel::KERNEL_ARCH::kDelegate) {
463 continue;
464 }
465 std::vector<kernel ::KernelExec *> kernel_list = reinterpret_cast<kernel::SubGraphKernel *>(subgraph)->nodes();
466 for (auto kernel : kernel_list) {
467 MS_LOG(DEBUG) << "kernel: [" << kernel->name() << "] "
468 << "TypeId(" << kernel->desc().data_type << "); "
469 << "OpType(" << PrimitiveCurVersionTypeName(kernel->desc().type) << "); "
470 << "format(" << kernel->desc().format << "); "
471 << "arch(" << kernel->desc().arch << ")";
472 }
473 }
474 }
475 return RET_OK;
476 }
477
CheckInputParam(const std::vector<kernel::KernelExec * > * dst_kernels) const478 int Scheduler::CheckInputParam(const std::vector<kernel::KernelExec *> *dst_kernels) const {
479 if (dst_kernels == nullptr) {
480 return RET_ERROR;
481 }
482 if (src_model_ == nullptr) {
483 MS_LOG(ERROR) << "Input model is nullptr";
484 return RET_PARAM_INVALID;
485 }
486 if (src_model_->graph_.sub_graphs_.empty()) {
487 MS_LOG(ERROR) << "Model should have a subgraph at least";
488 return RET_PARAM_INVALID;
489 }
490 return RET_OK;
491 }
492
493 #ifndef DELEGATE_CLIP
ReplaceDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)494 int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
495 std::vector<kernel::Kernel *> kernels;
496 for (size_t i = 0; i < dst_kernels->size(); i++) {
497 auto litert_kernel = reinterpret_cast<kernel::Kernel *>((*dst_kernels)[i]->kernel());
498 if (MS_UNLIKELY(litert_kernel == nullptr)) {
499 MS_LOG(ERROR) << "nullptr exist in dst_kernels.";
500 return RET_ERROR;
501 }
502 kernels.push_back(litert_kernel);
503 }
504
505 ms_inputs_ = LiteTensorsToMSTensors(*inputs_);
506 ms_outputs_ = LiteTensorsToMSTensors(*outputs_);
507 auto schema_version = static_cast<SchemaVersion>(context_->get_schema_version());
508 DelegateModel<schema::Primitive> *model =
509 new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
510 if (model == nullptr) {
511 MS_LOG(ERROR) << "New delegate model failed.";
512 return RET_NULL_PTR;
513 }
514
515 #ifdef SUPPORT_NNRT
516 if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
517 auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
518 delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
519 void *meta_graph = reinterpret_cast<void *>(
520 const_cast<mindspore::schema::MetaGraph *>(mindspore::schema::GetMetaGraph(this->src_model_->buf)));
521 delegate->SetMetaGraph(meta_graph);
522 delegate->SetDequantTensors(this->src_tensors_);
523 }
524 #endif
525
526 auto ret = delegate_->Build(model);
527 if (ret != mindspore::kSuccess) {
528 delete model;
529 MS_LOG(ERROR) << "Delegate prepare kernels failed.";
530 return RET_ERROR;
531 }
532
533 auto src_kernels = *dst_kernels;
534 dst_kernels->clear();
535 std::map<const kernel::KernelExec *, bool> delegate_support;
536 for (auto kernel : src_kernels) {
537 delegate_support[kernel] = true;
538 }
539 for (auto kernel : kernels) {
540 size_t index = 0;
541 for (; index < src_kernels.size(); index++) {
542 if (kernel == src_kernels[index]->kernel()) {
543 // Kernels that the delegate does not support keep the original backend
544 dst_kernels->push_back(src_kernels[index]);
545 delegate_support[src_kernels[index]] = false;
546 break;
547 }
548 }
549 if (index == src_kernels.size()) {
550 // New liteKernel to save delegate subgraph
551 std::shared_ptr<kernel::Kernel> shared_kernel(kernel);
552 auto kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel);
553 if (kernel_exec == nullptr) {
554 delete model;
555 MS_LOG(ERROR) << "New KernelExec for delegate subgraph failed.";
556 return RET_NULL_PTR;
557 }
558 auto delegate_type = kNumberTypeFloat32;
559 for (auto &input : kernel->inputs()) {
560 if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) {
561 delegate_type = kNumberTypeFloat16;
562 break;
563 }
564 }
565 kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""};
566 kernel_exec->set_desc(delegate_desc);
567 dst_kernels->push_back(kernel_exec);
568 }
569 }
570 // Release the cpu kernel that has been replace by delegate subgraph, as well as their tensor data
571 for (auto kernel : src_kernels) {
572 if (delegate_support[kernel] == true) {
573 auto inputs = kernel->in_tensors();
574 for (auto *tensor : inputs) {
575 MS_ASSERT(tensor != nullptr);
576 if (tensor->IsConst()) {
577 tensor->FreeData();
578 }
579 }
580 delete kernel;
581 }
582 }
583 delete model;
584 return RET_OK;
585 }
586
InitDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)587 int Scheduler::InitDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
588 /* no delegate valid */
589 if (delegate_ == nullptr) {
590 return RET_OK;
591 }
592
593 /* set delegate spin count */
594 context_->thread_pool_->SetSpinCountMinValue();
595
596 /* external delegate */
597 if (delegate_device_type_ == -1) {
598 auto ret = ReplaceDelegateKernels(dst_kernels);
599 if (ret != RET_OK) {
600 MS_LOG(ERROR) << "external delegate init failed.";
601 return ret;
602 }
603 }
604
605 /* Inner delegate : check Priority */
606 std::vector<kernel::KernelExec *> src_kernels = *dst_kernels;
607 dst_kernels->clear();
608
609 while (!src_kernels.empty()) {
610 std::vector<kernel::KernelExec *> tmp_kernels;
611 kernel::KernelExec *remain_kernel = nullptr;
612
613 /* Loop for inner delegate npu and TensorRT subgraph */
614 while (!src_kernels.empty()) {
615 auto kernel = src_kernels.front();
616 VectorErase(&src_kernels, kernel);
617 bool priority_ret =
618 DeviceTypePriority(context_, delegate_device_type_, KernelArchToDeviceType(kernel->desc().arch));
619 if (priority_ret == true) {
620 tmp_kernels.push_back(kernel);
621 } else {
622 remain_kernel = kernel;
623 break;
624 }
625 }
626
627 /* start current NPU-kernels replace */
628 if (tmp_kernels.empty()) {
629 if (remain_kernel != nullptr) {
630 dst_kernels->push_back(remain_kernel);
631 remain_kernel = nullptr;
632 }
633 continue;
634 }
635 auto ret = ReplaceDelegateKernels(&tmp_kernels);
636 if (ret != RET_OK) {
637 dst_kernels->insert(dst_kernels->end(), src_kernels.begin(), src_kernels.end());
638 dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
639 if (remain_kernel != nullptr) {
640 dst_kernels->push_back(remain_kernel);
641 }
642 MS_LOG(ERROR) << "Inner delegate replace delegate kernels failed.";
643 return ret;
644 }
645
646 dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
647 tmp_kernels.clear();
648 if (remain_kernel != nullptr) {
649 dst_kernels->push_back(remain_kernel);
650 remain_kernel = nullptr;
651 }
652 }
653
654 return RET_OK;
655 }
656 #endif
657
FindNodeInoutTensors(const lite::LiteGraph::Node & node,std::vector<Tensor * > * inputs,std::vector<Tensor * > * outputs)658 void Scheduler::FindNodeInoutTensors(const lite::LiteGraph::Node &node, std::vector<Tensor *> *inputs,
659 std::vector<Tensor *> *outputs) {
660 MS_ASSERT(inputs != nullptr);
661 MS_ASSERT(outputs != nullptr);
662 auto in_size = node.input_indices_.size();
663 inputs->reserve(in_size);
664 for (size_t j = 0; j < in_size; ++j) {
665 inputs->emplace_back(src_tensors_->at(node.input_indices_[j]));
666 }
667 auto out_size = node.output_indices_.size();
668 outputs->reserve(out_size);
669 for (size_t j = 0; j < out_size; ++j) {
670 outputs->emplace_back(src_tensors_->at(node.output_indices_[j]));
671 }
672 }
673
InferNodeShape(const lite::LiteGraph::Node * node)674 int Scheduler::InferNodeShape(const lite::LiteGraph::Node *node) {
675 MS_ASSERT(node != nullptr);
676 auto primitive = node->primitive_;
677 MS_ASSERT(primitive != nullptr);
678 std::vector<Tensor *> inputs;
679 std::vector<Tensor *> outputs;
680 FindNodeInoutTensors(*node, &inputs, &outputs);
681 auto ret =
682 KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders(), context_->get_schema_version());
683 if (ret != RET_NOT_SUPPORT) {
684 *infer_along_running_ = false;
685 return ret;
686 }
687
688 auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(
689 GetPrimitiveType(node->primitive_, context_->get_schema_version()), context_->get_schema_version());
690 if (parame_gen == nullptr) {
691 MS_LOG(ERROR) << "parameter generator is nullptr.";
692 FreeOpParameters();
693 return RET_NULL_PTR;
694 }
695 auto parameter = parame_gen(primitive);
696 if (parameter == nullptr) {
697 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
698 << GetPrimitiveTypeName(primitive, context_->get_schema_version());
699 FreeOpParameters();
700 return RET_ERROR;
701 }
702
703 parameter->quant_type_ = node->quant_type_;
704 parameter->thread_num_ = context_->thread_num_;
705 if (node->output_indices_.empty()) {
706 MS_LOG(ERROR) << "The output size is invalid";
707 if (parameter->destroy_func_ != nullptr) {
708 parameter->destroy_func_(parameter);
709 }
710 free(parameter);
711 return RET_ERROR;
712 }
713 if (op_parameters_.find(node->output_indices_.at(0)) != op_parameters_.end()) {
714 if (parameter->destroy_func_ != nullptr) {
715 parameter->destroy_func_(parameter);
716 }
717 free(parameter);
718 parameter = op_parameters_[node->output_indices_.at(0)];
719 } else {
720 op_parameters_[node->output_indices_.at(0)] = parameter;
721 }
722
723 if (IsCallNode(primitive, context_->get_schema_version())) {
724 return InferCallShape(node);
725 }
726 ret = KernelInferShape(inputs, outputs, parameter, context_->allocator);
727 if (ret != RET_OK && ret != RET_INFER_INVALID) {
728 FreeOpParameters();
729 return RET_ERROR;
730 }
731 for (auto &output : outputs) {
732 output->set_shape_changed(false);
733 }
734 if (*is_control_flow_) {
735 for (auto &output : outputs) {
736 output->set_shape({-1});
737 }
738 return RET_INFER_INVALID;
739 }
740
741 if (ret == RET_OK) {
742 for (auto &output : outputs) {
743 if (static_cast<size_t>(output->ElementsNum()) >= GetMaxMallocSize() / sizeof(int64_t)) {
744 MS_LOG(ERROR) << "The size of output tensor is too big";
745 FreeOpParameters();
746 return RET_ERROR;
747 }
748 }
749 } else if (ret != RET_INFER_INVALID) {
750 FreeOpParameters();
751 return RET_ERROR;
752 }
753 return ret;
754 }
755
FreeOpParameters()756 void Scheduler::FreeOpParameters() {
757 for (auto ¶m : op_parameters_) {
758 if (param.second != nullptr) {
759 if (param.second->destroy_func_ != nullptr) {
760 param.second->destroy_func_(param.second);
761 }
762 free(param.second);
763 param.second = nullptr;
764 }
765 }
766 }
767
RestoreSubGraphInput(const lite::LiteGraph::Node * partial_node)768 int Scheduler::RestoreSubGraphInput(const lite::LiteGraph::Node *partial_node) {
769 auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
770 MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
771 auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
772 for (size_t i = 0; i < subgraph->input_indices_.size(); ++i) {
773 auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
774 subgraph_input->set_data(nullptr);
775 }
776 return RET_OK;
777 }
778
CopyCommonTensor(Tensor * dst_tensor,Tensor * src_tensor)779 void CopyCommonTensor(Tensor *dst_tensor, Tensor *src_tensor) {
780 dst_tensor->set_data_type(src_tensor->data_type());
781 dst_tensor->set_shape(src_tensor->shape());
782 dst_tensor->set_format(src_tensor->format());
783 dst_tensor->set_data(src_tensor->data());
784 }
785
CopyPartialShapeToSubGraph(const lite::LiteGraph::Node * partial_node)786 int Scheduler::CopyPartialShapeToSubGraph(const lite::LiteGraph::Node *partial_node) {
787 auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
788 MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
789 auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
790 if (subgraph->input_indices_.size() != partial_node->input_indices_.size()) {
791 MS_LOG(ERROR) << "partial node " << partial_node->name_ << " inputs size: " << partial_node->input_indices_.size()
792 << " vs "
793 << " subgraph input size: " << subgraph->input_indices_.size();
794 return RET_PARAM_INVALID;
795 }
796
797 for (size_t i = 0; i < partial_node->input_indices_.size(); ++i) {
798 auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
799 auto &partial_input = src_tensors_->at(partial_node->input_indices_[i]);
800 if (partial_input->data_type() == kObjectTypeTensorType) {
801 return RET_INFER_INVALID;
802 }
803 CopyCommonTensor(subgraph_input, partial_input);
804 }
805
806 return RET_OK;
807 }
808
InferPartialShape(const lite::LiteGraph::Node * node)809 int Scheduler::InferPartialShape(const lite::LiteGraph::Node *node) {
810 MS_ASSERT(src_model_ != nullptr);
811 MS_ASSERT(node != nullptr);
812 if (!IsPartialNode(node->primitive_, context_->get_schema_version())) {
813 MS_LOG(ERROR) << "Node is not a partial";
814 return RET_PARAM_INVALID;
815 }
816 CopyPartialShapeToSubGraph(node);
817 int subgraph_index = GetPartialGraphIndex(node->primitive_, context_->get_schema_version());
818 auto ret = InferSubGraphShape(subgraph_index);
819 if (ret != RET_OK) {
820 MS_LOG(WARNING) << "infer subgraph: " << subgraph_index << " failed, ret:" << ret;
821 }
822 RestoreSubGraphInput(node);
823 return ret;
824 }
825
NodeInputIsPartial(const lite::LiteGraph::Node * node)826 LiteGraph::Node *Scheduler::NodeInputIsPartial(const lite::LiteGraph::Node *node) {
827 MS_ASSERT(src_model_ != nullptr);
828 MS_ASSERT(node != nullptr);
829 for (auto &iter : src_model_->graph_.all_nodes_) {
830 if (iter->output_indices_ == node->input_indices_) {
831 if (IsPartialNode(iter->primitive_, context_->get_schema_version())) {
832 return iter;
833 } else {
834 return nullptr;
835 }
836 }
837 }
838 return nullptr;
839 }
840
InferCallShape(const lite::LiteGraph::Node * node)841 int Scheduler::InferCallShape(const lite::LiteGraph::Node *node) {
842 MS_ASSERT(src_model_ != nullptr);
843 MS_ASSERT(node != nullptr);
844 if (!IsCallNode(node->primitive_, context_->get_schema_version())) {
845 MS_LOG(ERROR) << "Node is not a call cnode";
846 return RET_PARAM_INVALID;
847 }
848
849 auto partial_input = NodeInputIsPartial(node);
850 if (partial_input) {
851 return InferPartialShape(partial_input);
852 }
853 auto switch_input = NodeInputIsSwitchType(node);
854 if (switch_input) {
855 *is_control_flow_ = true;
856 return InferSwitchShape(switch_input);
857 }
858
859 MS_LOG(ERROR) << "call input is not partial and also not switch.";
860 return RET_ERROR;
861 }
862
InferSubGraphShape(size_t subgraph_index)863 int Scheduler::InferSubGraphShape(size_t subgraph_index) {
864 MS_ASSERT(src_model_ != nullptr);
865 MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
866 MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
867 if (find(infer_subgraph_index_.begin(), infer_subgraph_index_.end(), subgraph_index) != infer_subgraph_index_.end()) {
868 MS_LOG(ERROR) << "The subgraph has been infer shape, subgraph index: " << subgraph_index;
869 return RET_INFER_INVALID;
870 }
871 infer_subgraph_index_.push_back(subgraph_index);
872 auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
873 int subgraph_infershape_ret = RET_OK;
874 auto node_indexes = subgraph->node_indices_;
875 for (size_t i = 0; i < node_indexes.size(); ++i) {
876 auto node_index = node_indexes[i];
877 auto node = src_model_->graph_.all_nodes_[node_index];
878 MS_ASSERT(node != nullptr);
879 auto *primitive = node->primitive_;
880 if (primitive == nullptr) {
881 MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
882 return RET_ERROR;
883 }
884 if (node->node_type_ == schema::PrimitiveType_Shape) {
885 // convert shape to built-in shape
886 MS_CHECK_TRUE_RET(node->input_indices_.size() == 1, RET_ERROR);
887 shape_fusion_pass_->Run(node, subgraph_index);
888 node_indexes = subgraph->node_indices_;
889 }
890 auto ret = InferNodeShape(node);
891 if (ret == RET_INFER_INVALID) {
892 MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_
893 << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version())
894 << ", set infer flag to false.";
895 subgraph_infershape_ret = RET_INFER_INVALID;
896 } else if (ret != RET_OK) {
897 FreeOpParameters();
898 MS_LOG(ERROR) << "InferShape failed, name: " << node->name_
899 << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
900 return RET_INFER_ERR;
901 }
902 }
903 return subgraph_infershape_ret;
904 }
905
906 namespace {
907 // support_fp16: current device and package support float16
CastAndRestoreConstTensorData(Tensor * tensor,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)908 int CastAndRestoreConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors,
909 TypeId dst_data_type, bool support_fp16) {
910 MS_ASSERT(tensor != nullptr);
911 MS_ASSERT(tensor->IsConst());
912 MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
913 MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16);
914 if (tensor->data_type() == dst_data_type) {
915 return RET_OK;
916 }
917 auto origin_data = tensor->data();
918 MS_ASSERT(origin_data != nullptr);
919 auto restore_tensor = Tensor::CopyTensor(*tensor, false);
920 if (restore_tensor == nullptr) {
921 return RET_NULL_PTR;
922 }
923 restore_tensor->set_data(origin_data);
924 restore_tensor->set_own_data(tensor->own_data());
925 tensor->set_data(nullptr);
926 tensor->set_data_type(dst_data_type);
927 auto ret = tensor->MallocData();
928 if (RET_OK != ret) {
929 MS_LOG(ERROR) << "malloc data failed";
930 return ret;
931 }
932 auto new_tensor_data = tensor->data();
933 MS_ASSERT(new_tensor_data != nullptr);
934 if (dst_data_type == kNumberTypeFloat32) {
935 Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
936 } else { // dst_data_type == kNumberTypeFloat16
937 Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
938 }
939 if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
940 MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
941 delete restore_tensor;
942 return RET_ERROR;
943 }
944 (*restored_origin_tensors)[tensor] = restore_tensor;
945 return RET_OK;
946 }
947
948 // support_fp16: current device and package support float16
CastConstTensorsData(const std::vector<Tensor * > & tensors,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)949 int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
950 TypeId dst_data_type, bool support_fp16) {
951 MS_ASSERT(restored_origin_tensors != nullptr);
952 if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
953 MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type.";
954 return RET_PARAM_INVALID;
955 }
956 for (auto *tensor : tensors) {
957 MS_ASSERT(tensor != nullptr);
958 // only cast const tensor
959 // tensorlist not support fp16 now
960 if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
961 continue;
962 }
963 // only support fp32->fp16 or fp16->fp32
964 if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
965 continue;
966 }
967 if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
968 auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16, support_fp16);
969 if (ret != RET_OK) {
970 MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
971 return ret;
972 }
973 } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
974 auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32, support_fp16);
975 if (ret != RET_OK) {
976 MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
977 return ret;
978 }
979 } else {
980 MS_LOG(DEBUG) << "No need to cast from " << tensor->data_type() << " to " << dst_data_type;
981 }
982 }
983 return RET_OK;
984 }
985
FreeRestoreTensors(std::map<Tensor *,Tensor * > * restored_origin_tensors)986 inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
987 MS_ASSERT(restored_origin_tensors != nullptr);
988 for (auto &restored_origin_tensor : *restored_origin_tensors) {
989 restored_origin_tensor.second->set_data(nullptr);
990 delete (restored_origin_tensor.second);
991 restored_origin_tensor.second = nullptr;
992 }
993 restored_origin_tensors->clear();
994 }
995
RestoreTensorData(std::map<Tensor *,Tensor * > * restored_origin_tensors)996 inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
997 MS_ASSERT(restored_origin_tensors != nullptr);
998 for (auto &restored_origin_tensor : *restored_origin_tensors) {
999 auto *origin_tensor = restored_origin_tensor.first;
1000 auto *restored_tensor = restored_origin_tensor.second;
1001 MS_ASSERT(origin_tensor != nullptr);
1002 MS_ASSERT(restored_tensor != nullptr);
1003 origin_tensor->FreeData();
1004 origin_tensor->set_data_type(restored_tensor->data_type());
1005 origin_tensor->set_data(restored_tensor->data());
1006 origin_tensor->set_own_data(restored_tensor->own_data());
1007 }
1008 FreeRestoreTensors(restored_origin_tensors);
1009 }
1010 } // namespace
1011
ResetByExecutionPlan(std::string node_name,TypeId * data_type)1012 void Scheduler::ResetByExecutionPlan(std::string node_name, TypeId *data_type) {
1013 if (execution_plan_ == nullptr) {
1014 return;
1015 }
1016 auto iter = execution_plan_->find(node_name);
1017 if (iter != execution_plan_->end()) {
1018 *data_type = iter->second;
1019 }
1020 return;
1021 }
1022
FindCpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,TypeId kernel_data_type,kernel::KernelExec ** kernel)1023 int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1024 OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
1025 kernel::KernelExec **kernel) {
1026 MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr.");
1027 auto op_type = op_parameter->type_;
1028 if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1029 MS_LOG(INFO) << "unsupported op_type: " << PrimitiveCurVersionTypeName(op_type)
1030 << ", data_type: " << desc.data_type;
1031 return RET_NOT_SUPPORT;
1032 }
1033 kernel::KernelKey cpu_desc = desc;
1034 if (kernel_data_type == kNumberTypeFloat16) {
1035 if (!context_->IsCpuFloat16Enabled() ||
1036 (cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
1037 return RET_NOT_SUPPORT;
1038 }
1039 cpu_desc.data_type = kNumberTypeFloat16;
1040 }
1041 auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->graph_.version_,
1042 context_->float_mode);
1043 if (ret != RET_OK) {
1044 MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1045 return RET_NOT_SUPPORT;
1046 }
1047 std::map<Tensor *, Tensor *> restored_origin_tensors;
1048
1049 if (is_train_session_) {
1050 ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type,
1051 context_->device_and_pkg_support_fp16_);
1052 if (ret != RET_OK) {
1053 MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
1054 return RET_NOT_SUPPORT;
1055 }
1056 }
1057
1058 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
1059 // reset op task num, The number of operator segmentation tasks is not necessarily equal to the number of threads
1060 int thread_num_limit = ParallelThreadPoolManager::GetInstance()->GetTaskNum(config_info_);
1061 if (thread_num_limit != -1 && IsSharedThreadPoolOp(op_type)) {
1062 op_parameter->thread_num_ = thread_num_limit;
1063 }
1064 #endif
1065
1066 ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, cpu_desc,
1067 op_parameter, kernel);
1068 if (ret == RET_OK) {
1069 MS_LOG(DEBUG) << "Get TypeId(expect = " << kernel_data_type << ", real = " << cpu_desc.data_type
1070 << ") op success: " << PrimitiveCurVersionTypeName(op_type);
1071 if (is_train_session_) {
1072 ret = (*kernel)->Prepare();
1073 RestoreTensorData(&restored_origin_tensors);
1074 }
1075 }
1076 return ret;
1077 }
1078
1079 #ifdef GPU_OPENCL
FindGpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,kernel::KernelExec ** kernel,TypeId prefer_data_type)1080 int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1081 OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::KernelExec **kernel,
1082 TypeId prefer_data_type) {
1083 MS_ASSERT(op_parameter != nullptr);
1084 MS_ASSERT(kernel != nullptr);
1085 if (!context_->IsDeviceTypeEnabled(DT_GPU)) {
1086 return RET_NOT_SUPPORT;
1087 }
1088
1089 // support more data type like int32
1090 kernel::KernelKey gpu_desc{kernel::KERNEL_ARCH::kGPU, desc.data_type, NHWC, desc.type};
1091 if (desc.data_type == kNumberTypeFloat32 && context_->IsGpuFloat16Enabled()) {
1092 gpu_desc.data_type = kNumberTypeFloat16;
1093 }
1094 if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kNumberTypeFloat32) {
1095 gpu_desc.data_type = prefer_data_type;
1096 }
1097 // weight dequant
1098 auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32, src_model_->graph_.version_,
1099 context_->float_mode);
1100 if (ret != RET_OK) {
1101 MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1102 return RET_NOT_SUPPORT;
1103 }
1104 // we don't need to restore tensor for copy data
1105 ret = CopyConstTensorData(in_tensors, op_parameter->type_);
1106 if (ret != RET_OK) {
1107 MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
1108 return RET_NOT_SUPPORT;
1109 }
1110 ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, gpu_desc,
1111 op_parameter, kernel);
1112 if (ret == RET_OK) {
1113 MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1114 } else {
1115 MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1116 }
1117 return ret;
1118 }
1119 #endif
1120
FindProviderKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId data_type,kernel::KernelExec ** kernel)1121 int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1122 const LiteGraph::Node *node, TypeId data_type, kernel::KernelExec **kernel) {
1123 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
1124 MS_ASSERT(kernel != nullptr);
1125 int ret = RET_NOT_SUPPORT;
1126 auto prim_type = GetPrimitiveType(node->primitive_, context_->get_schema_version());
1127 if (prim_type == schema::PrimitiveType_Custom) {
1128 for (auto &&device : context_->device_list_) {
1129 if (!device.provider_.empty() && !device.provider_device_.empty()) {
1130 kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, prim_type,
1131 device.provider_device_, device.provider_};
1132 ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc,
1133 nullptr, kernel, node->primitive_);
1134 if (ret == RET_OK && *kernel != nullptr) {
1135 return ret;
1136 }
1137 }
1138 }
1139
1140 kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, prim_type, "", ""};
1141 ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1142 kernel, node->primitive_);
1143 if (ret == RET_OK && *kernel != nullptr) {
1144 return ret;
1145 }
1146 return RET_NOT_SUPPORT;
1147 }
1148 if (!context_->IsProviderEnabled()) {
1149 return ret;
1150 }
1151 if (context_->get_schema_version() == SCHEMA_V0) {
1152 return ret;
1153 }
1154 for (auto &&device : context_->device_list_) {
1155 if (!device.provider_.empty()) {
1156 kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, prim_type,
1157 device.provider_device_, device.provider_};
1158 ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1159 kernel, node->primitive_);
1160 if (ret == RET_OK && *kernel != nullptr) {
1161 return ret;
1162 }
1163 }
1164 }
1165 #endif
1166 return RET_NOT_SUPPORT;
1167 }
1168
FindBackendKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId prefer_data_type)1169 kernel::KernelExec *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
1170 const std::vector<Tensor *> &out_tensors, const LiteGraph::Node *node,
1171 TypeId prefer_data_type) {
1172 MS_ASSERT(node != nullptr);
1173 // why we need this
1174 TypeId data_type;
1175 if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1176 if (in_tensors.front()->data_type() == kNumberTypeBool) {
1177 data_type = kNumberTypeBool;
1178 } else {
1179 data_type = kNumberTypeFloat32;
1180 }
1181 } else {
1182 data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
1183 if (data_type == kTypeUnknown) {
1184 MS_LOG(ERROR) << "GetFirstFp32Fp16OrInt8Type is unknown.";
1185 return nullptr;
1186 }
1187 }
1188 if (context_->float_mode) {
1189 for (auto tensor : out_tensors) {
1190 if (!tensor->quant_params().empty() &&
1191 (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeUInt8)) {
1192 data_type = kNumberTypeFloat32;
1193 tensor->set_data_type(kNumberTypeFloat32);
1194 }
1195 }
1196 }
1197 kernel::KernelExec *kernel = nullptr;
1198 auto status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
1199 if (status == RET_OK && kernel != nullptr) {
1200 return kernel;
1201 }
1202 MS_ASSERT(!node->output_indices_.empty());
1203 OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1204 if (op_parameter == nullptr) {
1205 MS_LOG(ERROR) << "Can not find OpParameter!type: "
1206 << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1207 return nullptr;
1208 }
1209
1210 #if (defined GPU_OPENCL) || (defined ENABLE_FP16)
1211 int kernel_thread_count = op_parameter->thread_num_;
1212 #endif
1213 op_parameter->is_train_session_ = is_train_session_;
1214 kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, op_parameter->type_};
1215
1216 #ifdef GPU_OPENCL
1217 bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
1218 bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
1219 if (gpu_priority && use_gpu_kernel) {
1220 status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel, prefer_data_type);
1221 if (status == RET_OK) {
1222 return kernel;
1223 } else {
1224 MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1225 << node->name_;
1226 if (status == RET_ERROR) {
1227 op_parameters_.erase(node->output_indices_.at(0));
1228 auto ret = InferNodeShape(node);
1229 if (ret == RET_INFER_INVALID || ret == RET_OK) {
1230 op_parameter = op_parameters_[node->output_indices_.at(0)];
1231 op_parameter->thread_num_ = kernel_thread_count;
1232 } else {
1233 MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1234 return nullptr;
1235 }
1236 }
1237 }
1238 }
1239 #endif
1240 #ifdef ENABLE_FP16
1241 if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
1242 ((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
1243 status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
1244 if (status == RET_OK) {
1245 return kernel;
1246 } else {
1247 MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1248 << node->name_;
1249 if (status == RET_ERROR) {
1250 op_parameters_.erase(node->output_indices_.at(0));
1251 auto ret = InferNodeShape(node);
1252 if (ret == RET_INFER_INVALID || ret == RET_OK) {
1253 op_parameter = op_parameters_[node->output_indices_.at(0)];
1254 op_parameter->thread_num_ = kernel_thread_count;
1255 } else {
1256 MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1257 return nullptr;
1258 }
1259 }
1260 }
1261 }
1262 #endif
1263 if (data_type == kNumberTypeFloat16) {
1264 MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
1265 desc.data_type = kNumberTypeFloat32;
1266 }
1267 status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
1268 if (status == RET_OK) {
1269 return kernel;
1270 } else if (status == RET_ERROR) {
1271 op_parameters_.erase(node->output_indices_.at(0));
1272 auto ret = InferNodeShape(node);
1273 if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
1274 MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1275 }
1276 }
1277 #ifdef OP_INT8_CLIP
1278 if (desc.data_type == kNumberTypeInt8) {
1279 MS_LOG(ERROR) << unsupport_int8_log;
1280 }
1281 #endif
1282 return nullptr;
1283 }
1284
1285 namespace {
GetKernelSubGraphType(const kernel::KernelExec * kernel,const InnerContext & context,bool is_controlflow=false)1286 kernel::SubGraphType GetKernelSubGraphType(const kernel::KernelExec *kernel, const InnerContext &context,
1287 bool is_controlflow = false) {
1288 if (kernel == nullptr) {
1289 return kernel::kNotSubGraph;
1290 }
1291
1292 auto desc = kernel->desc();
1293 if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
1294 if (desc.data_type == kNumberTypeFloat16) {
1295 return kernel::kGpuFp16SubGraph;
1296 } else {
1297 return kernel::kGpuFp32SubGraph;
1298 }
1299 } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) {
1300 return kernel::kNpuSubGraph;
1301 } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) {
1302 return kernel::kApuSubGraph;
1303 } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) {
1304 if (desc.data_type == kNumberTypeFloat16) {
1305 return kernel::kCpuFP16SubGraph;
1306 } else {
1307 return kernel::kCpuFP32SubGraph;
1308 }
1309 } else if (desc.arch == kernel::KERNEL_ARCH::kCustom) {
1310 return kernel::kCustomSubGraph;
1311 }
1312 return kernel::kNotSubGraph;
1313 }
1314 } // namespace
1315
SchedulePartialToKernel(const lite::LiteGraph::Node * src_node)1316 kernel::KernelExec *Scheduler::SchedulePartialToKernel(const lite::LiteGraph::Node *src_node) {
1317 MS_ASSERT(src_model_ != nullptr);
1318 MS_ASSERT(src_node != nullptr);
1319 auto *primitive = src_node->primitive_;
1320 MS_ASSERT(primitive != nullptr);
1321 if (!IsPartialNode(primitive, context_->get_schema_version())) {
1322 return nullptr;
1323 }
1324 auto subgraph_index = GetPartialGraphIndex(src_node->primitive_, context_->get_schema_version());
1325 if (SubGraphHasScheduled(subgraph_index)) {
1326 MS_LOG(INFO) << "Subgraph has been scheduled.";
1327 return {};
1328 } else {
1329 SubGraphMarkScheduled(subgraph_index);
1330 }
1331 auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1332 if (subgraph_kernel == nullptr) {
1333 MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1334 return {};
1335 }
1336 subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1337 return subgraph_kernel;
1338 }
1339
1340 #ifdef ENABLE_FP16
SubGraphPreferDataType(const int & subgraph_index,TypeId * prefer_data_type)1341 int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
1342 if (!context_->IsCpuFloat16Enabled() || context_->GetDelegateMode() == kNNAPI) {
1343 *prefer_data_type = kNumberTypeFloat32;
1344 return RET_OK;
1345 }
1346
1347 auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1348 for (auto node_index : subgraph->node_indices_) {
1349 auto node = src_model_->graph_.all_nodes_[node_index];
1350 MS_ASSERT(node != nullptr);
1351 MS_ASSERT(!node->output_indices_.empty());
1352 OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1353 if (op_parameter == nullptr) {
1354 MS_LOG(ERROR) << "Can not find OpParameter!type: "
1355 << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1356 return RET_ERROR;
1357 }
1358 kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16, NHWC, op_parameter->type_};
1359 if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1360 *prefer_data_type = kNumberTypeFloat32;
1361 return RET_OK;
1362 }
1363
1364 std::vector<Tensor *> inputs;
1365 std::vector<Tensor *> outputs;
1366 FindNodeInoutTensors(*node, &inputs, &outputs);
1367 if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1368 *prefer_data_type = kNumberTypeFloat32;
1369 return RET_OK;
1370 }
1371 TypeId data_type = GetFirstFp32Fp16OrInt8Type(inputs);
1372 if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) {
1373 *prefer_data_type = kNumberTypeFloat32;
1374 return RET_OK;
1375 }
1376 }
1377 *prefer_data_type = kNumberTypeFloat16;
1378 return RET_OK;
1379 }
1380 #endif
1381
ScheduleMainSubGraphToKernels()1382 std::vector<kernel::KernelExec *> Scheduler::ScheduleMainSubGraphToKernels() {
1383 std::vector<kernel::KernelExec *> kernels;
1384 std::vector<lite::Tensor *> in_tensors;
1385 std::vector<lite::Tensor *> out_tensors;
1386 TypeId prefer_data_type = context_->GetDelegateMode() == kNNAPI ? kNumberTypeFloat32 : kTypeUnknown;
1387 auto ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1388 if (ret != RET_OK) {
1389 MS_LOG(ERROR) << "Schedule subgraph failed, index: " << kMainSubGraphIndex;
1390 for (auto *kernel : kernels) {
1391 delete kernel;
1392 kernel = nullptr;
1393 }
1394 return {};
1395 }
1396 return kernels;
1397 }
1398
SchedulePartialToSubGraphKernel(const int & subgraph_index)1399 kernel::KernelExec *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
1400 TypeId prefer_data_type = kTypeUnknown;
1401 #ifdef ENABLE_FP16
1402 if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
1403 MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
1404 return nullptr;
1405 }
1406 #endif
1407 std::vector<kernel::KernelExec *> kernels;
1408 std::vector<lite::Tensor *> in_tensors;
1409 std::vector<lite::Tensor *> out_tensors;
1410 auto ret = ScheduleSubGraphToKernels(subgraph_index, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1411 if (ret != RET_OK) {
1412 MS_LOG(ERROR) << "Schedule subgraph failed, index: " << subgraph_index;
1413 return nullptr;
1414 }
1415 kernel::KernelExecUtil::FindAllInoutKernels(kernels);
1416 kernel::SubGraphType cur_sub_graph_type = kernel::kCpuFP32SubGraph;
1417 if (!kernels.empty()) {
1418 cur_sub_graph_type = GetKernelSubGraphType(kernels.front(), *context_, true);
1419 }
1420 MS_LOG(INFO) << "cur_sub_graph_type: " << cur_sub_graph_type;
1421 auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1422 kernels, &in_tensors, &out_tensors, cur_sub_graph_type, *context_, context_->get_schema_version());
1423 if (subgraph_kernel == nullptr) {
1424 MS_LOG(ERROR) << "CreateSubGraphKernel failed, cur_sub_graph_type: " << cur_sub_graph_type;
1425 return nullptr;
1426 }
1427 return subgraph_kernel;
1428 }
1429
ScheduleSubGraphToSubGraphKernels(const int & subgraph_index)1430 std::vector<kernel::KernelExec *> Scheduler::ScheduleSubGraphToSubGraphKernels(const int &subgraph_index) {
1431 if (subgraph_index == kMainSubGraphIndex) {
1432 return ScheduleMainSubGraphToKernels();
1433 }
1434 auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1435 if (subgraph_kernel == nullptr) {
1436 MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1437 return {};
1438 }
1439 subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1440 subgraph_index_subgraph_kernel_map_[subgraph_index] = subgraph_kernel;
1441 return {subgraph_kernel};
1442 }
1443
ScheduleNodeToKernel(const lite::LiteGraph::Node * src_node,TypeId prefer_data_type)1444 kernel::KernelExec *Scheduler::ScheduleNodeToKernel(const lite::LiteGraph::Node *src_node, TypeId prefer_data_type) {
1445 std::vector<Tensor *> inputs;
1446 std::vector<Tensor *> outputs;
1447 MS_ASSERT(src_node != nullptr);
1448 FindNodeInoutTensors(*src_node, &inputs, &outputs);
1449
1450 ResetByExecutionPlan(src_node->name_, &prefer_data_type);
1451
1452 mindspore::kernel::KernelExec *kernel = nullptr;
1453 if (src_model_->model_type_ != mindspore::lite::ModelType_MSLite) {
1454 auto abstract_model_ptr = reinterpret_cast<AbstractBaseModel *>(src_model_);
1455 if (abstract_model_ptr == nullptr) {
1456 MS_LOG(ERROR) << "src model is not abstract base model return nullptr.";
1457 return nullptr;
1458 }
1459 kernel = abstract_model_ptr->FindBackendKernel(inputs, outputs, src_node, context_, prefer_data_type);
1460 if (kernel == nullptr) {
1461 MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_;
1462 return nullptr;
1463 }
1464 } else {
1465 kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type);
1466 if (kernel == nullptr) {
1467 MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_
1468 << ", type: " << GetPrimitiveTypeName(src_node->primitive_, context_->get_schema_version());
1469 return nullptr;
1470 }
1471 }
1472 op_parameters_[src_node->output_indices_.at(0)] = nullptr;
1473 auto ret = kernel::KernelExecUtil::SetKernelTensorDataType(kernel);
1474 if (ret != RET_OK) {
1475 MS_LOG(ERROR) << "Set tensor data type for kernel " << kernel->name() << std::endl;
1476 delete kernel;
1477 return nullptr;
1478 }
1479 kernel->set_name(src_node->name_);
1480 if (kernel->kernel() != nullptr) {
1481 kernel->kernel()->SetConfig(config_info_);
1482 }
1483 return kernel;
1484 }
1485
IsControlFlowPattern(const lite::LiteGraph::Node & partial_node)1486 bool Scheduler::IsControlFlowPattern(const lite::LiteGraph::Node &partial_node) {
1487 lite::LiteGraph::Node *partial_node_output = nullptr;
1488 for (auto output_index : partial_node.output_indices_) {
1489 for (auto &node : src_model_->graph_.all_nodes_) {
1490 if (IsContain(node->input_indices_, output_index)) {
1491 partial_node_output = node;
1492 break;
1493 }
1494 }
1495 }
1496
1497 return partial_node_output != nullptr &&
1498 (IsCallNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1499 IsSwitchNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1500 IsSwitchLayerNode(partial_node_output->primitive_, context_->get_schema_version()));
1501 }
1502
ScheduleGraphToKernels(std::vector<kernel::KernelExec * > * dst_kernels,TypeId prefer_data_type)1503 int Scheduler::ScheduleGraphToKernels(std::vector<kernel::KernelExec *> *dst_kernels, TypeId prefer_data_type) {
1504 subgraphs_to_schedule_.push_back(kMainSubGraphIndex);
1505 while (!subgraphs_to_schedule_.empty()) {
1506 auto cur_subgraph_index = subgraphs_to_schedule_.front();
1507 subgraphs_to_schedule_.pop_front();
1508 auto kernels = ScheduleSubGraphToSubGraphKernels(cur_subgraph_index);
1509 if (kernels.empty()) {
1510 MS_LOG(ERROR) << "ScheduleSubGraphToSubGraphKernel failed";
1511 return RET_ERROR;
1512 }
1513 std::copy(kernels.begin(), kernels.end(), std::back_inserter(*dst_kernels));
1514 }
1515 return RET_OK;
1516 }
1517
ScheduleSubGraphToKernels(size_t subgraph_index,std::vector<kernel::KernelExec * > * dst_kernels,std::vector<lite::Tensor * > * in_tensors,std::vector<lite::Tensor * > * out_tensors,TypeId prefer_data_type)1518 int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::KernelExec *> *dst_kernels,
1519 std::vector<lite::Tensor *> *in_tensors,
1520 std::vector<lite::Tensor *> *out_tensors, TypeId prefer_data_type) {
1521 MS_ASSERT(src_model_ != nullptr);
1522 MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
1523 MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
1524 MS_ASSERT(dst_kernels != nullptr);
1525 MS_ASSERT(dst_kernels->empty());
1526 auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1527 auto ret = RET_OK;
1528 for (auto node_index : subgraph->node_indices_) {
1529 auto node = src_model_->graph_.all_nodes_[node_index];
1530 MS_ASSERT(node != nullptr);
1531 auto *primitive = node->primitive_;
1532 MS_ASSERT(primitive != nullptr);
1533 kernel::KernelExec *kernel = nullptr;
1534
1535 if (src_model_->model_type_ == ModelType_MSLite && IsPartialNode(primitive, context_->get_schema_version())) {
1536 if (IsControlFlowPattern(*node)) {
1537 kernel = ScheduleNodeToKernel(node, prefer_data_type);
1538 auto partial_subgraph_index = GetPartialGraphIndex(primitive, context_->get_schema_version());
1539 MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "control flow scheduler is nullptr.");
1540 control_flow_scheduler_->RecordSubgraphCaller(partial_subgraph_index, kernel);
1541 if (SubGraphHasScheduled(partial_subgraph_index)) {
1542 partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1543 MS_LOG(INFO) << "subgraph has scheduled. ";
1544 } else {
1545 SubGraphMarkScheduled(partial_subgraph_index);
1546 partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1547 subgraphs_to_schedule_.push_back(partial_subgraph_index);
1548 }
1549 } else {
1550 MS_CHECK_TRUE_MSG(
1551 subgraph_index != static_cast<size_t>(GetPartialGraphIndex(node->primitive_, context_->get_schema_version())),
1552 RET_ERROR, "Unreasonable cycles exist in subgraph.");
1553 kernel = SchedulePartialToKernel(node);
1554 }
1555 } else {
1556 kernel = ScheduleNodeToKernel(node, prefer_data_type);
1557 }
1558 if (kernel == nullptr || ret != RET_OK) {
1559 MS_LOG(ERROR) << "schedule node return nullptr, name: " << node->name_
1560 << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
1561 return RET_ERROR;
1562 }
1563 kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index)));
1564 dst_kernels->emplace_back(kernel);
1565 auto litert_kernel = reinterpret_cast<kernel::Kernel *>(kernel->kernel());
1566 if (MS_UNLIKELY(litert_kernel == nullptr)) {
1567 MS_LOG(ERROR) << "nullptr exist in scheduler.";
1568 return RET_ERROR;
1569 }
1570 primitives_.emplace(litert_kernel, static_cast<const schema::Primitive *>(primitive));
1571 }
1572 if (in_tensors != nullptr) {
1573 std::transform(subgraph->input_indices_.begin(), subgraph->input_indices_.end(), std::back_inserter(*in_tensors),
1574 [&](const uint32_t index) { return this->src_tensors_->at(index); });
1575 }
1576 if (out_tensors != nullptr) {
1577 std::transform(subgraph->output_indices_.begin(), subgraph->output_indices_.end(), std::back_inserter(*out_tensors),
1578 [&](const uint32_t index) { return this->src_tensors_->at(index); });
1579 }
1580 return RET_OK;
1581 }
1582
1583 namespace {
KernelFitCurrentSubGraphCPUFp32(TypeId data_type)1584 bool KernelFitCurrentSubGraphCPUFp32(TypeId data_type) {
1585 return (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat || data_type == kNumberTypeInt8 ||
1586 data_type == kNumberTypeInt || data_type == kNumberTypeInt32 || data_type == kNumberTypeInt64 ||
1587 data_type == kNumberTypeUInt8 || data_type == kNumberTypeBool);
1588 }
1589
KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type,const kernel::KernelExec & kernel)1590 bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::KernelExec &kernel) {
1591 switch (subgraph_type) {
1592 case kernel::SubGraphType::kNotSubGraph:
1593 case kernel::SubGraphType::kApuSubGraph:
1594 return false;
1595 case kernel::SubGraphType::kGpuFp16SubGraph:
1596 if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1597 return false;
1598 }
1599 return (kernel.desc().data_type != kNumberTypeFloat32);
1600 case kernel::SubGraphType::kGpuFp32SubGraph:
1601 if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1602 return false;
1603 }
1604 return (kernel.desc().data_type != kNumberTypeFloat16);
1605 case kernel::SubGraphType::kNpuSubGraph:
1606 return kernel.desc().arch == kernel::KERNEL_ARCH::kNPU;
1607 case kernel::SubGraphType::kCpuFP16SubGraph: {
1608 auto desc = kernel.desc();
1609 if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1610 return false;
1611 }
1612 #ifdef ENABLE_FP16
1613 if (desc.data_type == kNumberTypeInt8 || desc.data_type == kNumberTypeUInt8) {
1614 return true;
1615 }
1616 #endif
1617 return (desc.data_type == kNumberTypeFloat16);
1618 }
1619 case kernel::SubGraphType::kCpuFP32SubGraph: {
1620 auto desc = kernel.desc();
1621 if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1622 return false;
1623 }
1624 return KernelFitCurrentSubGraphCPUFp32(desc.data_type);
1625 }
1626 default:
1627 return false;
1628 }
1629 }
1630
FindAllSubGraphKernels(const std::vector<kernel::KernelExec * > & sorted_kernels,const InnerContext & context,size_t * cur_index,int schema_version)1631 kernel::KernelExec *FindAllSubGraphKernels(const std::vector<kernel::KernelExec *> &sorted_kernels,
1632 const InnerContext &context, size_t *cur_index, int schema_version) {
1633 std::vector<kernel::KernelExec *> sub_kernels;
1634 sub_kernels.emplace_back(sorted_kernels[*cur_index]);
1635 auto cur_sub_graph_type = GetKernelSubGraphType(sorted_kernels[*cur_index], context);
1636 for (*cur_index = *cur_index + 1; *cur_index < sorted_kernels.size(); ++(*cur_index)) {
1637 auto cur_kernel = sorted_kernels[*cur_index];
1638 MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph);
1639 // already a subgraph or a delegate
1640 if (cur_kernel->desc().arch == kernel::kDelegate) {
1641 --(*cur_index);
1642 break;
1643 }
1644 if (cur_kernel->subgraph_type() != kernel::kNotSubGraph ||
1645 !KernelFitCurrentSubGraph(cur_sub_graph_type, *cur_kernel)) {
1646 --(*cur_index);
1647 break;
1648 }
1649 sub_kernels.emplace_back(cur_kernel);
1650 }
1651 return kernel::KernelExecUtil::CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type, context,
1652 schema_version);
1653 }
1654 } // namespace
1655
ConstructNormalSubGraphs(const std::vector<kernel::KernelExec * > & src_kernel,std::vector<kernel::KernelExec * > * dst_kernel,std::map<const kernel::KernelExec *,bool> * is_kernel_finish)1656 int Scheduler::ConstructNormalSubGraphs(const std::vector<kernel::KernelExec *> &src_kernel,
1657 std::vector<kernel::KernelExec *> *dst_kernel,
1658 std::map<const kernel::KernelExec *, bool> *is_kernel_finish) {
1659 if (src_kernel.empty()) {
1660 return RET_OK;
1661 }
1662
1663 // construct subgraph
1664 for (size_t index = 0; index < src_kernel.size(); index++) {
1665 auto cur_kernel = src_kernel[index];
1666 MS_ASSERT(cur_kernel != nullptr);
1667 // Not support APU now
1668 MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph);
1669 if (cur_kernel->desc().arch == kernel::kDelegate) {
1670 dst_kernel->emplace_back(cur_kernel);
1671 continue;
1672 }
1673 // already a subgraph or a delegate
1674 if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
1675 dst_kernel->emplace_back(cur_kernel);
1676 continue;
1677 }
1678 auto subgraph = FindAllSubGraphKernels(src_kernel, *context_, &index, context_->get_schema_version());
1679 if (subgraph == nullptr) {
1680 MS_LOG(ERROR) << "Create SubGraphKernel failed";
1681 return RET_ERROR;
1682 }
1683 dst_kernel->emplace_back(subgraph);
1684 }
1685 for (auto *subgraph : *dst_kernel) {
1686 if (subgraph->desc().arch == kernel::kDelegate) {
1687 *infer_along_running_ = false;
1688 continue;
1689 }
1690 if (subgraph->subgraph_type() != kernel::kCpuFP32SubGraph &&
1691 subgraph->subgraph_type() != kernel::kCpuFP16SubGraph) {
1692 *infer_along_running_ = false;
1693 }
1694 auto subgraph_kernel = static_cast<kernel::SubGraphKernel *>(subgraph);
1695 if (subgraph_kernel == nullptr) {
1696 MS_LOG(ERROR) << "kernel: " << subgraph->name() << " not is subgraph kernel.";
1697 return RET_ERROR;
1698 }
1699 // this is for train session cpu fp16, should be removed in the future.
1700 auto ret = subgraph_kernel->SetFp16Attr();
1701 if (ret != RET_OK) {
1702 MS_LOG(ERROR) << "Init SubGraph failed: " << ret;
1703 return ret;
1704 }
1705 }
1706 return RET_OK;
1707 }
1708
GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor * > & in_tensors)1709 TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) {
1710 for (const auto &tensor : in_tensors) {
1711 auto dtype = tensor->data_type();
1712 if (dtype == kObjectTypeTensorType) {
1713 return TensorListDataType(tensor);
1714 }
1715 std::unordered_set<TypeId> type_set = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32,
1716 kNumberTypeBool, kNumberTypeUInt8, kObjectTypeString};
1717 if (type_set.find(dtype) != type_set.end()) {
1718 return dtype;
1719 }
1720 }
1721 if (in_tensors.empty()) {
1722 MS_LOG(ERROR) << "in tensor is empty.";
1723 return kTypeUnknown;
1724 }
1725 MS_ASSERT(!in_tensors.empty());
1726 return in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
1727 }
1728
PartialSubGraphType(const std::vector<kernel::KernelExec * > & kernels)1729 kernel::SubGraphType Scheduler::PartialSubGraphType(const std::vector<kernel::KernelExec *> &kernels) {
1730 if (std::any_of(kernels.begin(), kernels.end(),
1731 [](const kernel::KernelExec *item) { return item->desc().data_type == kNumberTypeFloat16; })) {
1732 return kernel::kCpuFP16SubGraph;
1733 }
1734 return kernel::kCpuFP32SubGraph;
1735 }
1736
InferSwitchShape(const lite::LiteGraph::Node * switch_node)1737 int Scheduler::InferSwitchShape(const lite::LiteGraph::Node *switch_node) {
1738 MS_ASSERT(src_model_ != nullptr);
1739 MS_ASSERT(switch_node != nullptr);
1740 std::deque<lite::LiteGraph::Node *> partial_cnode_to_infer{};
1741 for (size_t i = 1; i < switch_node->input_indices_.size(); ++i) {
1742 auto branch_output_index = switch_node->input_indices_.at(i);
1743 for (auto &node : src_model_->graph_.all_nodes_) {
1744 if (IsContain(node->output_indices_, branch_output_index) &&
1745 IsPartialNode(node->primitive_, context_->get_schema_version()) &&
1746 partial_cnode_inferred_.find(node) == partial_cnode_inferred_.end()) {
1747 partial_cnode_inferred_.insert(node);
1748 partial_cnode_to_infer.push_back(node);
1749 break;
1750 }
1751 }
1752 }
1753
1754 while (!partial_cnode_to_infer.empty()) {
1755 auto &node = partial_cnode_to_infer.front();
1756 partial_cnode_to_infer.pop_front();
1757 int ret = InferPartialShape(node);
1758 if (ret != RET_OK) {
1759 MS_LOG(WARNING) << "partial infer not ok, ret: " << ret;
1760 }
1761 }
1762 return RET_OK;
1763 }
1764
NodeInputIsSwitchType(const lite::LiteGraph::Node * node)1765 LiteGraph::Node *Scheduler::NodeInputIsSwitchType(const lite::LiteGraph::Node *node) {
1766 MS_ASSERT(src_model_ != nullptr);
1767 MS_ASSERT(node != nullptr);
1768 for (auto &iter : src_model_->graph_.all_nodes_) {
1769 if (iter->output_indices_ == node->input_indices_) {
1770 if (IsSwitchNode(iter->primitive_, context_->get_schema_version()) ||
1771 IsSwitchLayerNode(iter->primitive_, context_->get_schema_version())) {
1772 return iter;
1773 } else {
1774 return nullptr;
1775 }
1776 }
1777 }
1778 return nullptr;
1779 }
1780
SubGraphHasScheduled(const int & index)1781 bool Scheduler::SubGraphHasScheduled(const int &index) {
1782 return scheduled_subgraph_index_.find(index) != scheduled_subgraph_index_.end();
1783 }
1784
SubGraphMarkScheduled(const int & index)1785 void Scheduler::SubGraphMarkScheduled(const int &index) { scheduled_subgraph_index_.insert(index); }
1786
ConstructControlFlowMainGraph(std::vector<kernel::KernelExec * > * kernels)1787 int Scheduler::ConstructControlFlowMainGraph(std::vector<kernel::KernelExec *> *kernels) {
1788 auto back_kernels = *kernels;
1789 kernels->clear();
1790 std::vector<kernel::KernelExec *> main_graph_kernels{};
1791 for (auto &kernel : back_kernels) {
1792 if (kernel->subgraph_type() != kernel::kNotSubGraph) {
1793 kernels->push_back(kernel);
1794 } else {
1795 main_graph_kernels.push_back(kernel);
1796 }
1797 }
1798 auto cur_subgraph_type = PartialSubGraphType(main_graph_kernels);
1799 auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1800 main_graph_kernels, nullptr, nullptr, cur_subgraph_type, *context_, context_->get_schema_version());
1801 if (subgraph_kernel == nullptr) {
1802 MS_LOG(ERROR) << "create main graph for control flow model failed.";
1803 return RET_ERROR;
1804 }
1805 kernels->insert(kernels->begin(), subgraph_kernel);
1806 return RET_OK;
1807 }
1808
NonTailCallNodes()1809 std::vector<kernel::KernelExec *> Scheduler::NonTailCallNodes() {
1810 std::vector<kernel::KernelExec *> ret{};
1811 if (*is_control_flow_) {
1812 ret = control_flow_scheduler_->GetNonTailCalls();
1813 }
1814 return ret;
1815 }
1816 } // namespace mindspore::lite
1817