• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/scheduler.h"
18 #include <map>
19 #include <unordered_set>
20 #include <queue>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #include "src/tensorlist.h"
25 #include "nnacl/partial_fusion_parameter.h"
26 #include "include/errorcode.h"
27 #include "src/common/graph_util.h"
28 #include "src/common/utils.h"
29 #include "src/litert/kernel_registry.h"
30 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
31 #include "include/registry/register_kernel.h"
32 #endif
33 #include "src/litert/kernel_exec_util.h"
34 #include "src/executor/sub_graph_kernel.h"
35 #include "src/common/ops/populate/populate_register.h"
36 #include "src/common/version_manager.h"
37 #include "src/common/prim_util.h"
38 #include "src/litert/lite_model.h"
39 #include "src/common/tensor_util.h"
40 #include "src/common/context_util.h"
41 #include "src/litert/infer_manager.h"
42 #include "src/litert/runtime_pass.h"
43 #ifndef ENABLE_MULTI_LAYOUT
44 #include "src/litert/pass/format_pass/format_pass.h"
45 #endif
46 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
47 #include "src/litert/sub_graph_split.h"
48 #include "src/litert/pass/online_fusion/online_fusion_pass_registry.h"
49 #endif
50 #include "src/litert/weight_decoder.h"
51 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
52 #include "nnacl/nnacl_common.h"
53 #if GPU_OPENCL
54 #include "src/litert/kernel/opencl/opencl_subgraph.h"
55 #include "src/litert/kernel/gpu/opencl/opencl_runtime.h"
56 #endif
57 #include "include/registry/register_kernel_interface.h"
58 #include "extendrt/mindir_loader/abstract_base_model.h"
59 #include "src/litert/pack_weight_manager.h"
60 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
61 #include "thread/parallel_thread_pool_manager.h"
62 #endif
63 #ifdef SUPPORT_NNRT
64 #include "src/litert/delegate/nnrt/nnrt_delegate.h"
65 #endif
66 
67 using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
68 
69 namespace mindspore::lite {
70 namespace {
71 constexpr int kMainSubGraphIndex = 0;
72 constexpr int kWhereInputNum = 3;
73 constexpr int kIndex2 = 2;
74 constexpr int kIndex1 = 1;
75 }  // namespace
76 
77 namespace {
78 // support_fp16: current device and package support float16
CastKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel,bool support_fp16)79 int CastKernelWeight(const kernel::SubGraphType &belong_subgraph_type, const kernel::KernelExec *kernel,
80                      bool support_fp16) {
81   MS_ASSERT(kernel != nullptr);
82   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
83   if (belong_subgraph_type != kernel::kCpuFP32SubGraph && belong_subgraph_type != kernel::kCpuFP16SubGraph) {
84     return RET_OK;
85   }
86   for (auto *tensor : kernel->in_tensors()) {
87     MS_ASSERT(tensor != nullptr);
88     // only cast const tensor
89     // tensorlist not support fp16 now
90     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
91       continue;
92     }
93     // only support fp32->fp16 or fp16->fp32
94     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
95       continue;
96     }
97     if (tensor->data_type() == kNumberTypeFloat32 && belong_subgraph_type == kernel::kCpuFP16SubGraph) {
98       auto ret = CastConstTensorData(tensor, kNumberTypeFloat16, support_fp16);
99       if (ret != RET_OK) {
100         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
101         return ret;
102       }
103     } else if (tensor->data_type() == kNumberTypeFloat16 && belong_subgraph_type == kernel::kCpuFP32SubGraph) {
104       auto ret = CastConstTensorData(tensor, kNumberTypeFloat32, support_fp16);
105       if (ret != RET_OK) {
106         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
107         return ret;
108       }
109     } else {
110       MS_LOG(DEBUG) << "No need to cast";
111     }
112   }
113   return RET_OK;
114 }
115 
CopyConstTensorData(const std::vector<Tensor * > & tensors,int op_type)116 int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
117   // packed kernels such as conv don't need to copy because weight will be packed in kernel
118   if (lite::PackWeightManager::GetInstance()->IsCopyTensor(op_type)) {
119     return RET_OK;
120   }
121 
122   for (auto *tensor : tensors) {
123     // only copy non-copied const tensor
124     if (!tensor->IsConst() && tensor->data() != nullptr) {
125       MS_LOG(ERROR) << "Illegitimate tensor : " << tensor->tensor_name();
126       return RET_ERROR;
127     }
128     if (!tensor->IsConst() || tensor->own_data()) {
129       continue;
130     }
131     if (tensor->data_type() == kObjectTypeTensorType) {
132       // tensorlist's data is nullptr since ConvertTensors
133       // we never set or malloc data of tensorlist but malloc tensors in tensorlist
134       MS_ASSERT(tensor->data() == nullptr);
135     } else {
136       auto copy_tensor = Tensor::CopyTensor(*tensor, true);
137       if (copy_tensor == nullptr) {
138         MS_LOG(ERROR) << "Copy tensor failed";
139         return RET_ERROR;
140       }
141       tensor->FreeData();
142       tensor->set_data(copy_tensor->data());
143       tensor->set_own_data(true);
144       copy_tensor->set_data(nullptr);
145       delete (copy_tensor);
146     }
147   }
148   return RET_OK;
149 }
150 }  // namespace
151 
152 // support_fp16: current device and package support float16
HandleBuildinCpuKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel)153 int Scheduler::HandleBuildinCpuKernelWeight(const kernel::SubGraphType &belong_subgraph_type,
154                                             const kernel::KernelExec *kernel) {
155   MS_ASSERT(kernel != nullptr);
156   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
157   if (is_train_session_ || kernel->type() == schema::PrimitiveType_Custom ||
158       kernel->desc().provider != kernel::kBuiltin) {
159     return RET_OK;
160   }
161   auto ret = CastKernelWeight(belong_subgraph_type, kernel, context_->device_and_pkg_support_fp16_);
162   if (ret != RET_OK) {
163     MS_LOG(DEBUG) << "CastKernelWeight failed: " << ret;
164     return RET_NOT_SUPPORT;
165   }
166   if (!(reinterpret_cast<LiteModel *>(src_model_)->keep_model_buf())) {
167     // we don't need to restore tensor for copy data
168     MS_CHECK_TRUE_RET(kernel->op_parameter() != nullptr, RET_ERROR);
169     ret = CopyConstTensorData(kernel->in_tensors(), kernel->op_parameter()->type_);
170     if (ret != RET_OK) {
171       MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
172       return RET_NOT_SUPPORT;
173     }
174   }
175   return RET_OK;
176 }
177 
InitKernels(std::vector<kernel::KernelExec * > && dst_kernels)178 int Scheduler::InitKernels(std::vector<kernel::KernelExec *> &&dst_kernels) {
179   if (is_train_session_) {
180     return RET_OK;
181   }
182   for (auto kernel : dst_kernels) {
183     // delegate graph kernel
184     if (kernel->desc().arch == kernel::kDelegate) {
185       continue;
186     }
187     auto subgraph_type = kernel->subgraph_type();
188     if (subgraph_type == kernel::kNotSubGraph) {
189       MS_LOG(ERROR) << "construct subgraph failed.";
190       return RET_ERROR;
191     }
192     auto subgraph_nodes = reinterpret_cast<kernel::SubGraphKernel *>(kernel)->nodes();
193     for (auto node : subgraph_nodes) {
194       for (auto *tensor : node->out_tensors()) {
195         if (tensor->IsConst()) {
196           MS_CHECK_TRUE_MSG(node->op_parameter() != nullptr, RET_NULL_PTR, "node's op_parameter is invalid.");
197           if (node->op_parameter()->type_ == ::PrimType::PrimType_Inner_ShapeFusion) {
198             continue;
199           }
200           MS_LOG(ERROR) << "Illegitimate kernel output tensor : " << tensor->tensor_name();
201           continue;
202         }
203       }
204       auto ret = HandleBuildinCpuKernelWeight(subgraph_type, node);
205       if (ret != RET_OK) {
206         return ret;
207       }
208     }
209 #if GPU_OPENCL
210     if (kernel->desc().arch == kernel::kGPU) {
211       if (this->GetEnableGLTexture() == true && (kernel == dst_kernels.front() || kernel == dst_kernels.back() - 1)) {
212         kernel->SetOpenGLTextureEnable(true);
213         MS_LOG(INFO) << "Set OpenGLSharingMem for subgraph success!" << std::endl;
214       }
215       auto ret = reinterpret_cast<kernel::OpenCLSubGraph *>(kernel)->RunPass();
216       if (ret != RET_OK) {
217         MS_LOG(ERROR) << "OpenCLSubGraph RunPass failed.";
218         return ret;
219       }
220     }
221 #endif
222   }
223   return RET_OK;
224 }
225 
SchedulePreProcess()226 int Scheduler::SchedulePreProcess() {
227 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
228   auto search_sub_graph =
229     SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
230 #endif
231 
232 #if !defined(RUNTIME_PASS_CLIP)
233   OnlineFusionRegistry::GetInstance()->DoOnlineFusionPass(&search_sub_graph);
234 #endif
235 
236   this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
237 
238   if (src_model_->model_type_ != ModelType_MSLite) {
239     // call abstract model infer interface
240     *is_infershape_ = RET_OK;
241   } else {
242     *is_infershape_ = InferSubGraphShape(kMainSubGraphIndex);
243   }
244   if (*is_infershape_ != RET_OK && *is_infershape_ != RET_INFER_INVALID) {
245     MS_LOG(ERROR) << "op infer shape failed.";
246     return *is_infershape_;
247   }
248 
249   if (context_->enable_parallel_) {
250 #ifndef AUTO_PARALLEL_CLIP
251     if (*is_infershape_ != RET_INFER_INVALID) {
252       search_sub_graph.SubGraphSplit();
253     }
254 #else
255     MS_LOG(ERROR) << unsupport_auto_parallel_log;
256     return RET_NOT_SUPPORT;
257 #endif
258   }
259   return RET_OK;
260 }
261 
CheckCpuValid(const std::vector<kernel::KernelExec * > * dst_kernels) const262 int Scheduler::CheckCpuValid(const std::vector<kernel::KernelExec *> *dst_kernels) const {
263   if (context_->IsDeviceTypeEnabled(DT_CPU)) {
264     return RET_OK;
265   }
266   for (auto kernel : *dst_kernels) {
267     if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
268       MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
269       return RET_ERROR;
270     }
271   }
272   return RET_OK;
273 }
274 
ConstructSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)275 int Scheduler::ConstructSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels) {
276   if (*is_control_flow_) {
277     return ConstructControlFlowMainGraph(dst_kernels);
278   }
279 
280   auto src_kernel = *dst_kernels;
281   dst_kernels->clear();
282   std::map<const kernel::KernelExec *, bool> is_kernel_finish;
283   return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
284 }
285 
ProcessSubGraphTranspose(std::vector<kernel::KernelExec * > * dst_kernels)286 int Scheduler::ProcessSubGraphTranspose(std::vector<kernel::KernelExec *> *dst_kernels) {
287 #ifndef ENABLE_MULTI_LAYOUT
288   auto ret = pass::RuntimeFormatPass(dst_kernels, src_tensors_, Format::NHWC);
289   if (ret != RET_OK) {
290     MS_LOG(ERROR) << "Run runtime format pass failed.";
291     return RET_ERROR;
292   }
293 #endif
294   return RET_OK;
295 }
296 
DelQuantDTypeCastKernel(std::vector<kernel::KernelExec * > * kernels)297 STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *kernels) {
298   for (auto iter = (*kernels).begin(); iter != (*kernels).end();) {
299     auto cur_kernel = *iter;
300     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
301       auto sub_inner_graph = reinterpret_cast<kernel::SubGraphKernel *>(cur_kernel);
302       auto &subgraph_nodes = sub_inner_graph->nodes();
303       if (DelQuantDTypeCastKernel(&subgraph_nodes) != RET_OK) {
304         MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
305         return RET_ERROR;
306       }
307     }
308     if (cur_kernel->type() != schema::PrimitiveType_QuantDTypeCast) {
309       iter++;
310       continue;
311     }
312     auto &post_kernels = cur_kernel->out_kernels();
313     auto &pre_kernels = cur_kernel->in_kernels();
314     if (cur_kernel->in_tensors().size() != 1) {
315       MS_LOG(ERROR) << cur_kernel->name() << " input size error."
316                     << " cur_kernel in tensors size:" << cur_kernel->in_tensors().size();
317       return RET_ERROR;
318     }
319     bool graph_input = pre_kernels.empty();
320     if (!graph_input) {
321       // modify post kernel input to new kernel and new tensor
322       for (auto post_kernel : post_kernels) {
323         auto post_in_kernels = post_kernel->in_kernels();
324         auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
325         *post_input_iter = pre_kernels[0];
326         post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
327         post_kernel->set_in_kernels(post_in_kernels);
328       }
329       auto pre_out_kernels = pre_kernels[0]->out_kernels();
330       auto pre_out_iter = std::find(pre_out_kernels.begin(), pre_out_kernels.end(), cur_kernel);
331       if (pre_out_iter != pre_out_kernels.end()) {
332         pre_out_kernels.erase(pre_out_iter);
333         pre_out_kernels.insert(pre_out_iter, post_kernels.begin(), post_kernels.end());
334         pre_kernels[0]->set_out_kernels(pre_kernels);
335       }
336     } else {
337       for (auto post_kernel : post_kernels) {
338         auto post_in_kernels = post_kernel->in_kernels();
339         auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
340         *post_input_iter = {};
341         post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
342         post_kernel->set_in_kernels(post_in_kernels);
343       }
344     }
345 
346     // update data type
347     for (auto tensor : cur_kernel->in_tensors()) {
348       tensor->set_data_type(kNumberTypeFloat32);
349     }
350     for (auto tensor : cur_kernel->out_tensors()) {
351       tensor->set_data_type(kNumberTypeFloat32);
352     }
353 
354     // update model output kernel & tensor
355     if (cur_kernel->is_model_output()) {
356       pre_kernels[0]->set_is_model_output(true);
357       cur_kernel->in_tensors()[0]->set_category(Category::GRAPH_OUTPUT);
358       pre_kernels[0]->set_out_kernels({});
359       // If the current kernel is the output kernel, use the current output tensor as the output tensor of the previous
360       // node.
361       auto pre_out_tensors = pre_kernels[0]->out_tensors();
362       auto tensor_iter = std::find(pre_out_tensors.begin(), pre_out_tensors.end(), cur_kernel->in_tensors()[0]);
363       if (tensor_iter != pre_kernels[0]->out_tensors().end()) {
364         *tensor_iter = cur_kernel->out_tensors()[0];
365       }
366     }
367 
368     // delete cur kernel
369     iter = kernels->erase(iter);
370     MS_LOG(DEBUG) << "Delete kernel: " << cur_kernel->name();
371     delete cur_kernel;
372   }
373   return RET_OK;
374 }
375 
Schedule(std::vector<kernel::KernelExec * > * dst_kernels)376 int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
377   MS_LOG(DEBUG) << "Start schedule.";
378   int check_input_ret = CheckInputParam(dst_kernels);
379   if (check_input_ret != RET_OK) {
380     MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
381     return check_input_ret;
382   }
383 
384   shape_fusion_pass_ =
385     std::make_shared<ShapeFusionPass>(context_, reinterpret_cast<LiteModel *>(src_model_), src_tensors_);
386   MS_CHECK_TRUE_RET(shape_fusion_pass_ != nullptr, RET_ERROR);
387   int ret = SchedulePreProcess();
388   if (ret != RET_OK) {
389     return ret;
390   }
391 
392   if (*is_control_flow_) {
393     control_flow_scheduler_ = std::make_shared<ControlFlowScheduler>(context_, ms_context_, src_tensors_);
394     MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "new control scheduler failed.");
395   }
396 
397   ret = ScheduleGraphToKernels(dst_kernels);
398   FreeOpParameters();
399   op_parameters_.clear();
400   if (ret != RET_OK) {
401     MS_LOG(ERROR) << "Schedule graph to kernels failed.";
402     return ret;
403   }
404   if (context_->float_mode) {
405     kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
406     ret = DelQuantDTypeCastKernel(dst_kernels);
407     if (ret != RET_OK) {
408       MS_LOG(ERROR) << "Delete quant_dtype_cast kernel failed.";
409       return ret;
410     }
411   }
412   shape_fusion_pass_->StoreStateAndReset();
413 
414   MS_LOG(DEBUG) << "Start to init delegate kernels.";
415   ret = InitDelegateKernels(dst_kernels);
416   if (ret != RET_OK) {
417     MS_LOG(ERROR) << "Repalce delegate kernels failed.";
418     return ret;
419   }
420   MS_LOG(DEBUG) << "Finish to init delegate kernels.";
421 
422   ret = CheckCpuValid(dst_kernels);
423   if (ret != RET_OK) {
424     MS_LOG(ERROR) << "kernels invalid in set devices.";
425     return ret;
426   }
427 
428   kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
429 
430   ret = ConstructSubGraphs(dst_kernels);
431   if (ret != RET_OK) {
432     MS_LOG(ERROR) << "ConstructSubGraphs failed.";
433     return ret;
434   }
435 
436   ret = ProcessSubGraphTranspose(dst_kernels);
437   if (ret != RET_OK) {
438     MS_LOG(ERROR) << "Process SubGraph with multi layout failed.";
439     return ret;
440   }
441 
442   if (*is_control_flow_) {
443     control_flow_scheduler_->SetSubgraphForPartialNode(&partial_kernel_subgraph_index_map_,
444                                                        &subgraph_index_subgraph_kernel_map_);
445     ret = control_flow_scheduler_->Schedule(dst_kernels);
446     MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "control flow schedule failed.");
447   }
448 
449   auto status = RuntimePass(dst_kernels, src_tensors_);
450   if (status != RET_OK) {
451     MS_LOG(ERROR) << "runtime pass failed.";
452     return RET_ERROR;
453   }
454 
455   ret = InitKernels(std::move(*dst_kernels));
456   if (ret != RET_OK) {
457     MS_LOG(ERROR) << "InitKernels failed.";
458     return ret;
459   }
460   shape_fusion_pass_->RestoreState();
461   if (IsPrintDebug()) {
462     MS_LOG(DEBUG) << "schedule kernels success.";
463     for (auto subgraph : *dst_kernels) {
464       MS_LOG(DEBUG) << "[subgraph] : " << subgraph->name() << ",  type:" << subgraph->subgraph_type();
465       if (subgraph->desc().arch == kernel::KERNEL_ARCH::kDelegate) {
466         continue;
467       }
468       std::vector<kernel ::KernelExec *> kernel_list = reinterpret_cast<kernel::SubGraphKernel *>(subgraph)->nodes();
469       for (auto kernel : kernel_list) {
470         MS_LOG(DEBUG) << "kernel: [" << kernel->name() << "] "
471                       << "TypeId(" << kernel->desc().data_type << "); "
472                       << "OpType(" << PrimitiveCurVersionTypeName(kernel->desc().type) << "); "
473                       << "format(" << kernel->desc().format << "); "
474                       << "arch(" << kernel->desc().arch << ")";
475       }
476     }
477   }
478   return RET_OK;
479 }
480 
CheckInputParam(const std::vector<kernel::KernelExec * > * dst_kernels) const481 int Scheduler::CheckInputParam(const std::vector<kernel::KernelExec *> *dst_kernels) const {
482   if (dst_kernels == nullptr) {
483     return RET_ERROR;
484   }
485   if (src_model_ == nullptr) {
486     MS_LOG(ERROR) << "Input model is nullptr";
487     return RET_PARAM_INVALID;
488   }
489   if (src_model_->graph_.sub_graphs_.empty()) {
490     MS_LOG(ERROR) << "Model should have a subgraph at least";
491     return RET_PARAM_INVALID;
492   }
493   return RET_OK;
494 }
495 
496 #ifndef DELEGATE_CLIP
ReplaceDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)497 int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
498   std::vector<kernel::Kernel *> kernels;
499   for (size_t i = 0; i < dst_kernels->size(); i++) {
500     auto litert_kernel = reinterpret_cast<kernel::Kernel *>((*dst_kernels)[i]->kernel());
501     if (MS_UNLIKELY(litert_kernel == nullptr)) {
502       MS_LOG(ERROR) << "nullptr exist in dst_kernels.";
503       return RET_ERROR;
504     }
505     kernels.push_back(litert_kernel);
506   }
507 
508   ms_inputs_ = LiteTensorsToMSTensors(*inputs_);
509   ms_outputs_ = LiteTensorsToMSTensors(*outputs_);
510   auto schema_version = static_cast<SchemaVersion>(context_->get_schema_version());
511   DelegateModel<schema::Primitive> *model =
512     new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
513   if (model == nullptr) {
514     MS_LOG(ERROR) << "New delegate model failed.";
515     return RET_NULL_PTR;
516   }
517 
518 #ifdef SUPPORT_NNRT
519   if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
520     auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
521     delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
522     void *meta_graph = reinterpret_cast<void *>(
523       const_cast<mindspore::schema::MetaGraph *>(mindspore::schema::GetMetaGraph(this->src_model_->buf)));
524     delegate->SetMetaGraph(meta_graph);
525     delegate->SetDequantTensors(this->src_tensors_);
526   }
527 #endif
528 
529   auto ret = delegate_->Build(model);
530   if (ret != mindspore::kSuccess) {
531     delete model;
532     MS_LOG(ERROR) << "Delegate prepare kernels failed.";
533     return RET_ERROR;
534   }
535 
536   auto src_kernels = *dst_kernels;
537   dst_kernels->clear();
538   std::map<const kernel::KernelExec *, bool> delegate_support;
539   for (auto kernel : src_kernels) {
540     delegate_support[kernel] = true;
541   }
542   for (auto kernel : kernels) {
543     size_t index = 0;
544     for (; index < src_kernels.size(); index++) {
545       if (kernel == src_kernels[index]->kernel()) {
546         // Kernels that the delegate does not support keep the original backend
547         dst_kernels->push_back(src_kernels[index]);
548         delegate_support[src_kernels[index]] = false;
549         break;
550       }
551     }
552     if (index == src_kernels.size()) {
553       // New liteKernel to save delegate subgraph
554       std::shared_ptr<kernel::Kernel> shared_kernel(kernel);
555       auto kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel);
556       if (kernel_exec == nullptr) {
557         delete model;
558         MS_LOG(ERROR) << "New KernelExec for delegate subgraph failed.";
559         return RET_NULL_PTR;
560       }
561       auto delegate_type = kNumberTypeFloat32;
562       for (auto &input : kernel->inputs()) {
563         if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) {
564           delegate_type = kNumberTypeFloat16;
565           break;
566         }
567       }
568       kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""};
569       kernel_exec->set_desc(delegate_desc);
570       dst_kernels->push_back(kernel_exec);
571     }
572   }
573   // Release the cpu kernel that has been replace by delegate subgraph, as well as their tensor data
574   for (auto kernel : src_kernels) {
575     if (delegate_support[kernel] == true) {
576       auto inputs = kernel->in_tensors();
577       for (auto *tensor : inputs) {
578         MS_ASSERT(tensor != nullptr);
579         if (tensor->IsConst()) {
580           tensor->FreeData();
581         }
582       }
583       delete kernel;
584     }
585   }
586   delete model;
587   return RET_OK;
588 }
589 
InitDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)590 int Scheduler::InitDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
591   /* no delegate valid */
592   if (delegate_ == nullptr) {
593     return RET_OK;
594   }
595 
596   /* set delegate spin count */
597   context_->thread_pool_->SetSpinCountMinValue();
598 
599   /* external delegate */
600   if (delegate_device_type_ == -1) {
601     auto ret = ReplaceDelegateKernels(dst_kernels);
602     if (ret != RET_OK) {
603       MS_LOG(ERROR) << "external delegate init failed.";
604       return ret;
605     }
606   }
607 
608   /* Inner delegate  :  check Priority */
609   std::vector<kernel::KernelExec *> src_kernels = *dst_kernels;
610   dst_kernels->clear();
611 
612   while (!src_kernels.empty()) {
613     std::vector<kernel::KernelExec *> tmp_kernels;
614     kernel::KernelExec *remain_kernel = nullptr;
615 
616     /* Loop for inner delegate npu and TensorRT subgraph */
617     while (!src_kernels.empty()) {
618       auto kernel = src_kernels.front();
619       VectorErase(&src_kernels, kernel);
620       bool priority_ret =
621         DeviceTypePriority(context_, delegate_device_type_, KernelArchToDeviceType(kernel->desc().arch));
622       if (priority_ret == true) {
623         tmp_kernels.push_back(kernel);
624       } else {
625         remain_kernel = kernel;
626         break;
627       }
628     }
629 
630     /* start current NPU-kernels replace */
631     if (tmp_kernels.empty()) {
632       if (remain_kernel != nullptr) {
633         dst_kernels->push_back(remain_kernel);
634         remain_kernel = nullptr;
635       }
636       continue;
637     }
638     auto ret = ReplaceDelegateKernels(&tmp_kernels);
639     if (ret != RET_OK) {
640       dst_kernels->insert(dst_kernels->end(), src_kernels.begin(), src_kernels.end());
641       dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
642       if (remain_kernel != nullptr) {
643         dst_kernels->push_back(remain_kernel);
644       }
645       MS_LOG(ERROR) << "Inner delegate replace delegate kernels failed.";
646       return ret;
647     }
648 
649     dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
650     tmp_kernels.clear();
651     if (remain_kernel != nullptr) {
652       dst_kernels->push_back(remain_kernel);
653       remain_kernel = nullptr;
654     }
655   }
656 
657   return RET_OK;
658 }
659 #endif
660 
FindNodeInoutTensors(const lite::LiteGraph::Node & node,std::vector<Tensor * > * inputs,std::vector<Tensor * > * outputs)661 void Scheduler::FindNodeInoutTensors(const lite::LiteGraph::Node &node, std::vector<Tensor *> *inputs,
662                                      std::vector<Tensor *> *outputs) {
663   MS_ASSERT(inputs != nullptr);
664   MS_ASSERT(outputs != nullptr);
665   auto in_size = node.input_indices_.size();
666   inputs->reserve(in_size);
667   for (size_t j = 0; j < in_size; ++j) {
668     inputs->emplace_back(src_tensors_->at(node.input_indices_[j]));
669   }
670   auto out_size = node.output_indices_.size();
671   outputs->reserve(out_size);
672   for (size_t j = 0; j < out_size; ++j) {
673     outputs->emplace_back(src_tensors_->at(node.output_indices_[j]));
674   }
675 }
676 
InferNodeShape(const lite::LiteGraph::Node * node)677 int Scheduler::InferNodeShape(const lite::LiteGraph::Node *node) {
678   MS_ASSERT(node != nullptr);
679   auto primitive = node->primitive_;
680   MS_ASSERT(primitive != nullptr);
681   std::vector<Tensor *> inputs;
682   std::vector<Tensor *> outputs;
683   FindNodeInoutTensors(*node, &inputs, &outputs);
684   auto ret =
685     KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders(), context_->get_schema_version());
686   if (ret != RET_NOT_SUPPORT) {
687     *infer_along_running_ = false;
688     return ret;
689   }
690 
691   auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(
692     GetPrimitiveType(node->primitive_, context_->get_schema_version()), context_->get_schema_version());
693   if (parame_gen == nullptr) {
694     MS_LOG(ERROR) << "parameter generator is nullptr.";
695     FreeOpParameters();
696     return RET_NULL_PTR;
697   }
698   auto parameter = parame_gen(primitive);
699   if (parameter == nullptr) {
700     MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
701                   << GetPrimitiveTypeName(primitive, context_->get_schema_version());
702     FreeOpParameters();
703     return RET_ERROR;
704   }
705 
706   parameter->quant_type_ = node->quant_type_;
707   parameter->thread_num_ = context_->thread_num_;
708   if (node->output_indices_.empty()) {
709     MS_LOG(ERROR) << "The output size is invalid";
710     if (parameter->destroy_func_ != nullptr) {
711       parameter->destroy_func_(parameter);
712     }
713     free(parameter);
714     return RET_ERROR;
715   }
716   if (op_parameters_.find(node->output_indices_.at(0)) != op_parameters_.end()) {
717     if (parameter->destroy_func_ != nullptr) {
718       parameter->destroy_func_(parameter);
719     }
720     free(parameter);
721     parameter = op_parameters_[node->output_indices_.at(0)];
722   } else {
723     op_parameters_[node->output_indices_.at(0)] = parameter;
724   }
725 
726   if (IsCallNode(primitive, context_->get_schema_version())) {
727     return InferCallShape(node);
728   }
729   ret = KernelInferShape(inputs, outputs, parameter, context_->allocator);
730   if (ret != RET_OK && ret != RET_INFER_INVALID) {
731     FreeOpParameters();
732     return RET_ERROR;
733   }
734   for (auto &output : outputs) {
735     output->set_shape_changed(false);
736   }
737   if (*is_control_flow_) {
738     for (auto &output : outputs) {
739       output->set_shape({-1});
740     }
741     return RET_INFER_INVALID;
742   }
743 
744   if (ret == RET_OK) {
745     for (auto &output : outputs) {
746       if (static_cast<size_t>(output->ElementsNum()) >= GetMaxMallocSize() / sizeof(int64_t)) {
747         MS_LOG(ERROR) << "The size of output tensor is too big";
748         FreeOpParameters();
749         return RET_ERROR;
750       }
751     }
752   } else if (ret != RET_INFER_INVALID) {
753     FreeOpParameters();
754     return RET_ERROR;
755   }
756   return ret;
757 }
758 
FreeOpParameters()759 void Scheduler::FreeOpParameters() {
760   for (auto &param : op_parameters_) {
761     if (param.second != nullptr) {
762       if (param.second->destroy_func_ != nullptr) {
763         param.second->destroy_func_(param.second);
764       }
765       free(param.second);
766       param.second = nullptr;
767     }
768   }
769 }
770 
RestoreSubGraphInput(const lite::LiteGraph::Node * partial_node)771 int Scheduler::RestoreSubGraphInput(const lite::LiteGraph::Node *partial_node) {
772   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
773   MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
774   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
775   for (size_t i = 0; i < subgraph->input_indices_.size(); ++i) {
776     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
777     subgraph_input->set_data(nullptr);
778   }
779   return RET_OK;
780 }
781 
CopyCommonTensor(Tensor * dst_tensor,Tensor * src_tensor)782 void CopyCommonTensor(Tensor *dst_tensor, Tensor *src_tensor) {
783   dst_tensor->set_data_type(src_tensor->data_type());
784   dst_tensor->set_shape(src_tensor->shape());
785   dst_tensor->set_format(src_tensor->format());
786   dst_tensor->set_data(src_tensor->data());
787 }
788 
CopyPartialShapeToSubGraph(const lite::LiteGraph::Node * partial_node)789 int Scheduler::CopyPartialShapeToSubGraph(const lite::LiteGraph::Node *partial_node) {
790   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
791   MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
792   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
793   if (subgraph->input_indices_.size() != partial_node->input_indices_.size()) {
794     MS_LOG(ERROR) << "partial node " << partial_node->name_ << " inputs size: " << partial_node->input_indices_.size()
795                   << " vs "
796                   << " subgraph input size: " << subgraph->input_indices_.size();
797     return RET_PARAM_INVALID;
798   }
799 
800   for (size_t i = 0; i < partial_node->input_indices_.size(); ++i) {
801     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
802     auto &partial_input = src_tensors_->at(partial_node->input_indices_[i]);
803     if (partial_input->data_type() == kObjectTypeTensorType) {
804       return RET_INFER_INVALID;
805     }
806     CopyCommonTensor(subgraph_input, partial_input);
807   }
808 
809   return RET_OK;
810 }
811 
InferPartialShape(const lite::LiteGraph::Node * node)812 int Scheduler::InferPartialShape(const lite::LiteGraph::Node *node) {
813   MS_ASSERT(src_model_ != nullptr);
814   MS_ASSERT(node != nullptr);
815   if (!IsPartialNode(node->primitive_, context_->get_schema_version())) {
816     MS_LOG(ERROR) << "Node is not a partial";
817     return RET_PARAM_INVALID;
818   }
819   CopyPartialShapeToSubGraph(node);
820   int subgraph_index = GetPartialGraphIndex(node->primitive_, context_->get_schema_version());
821   auto ret = InferSubGraphShape(subgraph_index);
822   if (ret != RET_OK) {
823     MS_LOG(WARNING) << "infer subgraph: " << subgraph_index << " failed, ret:" << ret;
824   }
825   RestoreSubGraphInput(node);
826   return ret;
827 }
828 
NodeInputIsPartial(const lite::LiteGraph::Node * node)829 LiteGraph::Node *Scheduler::NodeInputIsPartial(const lite::LiteGraph::Node *node) {
830   MS_ASSERT(src_model_ != nullptr);
831   MS_ASSERT(node != nullptr);
832   for (auto &iter : src_model_->graph_.all_nodes_) {
833     if (iter->output_indices_ == node->input_indices_) {
834       if (IsPartialNode(iter->primitive_, context_->get_schema_version())) {
835         return iter;
836       } else {
837         return nullptr;
838       }
839     }
840   }
841   return nullptr;
842 }
843 
InferCallShape(const lite::LiteGraph::Node * node)844 int Scheduler::InferCallShape(const lite::LiteGraph::Node *node) {
845   MS_ASSERT(src_model_ != nullptr);
846   MS_ASSERT(node != nullptr);
847   if (!IsCallNode(node->primitive_, context_->get_schema_version())) {
848     MS_LOG(ERROR) << "Node is not a call cnode";
849     return RET_PARAM_INVALID;
850   }
851 
852   auto partial_input = NodeInputIsPartial(node);
853   if (partial_input) {
854     return InferPartialShape(partial_input);
855   }
856   auto switch_input = NodeInputIsSwitchType(node);
857   if (switch_input) {
858     *is_control_flow_ = true;
859     return InferSwitchShape(switch_input);
860   }
861 
862   MS_LOG(ERROR) << "call input is not partial and also not switch.";
863   return RET_ERROR;
864 }
865 
InferSubGraphShape(size_t subgraph_index)866 int Scheduler::InferSubGraphShape(size_t subgraph_index) {
867   MS_ASSERT(src_model_ != nullptr);
868   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
869   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
870   if (find(infer_subgraph_index_.begin(), infer_subgraph_index_.end(), subgraph_index) != infer_subgraph_index_.end()) {
871     MS_LOG(ERROR) << "The subgraph has been infer shape, subgraph index: " << subgraph_index;
872     return RET_INFER_INVALID;
873   }
874   infer_subgraph_index_.push_back(subgraph_index);
875   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
876   int subgraph_infershape_ret = RET_OK;
877   auto node_indexes = subgraph->node_indices_;
878   for (size_t i = 0; i < node_indexes.size(); ++i) {
879     auto node_index = node_indexes[i];
880     auto node = src_model_->graph_.all_nodes_[node_index];
881     MS_ASSERT(node != nullptr);
882     auto *primitive = node->primitive_;
883     if (primitive == nullptr) {
884       MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
885       return RET_ERROR;
886     }
887     if (node->node_type_ == schema::PrimitiveType_Shape) {
888       // convert shape to built-in shape
889       MS_CHECK_TRUE_RET(node->input_indices_.size() == 1, RET_ERROR);
890       shape_fusion_pass_->Run(node, subgraph_index);
891       node_indexes = subgraph->node_indices_;
892     }
893     auto ret = InferNodeShape(node);
894     if (ret == RET_INFER_INVALID) {
895       MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_
896                    << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version())
897                    << ", set infer flag to false.";
898       subgraph_infershape_ret = RET_INFER_INVALID;
899     } else if (ret != RET_OK) {
900       FreeOpParameters();
901       MS_LOG(ERROR) << "InferShape failed, name: " << node->name_
902                     << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
903       return RET_INFER_ERR;
904     }
905   }
906   return subgraph_infershape_ret;
907 }
908 
909 namespace {
910 // support_fp16: current device and package support float16
CastAndRestoreConstTensorData(Tensor * tensor,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)911 int CastAndRestoreConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors,
912                                   TypeId dst_data_type, bool support_fp16) {
913   MS_ASSERT(tensor != nullptr);
914   MS_ASSERT(tensor->IsConst());
915   MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
916   MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16);
917   if (tensor->data_type() == dst_data_type) {
918     return RET_OK;
919   }
920   auto origin_data = tensor->data();
921   MS_ASSERT(origin_data != nullptr);
922   auto restore_tensor = Tensor::CopyTensor(*tensor, false);
923   if (restore_tensor == nullptr) {
924     return RET_NULL_PTR;
925   }
926   restore_tensor->set_data(origin_data);
927   restore_tensor->set_own_data(tensor->own_data());
928   tensor->set_data(nullptr);
929   tensor->set_data_type(dst_data_type);
930   auto ret = tensor->MallocData();
931   if (RET_OK != ret) {
932     MS_LOG(ERROR) << "malloc data failed";
933     return ret;
934   }
935   auto new_tensor_data = tensor->data();
936   MS_ASSERT(new_tensor_data != nullptr);
937   if (dst_data_type == kNumberTypeFloat32) {
938     Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
939   } else {  // dst_data_type == kNumberTypeFloat16
940     Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
941   }
942   if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
943     MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
944     delete restore_tensor;
945     return RET_ERROR;
946   }
947   (*restored_origin_tensors)[tensor] = restore_tensor;
948   return RET_OK;
949 }
950 
951 // support_fp16: current device and package support float16
CastConstTensorsData(const std::vector<Tensor * > & tensors,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)952 int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
953                          TypeId dst_data_type, bool support_fp16) {
954   MS_ASSERT(restored_origin_tensors != nullptr);
955   if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
956     MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type.";
957     return RET_PARAM_INVALID;
958   }
959   for (auto *tensor : tensors) {
960     MS_ASSERT(tensor != nullptr);
961     // only cast const tensor
962     // tensorlist not support fp16 now
963     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
964       continue;
965     }
966     // only support fp32->fp16 or fp16->fp32
967     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
968       continue;
969     }
970     if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
971       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16, support_fp16);
972       if (ret != RET_OK) {
973         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
974         return ret;
975       }
976     } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
977       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32, support_fp16);
978       if (ret != RET_OK) {
979         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
980         return ret;
981       }
982     } else {
983       MS_LOG(DEBUG) << "No need to cast from " << tensor->data_type() << " to " << dst_data_type;
984     }
985   }
986   return RET_OK;
987 }
988 
FreeRestoreTensors(std::map<Tensor *,Tensor * > * restored_origin_tensors)989 inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
990   MS_ASSERT(restored_origin_tensors != nullptr);
991   for (auto &restored_origin_tensor : *restored_origin_tensors) {
992     restored_origin_tensor.second->set_data(nullptr);
993     delete (restored_origin_tensor.second);
994     restored_origin_tensor.second = nullptr;
995   }
996   restored_origin_tensors->clear();
997 }
998 
RestoreTensorData(std::map<Tensor *,Tensor * > * restored_origin_tensors)999 inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
1000   MS_ASSERT(restored_origin_tensors != nullptr);
1001   for (auto &restored_origin_tensor : *restored_origin_tensors) {
1002     auto *origin_tensor = restored_origin_tensor.first;
1003     auto *restored_tensor = restored_origin_tensor.second;
1004     MS_ASSERT(origin_tensor != nullptr);
1005     MS_ASSERT(restored_tensor != nullptr);
1006     origin_tensor->FreeData();
1007     origin_tensor->set_data_type(restored_tensor->data_type());
1008     origin_tensor->set_data(restored_tensor->data());
1009     origin_tensor->set_own_data(restored_tensor->own_data());
1010   }
1011   FreeRestoreTensors(restored_origin_tensors);
1012 }
1013 }  // namespace
1014 
ResetByExecutionPlan(std::string node_name,TypeId * data_type)1015 void Scheduler::ResetByExecutionPlan(std::string node_name, TypeId *data_type) {
1016   if (execution_plan_ == nullptr) {
1017     return;
1018   }
1019   auto iter = execution_plan_->find(node_name);
1020   if (iter != execution_plan_->end()) {
1021     *data_type = iter->second;
1022   }
1023   return;
1024 }
1025 
FindCpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,TypeId kernel_data_type,kernel::KernelExec ** kernel)1026 int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1027                              OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
1028                              kernel::KernelExec **kernel) {
1029   MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr.");
1030   auto op_type = op_parameter->type_;
1031   if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1032     MS_LOG(INFO) << "unsupported op_type: " << PrimitiveCurVersionTypeName(op_type)
1033                  << ", data_type: " << desc.data_type;
1034     return RET_NOT_SUPPORT;
1035   }
1036   kernel::KernelKey cpu_desc = desc;
1037   if (kernel_data_type == kNumberTypeFloat16) {
1038     if (!context_->IsCpuFloat16Enabled() ||
1039         (cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
1040       return RET_NOT_SUPPORT;
1041     }
1042     cpu_desc.data_type = kNumberTypeFloat16;
1043   }
1044   auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->graph_.version_,
1045                                         context_->float_mode);
1046   if (ret != RET_OK) {
1047     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1048     return RET_NOT_SUPPORT;
1049   }
1050   std::map<Tensor *, Tensor *> restored_origin_tensors;
1051 
1052   if (is_train_session_) {
1053     ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type,
1054                                context_->device_and_pkg_support_fp16_);
1055     if (ret != RET_OK) {
1056       MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
1057       return RET_NOT_SUPPORT;
1058     }
1059   }
1060 
1061 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
1062   // reset op task num, The number of operator segmentation tasks is not necessarily equal to the number of threads
1063   int thread_num_limit = ParallelThreadPoolManager::GetInstance()->GetTaskNum(config_info_);
1064   if (thread_num_limit != -1 && IsSharedThreadPoolOp(op_type)) {
1065     op_parameter->thread_num_ = thread_num_limit;
1066   }
1067 #endif
1068 
1069   ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, cpu_desc,
1070                                                      op_parameter, kernel);
1071   if (ret == RET_OK) {
1072     MS_LOG(DEBUG) << "Get TypeId(expect = " << kernel_data_type << ", real = " << cpu_desc.data_type
1073                   << ") op success: " << PrimitiveCurVersionTypeName(op_type);
1074     if (is_train_session_) {
1075       ret = (*kernel)->Prepare();
1076       RestoreTensorData(&restored_origin_tensors);
1077     }
1078   }
1079   return ret;
1080 }
1081 
1082 #ifdef GPU_OPENCL
FindGpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,kernel::KernelExec ** kernel,TypeId prefer_data_type)1083 int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1084                              OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::KernelExec **kernel,
1085                              TypeId prefer_data_type) {
1086   MS_ASSERT(op_parameter != nullptr);
1087   MS_ASSERT(kernel != nullptr);
1088   if (!context_->IsDeviceTypeEnabled(DT_GPU)) {
1089     return RET_NOT_SUPPORT;
1090   }
1091 
1092   // support more data type like int32
1093   kernel::KernelKey gpu_desc{kernel::KERNEL_ARCH::kGPU, desc.data_type, NHWC, desc.type};
1094   if (desc.data_type == kNumberTypeFloat32 && context_->IsGpuFloat16Enabled()) {
1095     gpu_desc.data_type = kNumberTypeFloat16;
1096   }
1097   if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kNumberTypeFloat32) {
1098     gpu_desc.data_type = prefer_data_type;
1099   }
1100   // weight dequant
1101   auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32, src_model_->graph_.version_,
1102                                         context_->float_mode);
1103   if (ret != RET_OK) {
1104     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1105     return RET_NOT_SUPPORT;
1106   }
1107   // we don't need to restore tensor for copy data
1108   ret = CopyConstTensorData(in_tensors, op_parameter->type_);
1109   if (ret != RET_OK) {
1110     MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
1111     return RET_NOT_SUPPORT;
1112   }
1113   ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, gpu_desc,
1114                                                      op_parameter, kernel);
1115   if (ret == RET_OK) {
1116     MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1117   } else {
1118     MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1119   }
1120   return ret;
1121 }
1122 #endif
1123 
FindProviderKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId data_type,kernel::KernelExec ** kernel)1124 int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1125                                   const LiteGraph::Node *node, TypeId data_type, kernel::KernelExec **kernel) {
1126 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
1127   MS_ASSERT(kernel != nullptr);
1128   int ret = RET_NOT_SUPPORT;
1129   auto prim_type = GetPrimitiveType(node->primitive_, context_->get_schema_version());
1130   if (prim_type == schema::PrimitiveType_Custom) {
1131     for (auto &&device : context_->device_list_) {
1132       if (!device.provider_.empty() && !device.provider_device_.empty()) {
1133         kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type,       NHWC, prim_type,
1134                                device.provider_device_,   device.provider_};
1135         ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc,
1136                                                            nullptr, kernel, node->primitive_);
1137         if (ret == RET_OK && *kernel != nullptr) {
1138           return ret;
1139         }
1140       }
1141     }
1142 
1143     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, prim_type, "", ""};
1144     ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1145                                                        kernel, node->primitive_);
1146     if (ret == RET_OK && *kernel != nullptr) {
1147       return ret;
1148     }
1149     return RET_NOT_SUPPORT;
1150   }
1151   if (!context_->IsProviderEnabled()) {
1152     return ret;
1153   }
1154   if (context_->get_schema_version() == SCHEMA_V0) {
1155     return ret;
1156   }
1157   for (auto &&device : context_->device_list_) {
1158     if (!device.provider_.empty()) {
1159       kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type,       NHWC, prim_type,
1160                              device.provider_device_,   device.provider_};
1161       ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1162                                                          kernel, node->primitive_);
1163       if (ret == RET_OK && *kernel != nullptr) {
1164         return ret;
1165       }
1166     }
1167   }
1168 #endif
1169   return RET_NOT_SUPPORT;
1170 }
1171 
FindBackendKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId prefer_data_type)1172 kernel::KernelExec *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
1173                                                  const std::vector<Tensor *> &out_tensors, const LiteGraph::Node *node,
1174                                                  TypeId prefer_data_type) {
1175   MS_ASSERT(node != nullptr);
1176   // why we need this
1177   TypeId data_type;
1178   if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1179     if (in_tensors.front()->data_type() == kNumberTypeBool) {
1180       data_type = kNumberTypeBool;
1181     } else {
1182       data_type = kNumberTypeFloat32;
1183     }
1184   } else {
1185     if (static_cast<std::string>(GetPrimitiveTypeName(node->primitive_, context_->get_schema_version())) == "Where" &&
1186         in_tensors.size() == kWhereInputNum) {
1187       // mslite use first input tensor dtype as kernel dtype, but op 'where' is a outliner, is first input 'condition'
1188       // is always 'bool'-typed, here we use where's second/third input dtype as kernel dtype.
1189       std::unordered_set<TypeId> type_set = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8,  kNumberTypeInt32,
1190                                              kNumberTypeBool,    kNumberTypeUInt8,   kObjectTypeString};
1191       if (type_set.find(in_tensors[kIndex1]->data_type()) != type_set.end()) {
1192         data_type = in_tensors[kIndex1]->data_type();
1193       } else {
1194         data_type = in_tensors[kIndex2]->data_type();
1195       }
1196     } else {
1197       data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
1198     }
1199 
1200     if (data_type == kTypeUnknown) {
1201       MS_LOG(ERROR) << "GetFirstFp32Fp16OrInt8Type is unknown.";
1202       return nullptr;
1203     }
1204   }
1205   if (context_->float_mode) {
1206     for (auto tensor : out_tensors) {
1207       if (!tensor->quant_params().empty() &&
1208           (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeUInt8)) {
1209         data_type = kNumberTypeFloat32;
1210         tensor->set_data_type(kNumberTypeFloat32);
1211       }
1212     }
1213   }
1214   kernel::KernelExec *kernel = nullptr;
1215   auto status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
1216   if (status == RET_OK && kernel != nullptr) {
1217     return kernel;
1218   }
1219   MS_ASSERT(!node->output_indices_.empty());
1220   OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1221   if (op_parameter == nullptr) {
1222     MS_LOG(ERROR) << "Can not find OpParameter!type: "
1223                   << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1224     return nullptr;
1225   }
1226 
1227 #if (defined GPU_OPENCL) || (defined ENABLE_FP16)
1228   int kernel_thread_count = op_parameter->thread_num_;
1229 #endif
1230   op_parameter->is_train_session_ = is_train_session_;
1231   kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, op_parameter->type_};
1232 
1233 #ifdef GPU_OPENCL
1234   bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
1235   bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
1236   if (gpu_priority && use_gpu_kernel) {
1237     status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel, prefer_data_type);
1238     if (status == RET_OK) {
1239       return kernel;
1240     } else {
1241       MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1242                     << node->name_;
1243       if (status == RET_ERROR) {
1244         op_parameters_.erase(node->output_indices_.at(0));
1245         auto ret = InferNodeShape(node);
1246         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1247           op_parameter = op_parameters_[node->output_indices_.at(0)];
1248           op_parameter->thread_num_ = kernel_thread_count;
1249         } else {
1250           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1251           return nullptr;
1252         }
1253       }
1254     }
1255   }
1256 #endif
1257 #ifdef ENABLE_FP16
1258   if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
1259       ((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
1260     status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
1261     if (status == RET_OK) {
1262       return kernel;
1263     } else {
1264       MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1265                     << node->name_;
1266       if (status == RET_ERROR) {
1267         op_parameters_.erase(node->output_indices_.at(0));
1268         auto ret = InferNodeShape(node);
1269         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1270           op_parameter = op_parameters_[node->output_indices_.at(0)];
1271           op_parameter->thread_num_ = kernel_thread_count;
1272         } else {
1273           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1274           return nullptr;
1275         }
1276       }
1277     }
1278   }
1279 #endif
1280   if (data_type == kNumberTypeFloat16) {
1281     MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
1282     desc.data_type = kNumberTypeFloat32;
1283   }
1284   status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
1285   if (status == RET_OK) {
1286     return kernel;
1287   } else if (status == RET_ERROR) {
1288     op_parameters_.erase(node->output_indices_.at(0));
1289     auto ret = InferNodeShape(node);
1290     if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
1291       MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1292     }
1293   }
1294 #ifdef OP_INT8_CLIP
1295   if (desc.data_type == kNumberTypeInt8) {
1296     MS_LOG(ERROR) << unsupport_int8_log;
1297   }
1298 #endif
1299   return nullptr;
1300 }
1301 
1302 namespace {
GetKernelSubGraphType(const kernel::KernelExec * kernel,const InnerContext & context,bool is_controlflow=false)1303 kernel::SubGraphType GetKernelSubGraphType(const kernel::KernelExec *kernel, const InnerContext &context,
1304                                            bool is_controlflow = false) {
1305   if (kernel == nullptr) {
1306     return kernel::kNotSubGraph;
1307   }
1308 
1309   auto desc = kernel->desc();
1310   if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
1311     if (desc.data_type == kNumberTypeFloat16) {
1312       return kernel::kGpuFp16SubGraph;
1313     } else {
1314       return kernel::kGpuFp32SubGraph;
1315     }
1316   } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) {
1317     return kernel::kNpuSubGraph;
1318   } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) {
1319     return kernel::kApuSubGraph;
1320   } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) {
1321     if (desc.data_type == kNumberTypeFloat16) {
1322       return kernel::kCpuFP16SubGraph;
1323     } else {
1324       return kernel::kCpuFP32SubGraph;
1325     }
1326   } else if (desc.arch == kernel::KERNEL_ARCH::kCustom) {
1327     return kernel::kCustomSubGraph;
1328   }
1329   return kernel::kNotSubGraph;
1330 }
1331 }  // namespace
1332 
SchedulePartialToKernel(const lite::LiteGraph::Node * src_node)1333 kernel::KernelExec *Scheduler::SchedulePartialToKernel(const lite::LiteGraph::Node *src_node) {
1334   MS_ASSERT(src_model_ != nullptr);
1335   MS_ASSERT(src_node != nullptr);
1336   auto *primitive = src_node->primitive_;
1337   MS_ASSERT(primitive != nullptr);
1338   if (!IsPartialNode(primitive, context_->get_schema_version())) {
1339     return nullptr;
1340   }
1341   auto subgraph_index = GetPartialGraphIndex(src_node->primitive_, context_->get_schema_version());
1342   if (SubGraphHasScheduled(subgraph_index)) {
1343     MS_LOG(INFO) << "Subgraph has been scheduled.";
1344     return {};
1345   } else {
1346     SubGraphMarkScheduled(subgraph_index);
1347   }
1348   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1349   if (subgraph_kernel == nullptr) {
1350     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1351     return {};
1352   }
1353   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1354   return subgraph_kernel;
1355 }
1356 
1357 #ifdef ENABLE_FP16
SubGraphPreferDataType(const int & subgraph_index,TypeId * prefer_data_type)1358 int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
1359   if (!context_->IsCpuFloat16Enabled() || context_->GetDelegateMode() == kNNAPI) {
1360     *prefer_data_type = kNumberTypeFloat32;
1361     return RET_OK;
1362   }
1363 
1364   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1365   for (auto node_index : subgraph->node_indices_) {
1366     auto node = src_model_->graph_.all_nodes_[node_index];
1367     MS_ASSERT(node != nullptr);
1368     MS_ASSERT(!node->output_indices_.empty());
1369     OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1370     if (op_parameter == nullptr) {
1371       MS_LOG(ERROR) << "Can not find OpParameter!type: "
1372                     << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1373       return RET_ERROR;
1374     }
1375     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16, NHWC, op_parameter->type_};
1376     if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1377       *prefer_data_type = kNumberTypeFloat32;
1378       return RET_OK;
1379     }
1380 
1381     std::vector<Tensor *> inputs;
1382     std::vector<Tensor *> outputs;
1383     FindNodeInoutTensors(*node, &inputs, &outputs);
1384     if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1385       *prefer_data_type = kNumberTypeFloat32;
1386       return RET_OK;
1387     }
1388     TypeId data_type = GetFirstFp32Fp16OrInt8Type(inputs);
1389     if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) {
1390       *prefer_data_type = kNumberTypeFloat32;
1391       return RET_OK;
1392     }
1393   }
1394   *prefer_data_type = kNumberTypeFloat16;
1395   return RET_OK;
1396 }
1397 #endif
1398 
ScheduleMainSubGraphToKernels()1399 std::vector<kernel::KernelExec *> Scheduler::ScheduleMainSubGraphToKernels() {
1400   std::vector<kernel::KernelExec *> kernels;
1401   std::vector<lite::Tensor *> in_tensors;
1402   std::vector<lite::Tensor *> out_tensors;
1403   TypeId prefer_data_type = context_->GetDelegateMode() == kNNAPI ? kNumberTypeFloat32 : kTypeUnknown;
1404   auto ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1405   if (ret != RET_OK) {
1406     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << kMainSubGraphIndex;
1407     for (auto *kernel : kernels) {
1408       delete kernel;
1409       kernel = nullptr;
1410     }
1411     return {};
1412   }
1413   return kernels;
1414 }
1415 
SchedulePartialToSubGraphKernel(const int & subgraph_index)1416 kernel::KernelExec *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
1417   TypeId prefer_data_type = kTypeUnknown;
1418 #ifdef ENABLE_FP16
1419   if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
1420     MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
1421     return nullptr;
1422   }
1423 #endif
1424   std::vector<kernel::KernelExec *> kernels;
1425   std::vector<lite::Tensor *> in_tensors;
1426   std::vector<lite::Tensor *> out_tensors;
1427   auto ret = ScheduleSubGraphToKernels(subgraph_index, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1428   if (ret != RET_OK) {
1429     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << subgraph_index;
1430     return nullptr;
1431   }
1432   kernel::KernelExecUtil::FindAllInoutKernels(kernels);
1433   kernel::SubGraphType cur_sub_graph_type = kernel::kCpuFP32SubGraph;
1434   if (!kernels.empty()) {
1435     cur_sub_graph_type = GetKernelSubGraphType(kernels.front(), *context_, true);
1436   }
1437   MS_LOG(INFO) << "cur_sub_graph_type: " << cur_sub_graph_type;
1438   auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1439     kernels, &in_tensors, &out_tensors, cur_sub_graph_type, *context_, context_->get_schema_version());
1440   if (subgraph_kernel == nullptr) {
1441     MS_LOG(ERROR) << "CreateSubGraphKernel failed, cur_sub_graph_type: " << cur_sub_graph_type;
1442     return nullptr;
1443   }
1444   return subgraph_kernel;
1445 }
1446 
ScheduleSubGraphToSubGraphKernels(const int & subgraph_index)1447 std::vector<kernel::KernelExec *> Scheduler::ScheduleSubGraphToSubGraphKernels(const int &subgraph_index) {
1448   if (subgraph_index == kMainSubGraphIndex) {
1449     return ScheduleMainSubGraphToKernels();
1450   }
1451   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1452   if (subgraph_kernel == nullptr) {
1453     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1454     return {};
1455   }
1456   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1457   subgraph_index_subgraph_kernel_map_[subgraph_index] = subgraph_kernel;
1458   return {subgraph_kernel};
1459 }
1460 
ScheduleNodeToKernel(const lite::LiteGraph::Node * src_node,TypeId prefer_data_type)1461 kernel::KernelExec *Scheduler::ScheduleNodeToKernel(const lite::LiteGraph::Node *src_node, TypeId prefer_data_type) {
1462   std::vector<Tensor *> inputs;
1463   std::vector<Tensor *> outputs;
1464   MS_ASSERT(src_node != nullptr);
1465   FindNodeInoutTensors(*src_node, &inputs, &outputs);
1466 
1467   ResetByExecutionPlan(src_node->name_, &prefer_data_type);
1468 
1469   mindspore::kernel::KernelExec *kernel = nullptr;
1470   if (src_model_->model_type_ != mindspore::lite::ModelType_MSLite) {
1471     auto abstract_model_ptr = reinterpret_cast<AbstractBaseModel *>(src_model_);
1472     if (abstract_model_ptr == nullptr) {
1473       MS_LOG(ERROR) << "src model is not abstract base model return nullptr.";
1474       return nullptr;
1475     }
1476     kernel = abstract_model_ptr->FindBackendKernel(inputs, outputs, src_node, context_, prefer_data_type);
1477     if (kernel == nullptr) {
1478       MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_;
1479       return nullptr;
1480     }
1481   } else {
1482     kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type);
1483     if (kernel == nullptr) {
1484       MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_
1485                     << ", type: " << GetPrimitiveTypeName(src_node->primitive_, context_->get_schema_version());
1486       return nullptr;
1487     }
1488   }
1489   op_parameters_[src_node->output_indices_.at(0)] = nullptr;
1490   auto ret = kernel::KernelExecUtil::SetKernelTensorDataType(kernel);
1491   if (ret != RET_OK) {
1492     MS_LOG(ERROR) << "Set tensor data type for kernel " << kernel->name() << std::endl;
1493     delete kernel;
1494     return nullptr;
1495   }
1496   kernel->set_name(src_node->name_);
1497   if (kernel->kernel() != nullptr) {
1498     kernel->kernel()->SetConfig(config_info_);
1499   }
1500   return kernel;
1501 }
1502 
IsControlFlowPattern(const lite::LiteGraph::Node & partial_node)1503 bool Scheduler::IsControlFlowPattern(const lite::LiteGraph::Node &partial_node) {
1504   lite::LiteGraph::Node *partial_node_output = nullptr;
1505   for (auto output_index : partial_node.output_indices_) {
1506     for (auto &node : src_model_->graph_.all_nodes_) {
1507       if (IsContain(node->input_indices_, output_index)) {
1508         partial_node_output = node;
1509         break;
1510       }
1511     }
1512   }
1513 
1514   return partial_node_output != nullptr &&
1515          (IsCallNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1516           IsSwitchNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1517           IsSwitchLayerNode(partial_node_output->primitive_, context_->get_schema_version()));
1518 }
1519 
ScheduleGraphToKernels(std::vector<kernel::KernelExec * > * dst_kernels,TypeId prefer_data_type)1520 int Scheduler::ScheduleGraphToKernels(std::vector<kernel::KernelExec *> *dst_kernels, TypeId prefer_data_type) {
1521   subgraphs_to_schedule_.push_back(kMainSubGraphIndex);
1522   while (!subgraphs_to_schedule_.empty()) {
1523     auto cur_subgraph_index = subgraphs_to_schedule_.front();
1524     subgraphs_to_schedule_.pop_front();
1525     auto kernels = ScheduleSubGraphToSubGraphKernels(cur_subgraph_index);
1526     if (kernels.empty()) {
1527       MS_LOG(ERROR) << "ScheduleSubGraphToSubGraphKernel failed";
1528       return RET_ERROR;
1529     }
1530     std::copy(kernels.begin(), kernels.end(), std::back_inserter(*dst_kernels));
1531   }
1532   return RET_OK;
1533 }
1534 
ScheduleSubGraphToKernels(size_t subgraph_index,std::vector<kernel::KernelExec * > * dst_kernels,std::vector<lite::Tensor * > * in_tensors,std::vector<lite::Tensor * > * out_tensors,TypeId prefer_data_type)1535 int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::KernelExec *> *dst_kernels,
1536                                          std::vector<lite::Tensor *> *in_tensors,
1537                                          std::vector<lite::Tensor *> *out_tensors, TypeId prefer_data_type) {
1538   MS_ASSERT(src_model_ != nullptr);
1539   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
1540   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
1541   MS_ASSERT(dst_kernels != nullptr);
1542   MS_ASSERT(dst_kernels->empty());
1543   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1544   auto ret = RET_OK;
1545   for (auto node_index : subgraph->node_indices_) {
1546     auto node = src_model_->graph_.all_nodes_[node_index];
1547     MS_ASSERT(node != nullptr);
1548     auto *primitive = node->primitive_;
1549     MS_ASSERT(primitive != nullptr);
1550     kernel::KernelExec *kernel = nullptr;
1551 
1552     if (src_model_->model_type_ == ModelType_MSLite && IsPartialNode(primitive, context_->get_schema_version())) {
1553       if (IsControlFlowPattern(*node)) {
1554         kernel = ScheduleNodeToKernel(node, prefer_data_type);
1555         auto partial_subgraph_index = GetPartialGraphIndex(primitive, context_->get_schema_version());
1556         MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "control flow scheduler is nullptr.");
1557         control_flow_scheduler_->RecordSubgraphCaller(partial_subgraph_index, kernel);
1558         if (SubGraphHasScheduled(partial_subgraph_index)) {
1559           partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1560           MS_LOG(INFO) << "subgraph has scheduled. ";
1561         } else {
1562           SubGraphMarkScheduled(partial_subgraph_index);
1563           partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1564           subgraphs_to_schedule_.push_back(partial_subgraph_index);
1565         }
1566       } else {
1567         MS_CHECK_TRUE_MSG(
1568           subgraph_index != static_cast<size_t>(GetPartialGraphIndex(node->primitive_, context_->get_schema_version())),
1569           RET_ERROR, "Unreasonable cycles exist in subgraph.");
1570         kernel = SchedulePartialToKernel(node);
1571       }
1572     } else {
1573       kernel = ScheduleNodeToKernel(node, prefer_data_type);
1574     }
1575     if (kernel == nullptr || ret != RET_OK) {
1576       MS_LOG(ERROR) << "schedule node return nullptr, name: " << node->name_
1577                     << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
1578       return RET_ERROR;
1579     }
1580     kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index)));
1581     dst_kernels->emplace_back(kernel);
1582     auto litert_kernel = reinterpret_cast<kernel::Kernel *>(kernel->kernel());
1583     if (MS_UNLIKELY(litert_kernel == nullptr)) {
1584       MS_LOG(ERROR) << "nullptr exist in scheduler.";
1585       return RET_ERROR;
1586     }
1587     primitives_.emplace(litert_kernel, static_cast<const schema::Primitive *>(primitive));
1588   }
1589   if (in_tensors != nullptr) {
1590     std::transform(subgraph->input_indices_.begin(), subgraph->input_indices_.end(), std::back_inserter(*in_tensors),
1591                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1592   }
1593   if (out_tensors != nullptr) {
1594     std::transform(subgraph->output_indices_.begin(), subgraph->output_indices_.end(), std::back_inserter(*out_tensors),
1595                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1596   }
1597   return RET_OK;
1598 }
1599 
1600 namespace {
KernelFitCurrentSubGraphCPUFp32(TypeId data_type)1601 bool KernelFitCurrentSubGraphCPUFp32(TypeId data_type) {
1602   return (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat || data_type == kNumberTypeInt8 ||
1603           data_type == kNumberTypeInt || data_type == kNumberTypeInt32 || data_type == kNumberTypeInt64 ||
1604           data_type == kNumberTypeUInt8 || data_type == kNumberTypeBool);
1605 }
1606 
KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type,const kernel::KernelExec & kernel)1607 bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::KernelExec &kernel) {
1608   switch (subgraph_type) {
1609     case kernel::SubGraphType::kNotSubGraph:
1610     case kernel::SubGraphType::kApuSubGraph:
1611       return false;
1612     case kernel::SubGraphType::kGpuFp16SubGraph:
1613       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1614         return false;
1615       }
1616       return (kernel.desc().data_type != kNumberTypeFloat32);
1617     case kernel::SubGraphType::kGpuFp32SubGraph:
1618       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1619         return false;
1620       }
1621       return (kernel.desc().data_type != kNumberTypeFloat16);
1622     case kernel::SubGraphType::kNpuSubGraph:
1623       return kernel.desc().arch == kernel::KERNEL_ARCH::kNPU;
1624     case kernel::SubGraphType::kCpuFP16SubGraph: {
1625       auto desc = kernel.desc();
1626       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1627         return false;
1628       }
1629 #ifdef ENABLE_FP16
1630       if (desc.data_type == kNumberTypeInt8 || desc.data_type == kNumberTypeUInt8) {
1631         return true;
1632       }
1633 #endif
1634       return (desc.data_type == kNumberTypeFloat16);
1635     }
1636     case kernel::SubGraphType::kCpuFP32SubGraph: {
1637       auto desc = kernel.desc();
1638       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1639         return false;
1640       }
1641       return KernelFitCurrentSubGraphCPUFp32(desc.data_type);
1642     }
1643     default:
1644       return false;
1645   }
1646 }
1647 
FindAllSubGraphKernels(const std::vector<kernel::KernelExec * > & sorted_kernels,const InnerContext & context,size_t * cur_index,int schema_version)1648 kernel::KernelExec *FindAllSubGraphKernels(const std::vector<kernel::KernelExec *> &sorted_kernels,
1649                                            const InnerContext &context, size_t *cur_index, int schema_version) {
1650   std::vector<kernel::KernelExec *> sub_kernels;
1651   sub_kernels.emplace_back(sorted_kernels[*cur_index]);
1652   auto cur_sub_graph_type = GetKernelSubGraphType(sorted_kernels[*cur_index], context);
1653   for (*cur_index = *cur_index + 1; *cur_index < sorted_kernels.size(); ++(*cur_index)) {
1654     auto cur_kernel = sorted_kernels[*cur_index];
1655     MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph);
1656     // already a subgraph or a delegate
1657     if (cur_kernel->desc().arch == kernel::kDelegate) {
1658       --(*cur_index);
1659       break;
1660     }
1661     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph ||
1662         !KernelFitCurrentSubGraph(cur_sub_graph_type, *cur_kernel)) {
1663       --(*cur_index);
1664       break;
1665     }
1666     sub_kernels.emplace_back(cur_kernel);
1667   }
1668   return kernel::KernelExecUtil::CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type, context,
1669                                                       schema_version);
1670 }
1671 }  // namespace
1672 
ConstructNormalSubGraphs(const std::vector<kernel::KernelExec * > & src_kernel,std::vector<kernel::KernelExec * > * dst_kernel,std::map<const kernel::KernelExec *,bool> * is_kernel_finish)1673 int Scheduler::ConstructNormalSubGraphs(const std::vector<kernel::KernelExec *> &src_kernel,
1674                                         std::vector<kernel::KernelExec *> *dst_kernel,
1675                                         std::map<const kernel::KernelExec *, bool> *is_kernel_finish) {
1676   if (src_kernel.empty()) {
1677     return RET_OK;
1678   }
1679 
1680   // construct subgraph
1681   for (size_t index = 0; index < src_kernel.size(); index++) {
1682     auto cur_kernel = src_kernel[index];
1683     MS_ASSERT(cur_kernel != nullptr);
1684     // Not support APU now
1685     MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph);
1686     if (cur_kernel->desc().arch == kernel::kDelegate) {
1687       dst_kernel->emplace_back(cur_kernel);
1688       continue;
1689     }
1690     // already a subgraph or a delegate
1691     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
1692       dst_kernel->emplace_back(cur_kernel);
1693       continue;
1694     }
1695     auto subgraph = FindAllSubGraphKernels(src_kernel, *context_, &index, context_->get_schema_version());
1696     if (subgraph == nullptr) {
1697       MS_LOG(ERROR) << "Create SubGraphKernel failed";
1698       return RET_ERROR;
1699     }
1700     dst_kernel->emplace_back(subgraph);
1701   }
1702   for (auto *subgraph : *dst_kernel) {
1703     if (subgraph->desc().arch == kernel::kDelegate) {
1704       *infer_along_running_ = false;
1705       continue;
1706     }
1707     if (subgraph->subgraph_type() != kernel::kCpuFP32SubGraph &&
1708         subgraph->subgraph_type() != kernel::kCpuFP16SubGraph) {
1709       *infer_along_running_ = false;
1710     }
1711     auto subgraph_kernel = static_cast<kernel::SubGraphKernel *>(subgraph);
1712     if (subgraph_kernel == nullptr) {
1713       MS_LOG(ERROR) << "kernel: " << subgraph->name() << " not is subgraph kernel.";
1714       return RET_ERROR;
1715     }
1716     // this is for train session cpu fp16, should be removed in the future.
1717     auto ret = subgraph_kernel->SetFp16Attr();
1718     if (ret != RET_OK) {
1719       MS_LOG(ERROR) << "Init SubGraph failed: " << ret;
1720       return ret;
1721     }
1722   }
1723   return RET_OK;
1724 }
1725 
GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor * > & in_tensors)1726 TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) {
1727   for (const auto &tensor : in_tensors) {
1728     auto dtype = tensor->data_type();
1729     if (dtype == kObjectTypeTensorType) {
1730       return TensorListDataType(tensor);
1731     }
1732     std::unordered_set<TypeId> type_set = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8,  kNumberTypeInt32,
1733                                            kNumberTypeBool,    kNumberTypeUInt8,   kObjectTypeString};
1734     if (type_set.find(dtype) != type_set.end()) {
1735       return dtype;
1736     }
1737   }
1738   if (in_tensors.empty()) {
1739     MS_LOG(ERROR) << "in tensor is empty.";
1740     return kTypeUnknown;
1741   }
1742   MS_ASSERT(!in_tensors.empty());
1743   return in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
1744 }
1745 
PartialSubGraphType(const std::vector<kernel::KernelExec * > & kernels)1746 kernel::SubGraphType Scheduler::PartialSubGraphType(const std::vector<kernel::KernelExec *> &kernels) {
1747   if (std::any_of(kernels.begin(), kernels.end(),
1748                   [](const kernel::KernelExec *item) { return item->desc().data_type == kNumberTypeFloat16; })) {
1749     return kernel::kCpuFP16SubGraph;
1750   }
1751   return kernel::kCpuFP32SubGraph;
1752 }
1753 
InferSwitchShape(const lite::LiteGraph::Node * switch_node)1754 int Scheduler::InferSwitchShape(const lite::LiteGraph::Node *switch_node) {
1755   MS_ASSERT(src_model_ != nullptr);
1756   MS_ASSERT(switch_node != nullptr);
1757   std::deque<lite::LiteGraph::Node *> partial_cnode_to_infer{};
1758   for (size_t i = 1; i < switch_node->input_indices_.size(); ++i) {
1759     auto branch_output_index = switch_node->input_indices_.at(i);
1760     for (auto &node : src_model_->graph_.all_nodes_) {
1761       if (IsContain(node->output_indices_, branch_output_index) &&
1762           IsPartialNode(node->primitive_, context_->get_schema_version()) &&
1763           partial_cnode_inferred_.find(node) == partial_cnode_inferred_.end()) {
1764         partial_cnode_inferred_.insert(node);
1765         partial_cnode_to_infer.push_back(node);
1766         break;
1767       }
1768     }
1769   }
1770 
1771   while (!partial_cnode_to_infer.empty()) {
1772     auto &node = partial_cnode_to_infer.front();
1773     partial_cnode_to_infer.pop_front();
1774     int ret = InferPartialShape(node);
1775     if (ret != RET_OK) {
1776       MS_LOG(WARNING) << "partial infer not ok, ret: " << ret;
1777     }
1778   }
1779   return RET_OK;
1780 }
1781 
NodeInputIsSwitchType(const lite::LiteGraph::Node * node)1782 LiteGraph::Node *Scheduler::NodeInputIsSwitchType(const lite::LiteGraph::Node *node) {
1783   MS_ASSERT(src_model_ != nullptr);
1784   MS_ASSERT(node != nullptr);
1785   for (auto &iter : src_model_->graph_.all_nodes_) {
1786     if (iter->output_indices_ == node->input_indices_) {
1787       if (IsSwitchNode(iter->primitive_, context_->get_schema_version()) ||
1788           IsSwitchLayerNode(iter->primitive_, context_->get_schema_version())) {
1789         return iter;
1790       } else {
1791         return nullptr;
1792       }
1793     }
1794   }
1795   return nullptr;
1796 }
1797 
SubGraphHasScheduled(const int & index)1798 bool Scheduler::SubGraphHasScheduled(const int &index) {
1799   return scheduled_subgraph_index_.find(index) != scheduled_subgraph_index_.end();
1800 }
1801 
SubGraphMarkScheduled(const int & index)1802 void Scheduler::SubGraphMarkScheduled(const int &index) { scheduled_subgraph_index_.insert(index); }
1803 
ConstructControlFlowMainGraph(std::vector<kernel::KernelExec * > * kernels)1804 int Scheduler::ConstructControlFlowMainGraph(std::vector<kernel::KernelExec *> *kernels) {
1805   auto back_kernels = *kernels;
1806   kernels->clear();
1807   std::vector<kernel::KernelExec *> main_graph_kernels{};
1808   for (auto &kernel : back_kernels) {
1809     if (kernel->subgraph_type() != kernel::kNotSubGraph) {
1810       kernels->push_back(kernel);
1811     } else {
1812       main_graph_kernels.push_back(kernel);
1813     }
1814   }
1815   auto cur_subgraph_type = PartialSubGraphType(main_graph_kernels);
1816   auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1817     main_graph_kernels, nullptr, nullptr, cur_subgraph_type, *context_, context_->get_schema_version());
1818   if (subgraph_kernel == nullptr) {
1819     MS_LOG(ERROR) << "create main graph for control flow model failed.";
1820     return RET_ERROR;
1821   }
1822   kernels->insert(kernels->begin(), subgraph_kernel);
1823   return RET_OK;
1824 }
1825 
NonTailCallNodes()1826 std::vector<kernel::KernelExec *> Scheduler::NonTailCallNodes() {
1827   std::vector<kernel::KernelExec *> ret{};
1828   if (*is_control_flow_) {
1829     ret = control_flow_scheduler_->GetNonTailCalls();
1830   }
1831   return ret;
1832 }
1833 }  // namespace mindspore::lite
1834