• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/scheduler.h"
18 #include <map>
19 #include <unordered_set>
20 #include <queue>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #include "src/tensorlist.h"
25 #include "nnacl/partial_fusion_parameter.h"
26 #include "include/errorcode.h"
27 #include "src/common/graph_util.h"
28 #include "src/common/utils.h"
29 #include "src/litert/kernel_registry.h"
30 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
31 #include "include/registry/register_kernel.h"
32 #endif
33 #include "src/litert/kernel_exec_util.h"
34 #include "src/executor/sub_graph_kernel.h"
35 #include "src/common/ops/populate/populate_register.h"
36 #include "src/common/version_manager.h"
37 #include "src/common/prim_util.h"
38 #include "src/litert/lite_model.h"
39 #include "src/common/tensor_util.h"
40 #include "src/common/context_util.h"
41 #include "src/litert/infer_manager.h"
42 #include "src/litert/runtime_pass.h"
43 #ifndef ENABLE_MULTI_LAYOUT
44 #include "src/litert/pass/format_pass/format_pass.h"
45 #endif
46 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
47 #include "src/litert/sub_graph_split.h"
48 #include "src/litert/pass/online_fusion/online_fusion_pass_registry.h"
49 #endif
50 #include "src/litert/weight_decoder.h"
51 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
52 #include "nnacl/nnacl_common.h"
53 #if GPU_OPENCL
54 #include "src/litert/kernel/opencl/opencl_subgraph.h"
55 #include "src/litert/kernel/gpu/opencl/opencl_runtime.h"
56 #endif
57 #include "include/registry/register_kernel_interface.h"
58 #include "extendrt/mindir_loader/abstract_base_model.h"
59 #include "src/litert/pack_weight_manager.h"
60 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
61 #include "thread/parallel_thread_pool_manager.h"
62 #endif
63 #ifdef SUPPORT_NNRT
64 #include "src/litert/delegate/nnrt/nnrt_delegate.h"
65 #endif
66 
67 using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
68 
69 namespace mindspore::lite {
70 namespace {
71 constexpr int kMainSubGraphIndex = 0;
72 }  // namespace
73 
74 namespace {
75 // support_fp16: current device and package support float16
CastKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel,bool support_fp16)76 int CastKernelWeight(const kernel::SubGraphType &belong_subgraph_type, const kernel::KernelExec *kernel,
77                      bool support_fp16) {
78   MS_ASSERT(kernel != nullptr);
79   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
80   if (belong_subgraph_type != kernel::kCpuFP32SubGraph && belong_subgraph_type != kernel::kCpuFP16SubGraph) {
81     return RET_OK;
82   }
83   for (auto *tensor : kernel->in_tensors()) {
84     MS_ASSERT(tensor != nullptr);
85     // only cast const tensor
86     // tensorlist not support fp16 now
87     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
88       continue;
89     }
90     // only support fp32->fp16 or fp16->fp32
91     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
92       continue;
93     }
94     if (tensor->data_type() == kNumberTypeFloat32 && belong_subgraph_type == kernel::kCpuFP16SubGraph) {
95       auto ret = CastConstTensorData(tensor, kNumberTypeFloat16, support_fp16);
96       if (ret != RET_OK) {
97         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
98         return ret;
99       }
100     } else if (tensor->data_type() == kNumberTypeFloat16 && belong_subgraph_type == kernel::kCpuFP32SubGraph) {
101       auto ret = CastConstTensorData(tensor, kNumberTypeFloat32, support_fp16);
102       if (ret != RET_OK) {
103         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
104         return ret;
105       }
106     } else {
107       MS_LOG(DEBUG) << "No need to cast";
108     }
109   }
110   return RET_OK;
111 }
112 
CopyConstTensorData(const std::vector<Tensor * > & tensors,int op_type)113 int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
114   // packed kernels such as conv don't need to copy because weight will be packed in kernel
115   if (lite::PackWeightManager::GetInstance()->IsCopyTensor(op_type)) {
116     return RET_OK;
117   }
118 
119   for (auto *tensor : tensors) {
120     // only copy non-copied const tensor
121     if (!tensor->IsConst() && tensor->data() != nullptr) {
122       MS_LOG(ERROR) << "Illegitimate tensor : " << tensor->tensor_name();
123       return RET_ERROR;
124     }
125     if (!tensor->IsConst() || tensor->own_data()) {
126       continue;
127     }
128     if (tensor->data_type() == kObjectTypeTensorType) {
129       // tensorlist's data is nullptr since ConvertTensors
130       // we never set or malloc data of tensorlist but malloc tensors in tensorlist
131       MS_ASSERT(tensor->data() == nullptr);
132     } else {
133       auto copy_tensor = Tensor::CopyTensor(*tensor, true);
134       if (copy_tensor == nullptr) {
135         MS_LOG(ERROR) << "Copy tensor failed";
136         return RET_ERROR;
137       }
138       tensor->FreeData();
139       tensor->set_data(copy_tensor->data());
140       tensor->set_own_data(true);
141       copy_tensor->set_data(nullptr);
142       delete (copy_tensor);
143     }
144   }
145   return RET_OK;
146 }
147 }  // namespace
148 
149 // support_fp16: current device and package support float16
HandleBuildinCpuKernelWeight(const kernel::SubGraphType & belong_subgraph_type,const kernel::KernelExec * kernel)150 int Scheduler::HandleBuildinCpuKernelWeight(const kernel::SubGraphType &belong_subgraph_type,
151                                             const kernel::KernelExec *kernel) {
152   MS_ASSERT(kernel != nullptr);
153   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
154   if (is_train_session_ || kernel->type() == schema::PrimitiveType_Custom ||
155       kernel->desc().provider != kernel::kBuiltin) {
156     return RET_OK;
157   }
158   auto ret = CastKernelWeight(belong_subgraph_type, kernel, context_->device_and_pkg_support_fp16_);
159   if (ret != RET_OK) {
160     MS_LOG(DEBUG) << "CastKernelWeight failed: " << ret;
161     return RET_NOT_SUPPORT;
162   }
163   if (!(reinterpret_cast<LiteModel *>(src_model_)->keep_model_buf())) {
164     // we don't need to restore tensor for copy data
165     MS_CHECK_TRUE_RET(kernel->op_parameter() != nullptr, RET_ERROR);
166     ret = CopyConstTensorData(kernel->in_tensors(), kernel->op_parameter()->type_);
167     if (ret != RET_OK) {
168       MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
169       return RET_NOT_SUPPORT;
170     }
171   }
172   return RET_OK;
173 }
174 
InitKernels(std::vector<kernel::KernelExec * > && dst_kernels)175 int Scheduler::InitKernels(std::vector<kernel::KernelExec *> &&dst_kernels) {
176   if (is_train_session_) {
177     return RET_OK;
178   }
179   for (auto kernel : dst_kernels) {
180     // delegate graph kernel
181     if (kernel->desc().arch == kernel::kDelegate) {
182       continue;
183     }
184     auto subgraph_type = kernel->subgraph_type();
185     if (subgraph_type == kernel::kNotSubGraph) {
186       MS_LOG(ERROR) << "construct subgraph failed.";
187       return RET_ERROR;
188     }
189     auto subgraph_nodes = reinterpret_cast<kernel::SubGraphKernel *>(kernel)->nodes();
190     for (auto node : subgraph_nodes) {
191       for (auto *tensor : node->out_tensors()) {
192         if (tensor->IsConst()) {
193           MS_CHECK_TRUE_MSG(node->op_parameter() != nullptr, RET_NULL_PTR, "node's op_parameter is invalid.");
194           if (node->op_parameter()->type_ == ::PrimType::PrimType_Inner_ShapeFusion) {
195             continue;
196           }
197           MS_LOG(ERROR) << "Illegitimate kernel output tensor : " << tensor->tensor_name();
198           continue;
199         }
200       }
201       auto ret = HandleBuildinCpuKernelWeight(subgraph_type, node);
202       if (ret != RET_OK) {
203         return ret;
204       }
205     }
206 #if GPU_OPENCL
207     if (kernel->desc().arch == kernel::kGPU) {
208       if (this->GetEnableGLTexture() == true && (kernel == dst_kernels.front() || kernel == dst_kernels.back() - 1)) {
209         kernel->SetOpenGLTextureEnable(true);
210         MS_LOG(INFO) << "Set OpenGLSharingMem for subgraph success!" << std::endl;
211       }
212       auto ret = reinterpret_cast<kernel::OpenCLSubGraph *>(kernel)->RunPass();
213       if (ret != RET_OK) {
214         MS_LOG(ERROR) << "OpenCLSubGraph RunPass failed.";
215         return ret;
216       }
217     }
218 #endif
219   }
220   return RET_OK;
221 }
222 
SchedulePreProcess()223 int Scheduler::SchedulePreProcess() {
224 #if !defined(AUTO_PARALLEL_CLIP) || !defined(RUNTIME_PASS_CLIP)
225   auto search_sub_graph =
226     SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
227 #endif
228 
229 #if !defined(RUNTIME_PASS_CLIP)
230   OnlineFusionRegistry::GetInstance()->DoOnlineFusionPass(&search_sub_graph);
231 #endif
232 
233   this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
234 
235   if (src_model_->model_type_ != ModelType_MSLite) {
236     // call abstract model infer interface
237     *is_infershape_ = RET_OK;
238   } else {
239     *is_infershape_ = InferSubGraphShape(kMainSubGraphIndex);
240   }
241   if (*is_infershape_ != RET_OK && *is_infershape_ != RET_INFER_INVALID) {
242     MS_LOG(ERROR) << "op infer shape failed.";
243     return *is_infershape_;
244   }
245 
246   if (context_->enable_parallel_) {
247 #ifndef AUTO_PARALLEL_CLIP
248     if (*is_infershape_ != RET_INFER_INVALID) {
249       search_sub_graph.SubGraphSplit();
250     }
251 #else
252     MS_LOG(ERROR) << unsupport_auto_parallel_log;
253     return RET_NOT_SUPPORT;
254 #endif
255   }
256   return RET_OK;
257 }
258 
CheckCpuValid(const std::vector<kernel::KernelExec * > * dst_kernels) const259 int Scheduler::CheckCpuValid(const std::vector<kernel::KernelExec *> *dst_kernels) const {
260   if (context_->IsDeviceTypeEnabled(DT_CPU)) {
261     return RET_OK;
262   }
263   for (auto kernel : *dst_kernels) {
264     if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
265       MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
266       return RET_ERROR;
267     }
268   }
269   return RET_OK;
270 }
271 
ConstructSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)272 int Scheduler::ConstructSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels) {
273   if (*is_control_flow_) {
274     return ConstructControlFlowMainGraph(dst_kernels);
275   }
276 
277   auto src_kernel = *dst_kernels;
278   dst_kernels->clear();
279   std::map<const kernel::KernelExec *, bool> is_kernel_finish;
280   return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
281 }
282 
ProcessSubGraphTranspose(std::vector<kernel::KernelExec * > * dst_kernels)283 int Scheduler::ProcessSubGraphTranspose(std::vector<kernel::KernelExec *> *dst_kernels) {
284 #ifndef ENABLE_MULTI_LAYOUT
285   auto ret = pass::RuntimeFormatPass(dst_kernels, src_tensors_, Format::NHWC);
286   if (ret != RET_OK) {
287     MS_LOG(ERROR) << "Run runtime format pass failed.";
288     return RET_ERROR;
289   }
290 #endif
291   return RET_OK;
292 }
293 
DelQuantDTypeCastKernel(std::vector<kernel::KernelExec * > * kernels)294 STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *kernels) {
295   for (auto iter = (*kernels).begin(); iter != (*kernels).end();) {
296     auto cur_kernel = *iter;
297     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
298       auto sub_inner_graph = reinterpret_cast<kernel::SubGraphKernel *>(cur_kernel);
299       auto &subgraph_nodes = sub_inner_graph->nodes();
300       if (DelQuantDTypeCastKernel(&subgraph_nodes) != RET_OK) {
301         MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
302         return RET_ERROR;
303       }
304     }
305     if (cur_kernel->type() != schema::PrimitiveType_QuantDTypeCast) {
306       iter++;
307       continue;
308     }
309     auto &post_kernels = cur_kernel->out_kernels();
310     auto &pre_kernels = cur_kernel->in_kernels();
311     if (cur_kernel->in_tensors().size() != 1) {
312       MS_LOG(ERROR) << cur_kernel->name() << " input size error."
313                     << " cur_kernel in tensors size:" << cur_kernel->in_tensors().size();
314       return RET_ERROR;
315     }
316     bool graph_input = pre_kernels.empty();
317     if (!graph_input) {
318       // modify post kernel input to new kernel and new tensor
319       for (auto post_kernel : post_kernels) {
320         auto post_in_kernels = post_kernel->in_kernels();
321         auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
322         *post_input_iter = pre_kernels[0];
323         post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
324         post_kernel->set_in_kernels(post_in_kernels);
325       }
326       auto pre_out_kernels = pre_kernels[0]->out_kernels();
327       auto pre_out_iter = std::find(pre_out_kernels.begin(), pre_out_kernels.end(), cur_kernel);
328       if (pre_out_iter != pre_out_kernels.end()) {
329         pre_out_kernels.erase(pre_out_iter);
330         pre_out_kernels.insert(pre_out_iter, post_kernels.begin(), post_kernels.end());
331         pre_kernels[0]->set_out_kernels(pre_kernels);
332       }
333     } else {
334       for (auto post_kernel : post_kernels) {
335         auto post_in_kernels = post_kernel->in_kernels();
336         auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
337         *post_input_iter = {};
338         post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
339         post_kernel->set_in_kernels(post_in_kernels);
340       }
341     }
342 
343     // update data type
344     for (auto tensor : cur_kernel->in_tensors()) {
345       tensor->set_data_type(kNumberTypeFloat32);
346     }
347     for (auto tensor : cur_kernel->out_tensors()) {
348       tensor->set_data_type(kNumberTypeFloat32);
349     }
350 
351     // update model output kernel & tensor
352     if (cur_kernel->is_model_output()) {
353       pre_kernels[0]->set_is_model_output(true);
354       cur_kernel->in_tensors()[0]->set_category(Category::GRAPH_OUTPUT);
355       pre_kernels[0]->set_out_kernels({});
356       // If the current kernel is the output kernel, use the current output tensor as the output tensor of the previous
357       // node.
358       auto pre_out_tensors = pre_kernels[0]->out_tensors();
359       auto tensor_iter = std::find(pre_out_tensors.begin(), pre_out_tensors.end(), cur_kernel->in_tensors()[0]);
360       if (tensor_iter != pre_kernels[0]->out_tensors().end()) {
361         *tensor_iter = cur_kernel->out_tensors()[0];
362       }
363     }
364 
365     // delete cur kernel
366     iter = kernels->erase(iter);
367     MS_LOG(DEBUG) << "Delete kernel: " << cur_kernel->name();
368     delete cur_kernel;
369   }
370   return RET_OK;
371 }
372 
Schedule(std::vector<kernel::KernelExec * > * dst_kernels)373 int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
374   MS_LOG(DEBUG) << "Start schedule.";
375   int check_input_ret = CheckInputParam(dst_kernels);
376   if (check_input_ret != RET_OK) {
377     MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
378     return check_input_ret;
379   }
380 
381   shape_fusion_pass_ =
382     std::make_shared<ShapeFusionPass>(context_, reinterpret_cast<LiteModel *>(src_model_), src_tensors_);
383   MS_CHECK_TRUE_RET(shape_fusion_pass_ != nullptr, RET_ERROR);
384   int ret = SchedulePreProcess();
385   if (ret != RET_OK) {
386     return ret;
387   }
388 
389   if (*is_control_flow_) {
390     control_flow_scheduler_ = std::make_shared<ControlFlowScheduler>(context_, ms_context_, src_tensors_);
391     MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "new control scheduler failed.");
392   }
393 
394   ret = ScheduleGraphToKernels(dst_kernels);
395   FreeOpParameters();
396   op_parameters_.clear();
397   if (ret != RET_OK) {
398     MS_LOG(ERROR) << "Schedule graph to kernels failed.";
399     return ret;
400   }
401   if (context_->float_mode) {
402     kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
403     ret = DelQuantDTypeCastKernel(dst_kernels);
404     if (ret != RET_OK) {
405       MS_LOG(ERROR) << "Delete quant_dtype_cast kernel failed.";
406       return ret;
407     }
408   }
409   shape_fusion_pass_->StoreStateAndReset();
410 
411   MS_LOG(DEBUG) << "Start to init delegate kernels.";
412   ret = InitDelegateKernels(dst_kernels);
413   if (ret != RET_OK) {
414     MS_LOG(ERROR) << "Repalce delegate kernels failed.";
415     return ret;
416   }
417   MS_LOG(DEBUG) << "Finish to init delegate kernels.";
418 
419   ret = CheckCpuValid(dst_kernels);
420   if (ret != RET_OK) {
421     MS_LOG(ERROR) << "kernels invalid in set devices.";
422     return ret;
423   }
424 
425   kernel::KernelExecUtil::FindAllInoutKernels(*dst_kernels);
426 
427   ret = ConstructSubGraphs(dst_kernels);
428   if (ret != RET_OK) {
429     MS_LOG(ERROR) << "ConstructSubGraphs failed.";
430     return ret;
431   }
432 
433   ret = ProcessSubGraphTranspose(dst_kernels);
434   if (ret != RET_OK) {
435     MS_LOG(ERROR) << "Process SubGraph with multi layout failed.";
436     return ret;
437   }
438 
439   if (*is_control_flow_) {
440     control_flow_scheduler_->SetSubgraphForPartialNode(&partial_kernel_subgraph_index_map_,
441                                                        &subgraph_index_subgraph_kernel_map_);
442     ret = control_flow_scheduler_->Schedule(dst_kernels);
443     MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "control flow schedule failed.");
444   }
445 
446   auto status = RuntimePass(dst_kernels, src_tensors_);
447   if (status != RET_OK) {
448     MS_LOG(ERROR) << "runtime pass failed.";
449     return RET_ERROR;
450   }
451 
452   ret = InitKernels(std::move(*dst_kernels));
453   if (ret != RET_OK) {
454     MS_LOG(ERROR) << "InitKernels failed.";
455     return ret;
456   }
457   shape_fusion_pass_->RestoreState();
458   if (IsPrintDebug()) {
459     MS_LOG(DEBUG) << "schedule kernels success.";
460     for (auto subgraph : *dst_kernels) {
461       MS_LOG(DEBUG) << "[subgraph] : " << subgraph->name() << ",  type:" << subgraph->subgraph_type();
462       if (subgraph->desc().arch == kernel::KERNEL_ARCH::kDelegate) {
463         continue;
464       }
465       std::vector<kernel ::KernelExec *> kernel_list = reinterpret_cast<kernel::SubGraphKernel *>(subgraph)->nodes();
466       for (auto kernel : kernel_list) {
467         MS_LOG(DEBUG) << "kernel: [" << kernel->name() << "] "
468                       << "TypeId(" << kernel->desc().data_type << "); "
469                       << "OpType(" << PrimitiveCurVersionTypeName(kernel->desc().type) << "); "
470                       << "format(" << kernel->desc().format << "); "
471                       << "arch(" << kernel->desc().arch << ")";
472       }
473     }
474   }
475   return RET_OK;
476 }
477 
CheckInputParam(const std::vector<kernel::KernelExec * > * dst_kernels) const478 int Scheduler::CheckInputParam(const std::vector<kernel::KernelExec *> *dst_kernels) const {
479   if (dst_kernels == nullptr) {
480     return RET_ERROR;
481   }
482   if (src_model_ == nullptr) {
483     MS_LOG(ERROR) << "Input model is nullptr";
484     return RET_PARAM_INVALID;
485   }
486   if (src_model_->graph_.sub_graphs_.empty()) {
487     MS_LOG(ERROR) << "Model should have a subgraph at least";
488     return RET_PARAM_INVALID;
489   }
490   return RET_OK;
491 }
492 
493 #ifndef DELEGATE_CLIP
ReplaceDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)494 int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
495   std::vector<kernel::Kernel *> kernels;
496   for (size_t i = 0; i < dst_kernels->size(); i++) {
497     auto litert_kernel = reinterpret_cast<kernel::Kernel *>((*dst_kernels)[i]->kernel());
498     if (MS_UNLIKELY(litert_kernel == nullptr)) {
499       MS_LOG(ERROR) << "nullptr exist in dst_kernels.";
500       return RET_ERROR;
501     }
502     kernels.push_back(litert_kernel);
503   }
504 
505   ms_inputs_ = LiteTensorsToMSTensors(*inputs_);
506   ms_outputs_ = LiteTensorsToMSTensors(*outputs_);
507   auto schema_version = static_cast<SchemaVersion>(context_->get_schema_version());
508   DelegateModel<schema::Primitive> *model =
509     new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
510   if (model == nullptr) {
511     MS_LOG(ERROR) << "New delegate model failed.";
512     return RET_NULL_PTR;
513   }
514 
515 #ifdef SUPPORT_NNRT
516   if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
517     auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
518     delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
519     void *meta_graph = reinterpret_cast<void *>(
520       const_cast<mindspore::schema::MetaGraph *>(mindspore::schema::GetMetaGraph(this->src_model_->buf)));
521     delegate->SetMetaGraph(meta_graph);
522     delegate->SetDequantTensors(this->src_tensors_);
523   }
524 #endif
525 
526   auto ret = delegate_->Build(model);
527   if (ret != mindspore::kSuccess) {
528     delete model;
529     MS_LOG(ERROR) << "Delegate prepare kernels failed.";
530     return RET_ERROR;
531   }
532 
533   auto src_kernels = *dst_kernels;
534   dst_kernels->clear();
535   std::map<const kernel::KernelExec *, bool> delegate_support;
536   for (auto kernel : src_kernels) {
537     delegate_support[kernel] = true;
538   }
539   for (auto kernel : kernels) {
540     size_t index = 0;
541     for (; index < src_kernels.size(); index++) {
542       if (kernel == src_kernels[index]->kernel()) {
543         // Kernels that the delegate does not support keep the original backend
544         dst_kernels->push_back(src_kernels[index]);
545         delegate_support[src_kernels[index]] = false;
546         break;
547       }
548     }
549     if (index == src_kernels.size()) {
550       // New liteKernel to save delegate subgraph
551       std::shared_ptr<kernel::Kernel> shared_kernel(kernel);
552       auto kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel);
553       if (kernel_exec == nullptr) {
554         delete model;
555         MS_LOG(ERROR) << "New KernelExec for delegate subgraph failed.";
556         return RET_NULL_PTR;
557       }
558       auto delegate_type = kNumberTypeFloat32;
559       for (auto &input : kernel->inputs()) {
560         if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) {
561           delegate_type = kNumberTypeFloat16;
562           break;
563         }
564       }
565       kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""};
566       kernel_exec->set_desc(delegate_desc);
567       dst_kernels->push_back(kernel_exec);
568     }
569   }
570   // Release the cpu kernel that has been replace by delegate subgraph, as well as their tensor data
571   for (auto kernel : src_kernels) {
572     if (delegate_support[kernel] == true) {
573       auto inputs = kernel->in_tensors();
574       for (auto *tensor : inputs) {
575         MS_ASSERT(tensor != nullptr);
576         if (tensor->IsConst()) {
577           tensor->FreeData();
578         }
579       }
580       delete kernel;
581     }
582   }
583   delete model;
584   return RET_OK;
585 }
586 
InitDelegateKernels(std::vector<kernel::KernelExec * > * dst_kernels)587 int Scheduler::InitDelegateKernels(std::vector<kernel::KernelExec *> *dst_kernels) {
588   /* no delegate valid */
589   if (delegate_ == nullptr) {
590     return RET_OK;
591   }
592 
593   /* set delegate spin count */
594   context_->thread_pool_->SetSpinCountMinValue();
595 
596   /* external delegate */
597   if (delegate_device_type_ == -1) {
598     auto ret = ReplaceDelegateKernels(dst_kernels);
599     if (ret != RET_OK) {
600       MS_LOG(ERROR) << "external delegate init failed.";
601       return ret;
602     }
603   }
604 
605   /* Inner delegate  :  check Priority */
606   std::vector<kernel::KernelExec *> src_kernels = *dst_kernels;
607   dst_kernels->clear();
608 
609   while (!src_kernels.empty()) {
610     std::vector<kernel::KernelExec *> tmp_kernels;
611     kernel::KernelExec *remain_kernel = nullptr;
612 
613     /* Loop for inner delegate npu and TensorRT subgraph */
614     while (!src_kernels.empty()) {
615       auto kernel = src_kernels.front();
616       VectorErase(&src_kernels, kernel);
617       bool priority_ret =
618         DeviceTypePriority(context_, delegate_device_type_, KernelArchToDeviceType(kernel->desc().arch));
619       if (priority_ret == true) {
620         tmp_kernels.push_back(kernel);
621       } else {
622         remain_kernel = kernel;
623         break;
624       }
625     }
626 
627     /* start current NPU-kernels replace */
628     if (tmp_kernels.empty()) {
629       if (remain_kernel != nullptr) {
630         dst_kernels->push_back(remain_kernel);
631         remain_kernel = nullptr;
632       }
633       continue;
634     }
635     auto ret = ReplaceDelegateKernels(&tmp_kernels);
636     if (ret != RET_OK) {
637       dst_kernels->insert(dst_kernels->end(), src_kernels.begin(), src_kernels.end());
638       dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
639       if (remain_kernel != nullptr) {
640         dst_kernels->push_back(remain_kernel);
641       }
642       MS_LOG(ERROR) << "Inner delegate replace delegate kernels failed.";
643       return ret;
644     }
645 
646     dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
647     tmp_kernels.clear();
648     if (remain_kernel != nullptr) {
649       dst_kernels->push_back(remain_kernel);
650       remain_kernel = nullptr;
651     }
652   }
653 
654   return RET_OK;
655 }
656 #endif
657 
FindNodeInoutTensors(const lite::LiteGraph::Node & node,std::vector<Tensor * > * inputs,std::vector<Tensor * > * outputs)658 void Scheduler::FindNodeInoutTensors(const lite::LiteGraph::Node &node, std::vector<Tensor *> *inputs,
659                                      std::vector<Tensor *> *outputs) {
660   MS_ASSERT(inputs != nullptr);
661   MS_ASSERT(outputs != nullptr);
662   auto in_size = node.input_indices_.size();
663   inputs->reserve(in_size);
664   for (size_t j = 0; j < in_size; ++j) {
665     inputs->emplace_back(src_tensors_->at(node.input_indices_[j]));
666   }
667   auto out_size = node.output_indices_.size();
668   outputs->reserve(out_size);
669   for (size_t j = 0; j < out_size; ++j) {
670     outputs->emplace_back(src_tensors_->at(node.output_indices_[j]));
671   }
672 }
673 
InferNodeShape(const lite::LiteGraph::Node * node)674 int Scheduler::InferNodeShape(const lite::LiteGraph::Node *node) {
675   MS_ASSERT(node != nullptr);
676   auto primitive = node->primitive_;
677   MS_ASSERT(primitive != nullptr);
678   std::vector<Tensor *> inputs;
679   std::vector<Tensor *> outputs;
680   FindNodeInoutTensors(*node, &inputs, &outputs);
681   auto ret =
682     KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders(), context_->get_schema_version());
683   if (ret != RET_NOT_SUPPORT) {
684     *infer_along_running_ = false;
685     return ret;
686   }
687 
688   auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(
689     GetPrimitiveType(node->primitive_, context_->get_schema_version()), context_->get_schema_version());
690   if (parame_gen == nullptr) {
691     MS_LOG(ERROR) << "parameter generator is nullptr.";
692     FreeOpParameters();
693     return RET_NULL_PTR;
694   }
695   auto parameter = parame_gen(primitive);
696   if (parameter == nullptr) {
697     MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
698                   << GetPrimitiveTypeName(primitive, context_->get_schema_version());
699     FreeOpParameters();
700     return RET_ERROR;
701   }
702 
703   parameter->quant_type_ = node->quant_type_;
704   parameter->thread_num_ = context_->thread_num_;
705   if (node->output_indices_.empty()) {
706     MS_LOG(ERROR) << "The output size is invalid";
707     if (parameter->destroy_func_ != nullptr) {
708       parameter->destroy_func_(parameter);
709     }
710     free(parameter);
711     return RET_ERROR;
712   }
713   if (op_parameters_.find(node->output_indices_.at(0)) != op_parameters_.end()) {
714     if (parameter->destroy_func_ != nullptr) {
715       parameter->destroy_func_(parameter);
716     }
717     free(parameter);
718     parameter = op_parameters_[node->output_indices_.at(0)];
719   } else {
720     op_parameters_[node->output_indices_.at(0)] = parameter;
721   }
722 
723   if (IsCallNode(primitive, context_->get_schema_version())) {
724     return InferCallShape(node);
725   }
726   ret = KernelInferShape(inputs, outputs, parameter, context_->allocator);
727   if (ret != RET_OK && ret != RET_INFER_INVALID) {
728     FreeOpParameters();
729     return RET_ERROR;
730   }
731   for (auto &output : outputs) {
732     output->set_shape_changed(false);
733   }
734   if (*is_control_flow_) {
735     for (auto &output : outputs) {
736       output->set_shape({-1});
737     }
738     return RET_INFER_INVALID;
739   }
740 
741   if (ret == RET_OK) {
742     for (auto &output : outputs) {
743       if (static_cast<size_t>(output->ElementsNum()) >= GetMaxMallocSize() / sizeof(int64_t)) {
744         MS_LOG(ERROR) << "The size of output tensor is too big";
745         FreeOpParameters();
746         return RET_ERROR;
747       }
748     }
749   } else if (ret != RET_INFER_INVALID) {
750     FreeOpParameters();
751     return RET_ERROR;
752   }
753   return ret;
754 }
755 
FreeOpParameters()756 void Scheduler::FreeOpParameters() {
757   for (auto &param : op_parameters_) {
758     if (param.second != nullptr) {
759       if (param.second->destroy_func_ != nullptr) {
760         param.second->destroy_func_(param.second);
761       }
762       free(param.second);
763       param.second = nullptr;
764     }
765   }
766 }
767 
RestoreSubGraphInput(const lite::LiteGraph::Node * partial_node)768 int Scheduler::RestoreSubGraphInput(const lite::LiteGraph::Node *partial_node) {
769   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
770   MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
771   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
772   for (size_t i = 0; i < subgraph->input_indices_.size(); ++i) {
773     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
774     subgraph_input->set_data(nullptr);
775   }
776   return RET_OK;
777 }
778 
CopyCommonTensor(Tensor * dst_tensor,Tensor * src_tensor)779 void CopyCommonTensor(Tensor *dst_tensor, Tensor *src_tensor) {
780   dst_tensor->set_data_type(src_tensor->data_type());
781   dst_tensor->set_shape(src_tensor->shape());
782   dst_tensor->set_format(src_tensor->format());
783   dst_tensor->set_data(src_tensor->data());
784 }
785 
CopyPartialShapeToSubGraph(const lite::LiteGraph::Node * partial_node)786 int Scheduler::CopyPartialShapeToSubGraph(const lite::LiteGraph::Node *partial_node) {
787   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, context_->get_schema_version());
788   MS_CHECK_TRUE_MSG(subgraph_index >= 0, RET_NULL_PTR, "subgraph index is negative.");
789   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
790   if (subgraph->input_indices_.size() != partial_node->input_indices_.size()) {
791     MS_LOG(ERROR) << "partial node " << partial_node->name_ << " inputs size: " << partial_node->input_indices_.size()
792                   << " vs "
793                   << " subgraph input size: " << subgraph->input_indices_.size();
794     return RET_PARAM_INVALID;
795   }
796 
797   for (size_t i = 0; i < partial_node->input_indices_.size(); ++i) {
798     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
799     auto &partial_input = src_tensors_->at(partial_node->input_indices_[i]);
800     if (partial_input->data_type() == kObjectTypeTensorType) {
801       return RET_INFER_INVALID;
802     }
803     CopyCommonTensor(subgraph_input, partial_input);
804   }
805 
806   return RET_OK;
807 }
808 
InferPartialShape(const lite::LiteGraph::Node * node)809 int Scheduler::InferPartialShape(const lite::LiteGraph::Node *node) {
810   MS_ASSERT(src_model_ != nullptr);
811   MS_ASSERT(node != nullptr);
812   if (!IsPartialNode(node->primitive_, context_->get_schema_version())) {
813     MS_LOG(ERROR) << "Node is not a partial";
814     return RET_PARAM_INVALID;
815   }
816   CopyPartialShapeToSubGraph(node);
817   int subgraph_index = GetPartialGraphIndex(node->primitive_, context_->get_schema_version());
818   auto ret = InferSubGraphShape(subgraph_index);
819   if (ret != RET_OK) {
820     MS_LOG(WARNING) << "infer subgraph: " << subgraph_index << " failed, ret:" << ret;
821   }
822   RestoreSubGraphInput(node);
823   return ret;
824 }
825 
NodeInputIsPartial(const lite::LiteGraph::Node * node)826 LiteGraph::Node *Scheduler::NodeInputIsPartial(const lite::LiteGraph::Node *node) {
827   MS_ASSERT(src_model_ != nullptr);
828   MS_ASSERT(node != nullptr);
829   for (auto &iter : src_model_->graph_.all_nodes_) {
830     if (iter->output_indices_ == node->input_indices_) {
831       if (IsPartialNode(iter->primitive_, context_->get_schema_version())) {
832         return iter;
833       } else {
834         return nullptr;
835       }
836     }
837   }
838   return nullptr;
839 }
840 
InferCallShape(const lite::LiteGraph::Node * node)841 int Scheduler::InferCallShape(const lite::LiteGraph::Node *node) {
842   MS_ASSERT(src_model_ != nullptr);
843   MS_ASSERT(node != nullptr);
844   if (!IsCallNode(node->primitive_, context_->get_schema_version())) {
845     MS_LOG(ERROR) << "Node is not a call cnode";
846     return RET_PARAM_INVALID;
847   }
848 
849   auto partial_input = NodeInputIsPartial(node);
850   if (partial_input) {
851     return InferPartialShape(partial_input);
852   }
853   auto switch_input = NodeInputIsSwitchType(node);
854   if (switch_input) {
855     *is_control_flow_ = true;
856     return InferSwitchShape(switch_input);
857   }
858 
859   MS_LOG(ERROR) << "call input is not partial and also not switch.";
860   return RET_ERROR;
861 }
862 
InferSubGraphShape(size_t subgraph_index)863 int Scheduler::InferSubGraphShape(size_t subgraph_index) {
864   MS_ASSERT(src_model_ != nullptr);
865   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
866   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
867   if (find(infer_subgraph_index_.begin(), infer_subgraph_index_.end(), subgraph_index) != infer_subgraph_index_.end()) {
868     MS_LOG(ERROR) << "The subgraph has been infer shape, subgraph index: " << subgraph_index;
869     return RET_INFER_INVALID;
870   }
871   infer_subgraph_index_.push_back(subgraph_index);
872   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
873   int subgraph_infershape_ret = RET_OK;
874   auto node_indexes = subgraph->node_indices_;
875   for (size_t i = 0; i < node_indexes.size(); ++i) {
876     auto node_index = node_indexes[i];
877     auto node = src_model_->graph_.all_nodes_[node_index];
878     MS_ASSERT(node != nullptr);
879     auto *primitive = node->primitive_;
880     if (primitive == nullptr) {
881       MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
882       return RET_ERROR;
883     }
884     if (node->node_type_ == schema::PrimitiveType_Shape) {
885       // convert shape to built-in shape
886       MS_CHECK_TRUE_RET(node->input_indices_.size() == 1, RET_ERROR);
887       shape_fusion_pass_->Run(node, subgraph_index);
888       node_indexes = subgraph->node_indices_;
889     }
890     auto ret = InferNodeShape(node);
891     if (ret == RET_INFER_INVALID) {
892       MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_
893                    << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version())
894                    << ", set infer flag to false.";
895       subgraph_infershape_ret = RET_INFER_INVALID;
896     } else if (ret != RET_OK) {
897       FreeOpParameters();
898       MS_LOG(ERROR) << "InferShape failed, name: " << node->name_
899                     << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
900       return RET_INFER_ERR;
901     }
902   }
903   return subgraph_infershape_ret;
904 }
905 
906 namespace {
907 // support_fp16: current device and package support float16
CastAndRestoreConstTensorData(Tensor * tensor,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)908 int CastAndRestoreConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors,
909                                   TypeId dst_data_type, bool support_fp16) {
910   MS_ASSERT(tensor != nullptr);
911   MS_ASSERT(tensor->IsConst());
912   MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
913   MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16);
914   if (tensor->data_type() == dst_data_type) {
915     return RET_OK;
916   }
917   auto origin_data = tensor->data();
918   MS_ASSERT(origin_data != nullptr);
919   auto restore_tensor = Tensor::CopyTensor(*tensor, false);
920   if (restore_tensor == nullptr) {
921     return RET_NULL_PTR;
922   }
923   restore_tensor->set_data(origin_data);
924   restore_tensor->set_own_data(tensor->own_data());
925   tensor->set_data(nullptr);
926   tensor->set_data_type(dst_data_type);
927   auto ret = tensor->MallocData();
928   if (RET_OK != ret) {
929     MS_LOG(ERROR) << "malloc data failed";
930     return ret;
931   }
932   auto new_tensor_data = tensor->data();
933   MS_ASSERT(new_tensor_data != nullptr);
934   if (dst_data_type == kNumberTypeFloat32) {
935     Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
936   } else {  // dst_data_type == kNumberTypeFloat16
937     Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
938   }
939   if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
940     MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
941     delete restore_tensor;
942     return RET_ERROR;
943   }
944   (*restored_origin_tensors)[tensor] = restore_tensor;
945   return RET_OK;
946 }
947 
948 // support_fp16: current device and package support float16
CastConstTensorsData(const std::vector<Tensor * > & tensors,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)949 int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
950                          TypeId dst_data_type, bool support_fp16) {
951   MS_ASSERT(restored_origin_tensors != nullptr);
952   if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
953     MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type.";
954     return RET_PARAM_INVALID;
955   }
956   for (auto *tensor : tensors) {
957     MS_ASSERT(tensor != nullptr);
958     // only cast const tensor
959     // tensorlist not support fp16 now
960     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
961       continue;
962     }
963     // only support fp32->fp16 or fp16->fp32
964     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
965       continue;
966     }
967     if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
968       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16, support_fp16);
969       if (ret != RET_OK) {
970         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
971         return ret;
972       }
973     } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
974       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32, support_fp16);
975       if (ret != RET_OK) {
976         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
977         return ret;
978       }
979     } else {
980       MS_LOG(DEBUG) << "No need to cast from " << tensor->data_type() << " to " << dst_data_type;
981     }
982   }
983   return RET_OK;
984 }
985 
FreeRestoreTensors(std::map<Tensor *,Tensor * > * restored_origin_tensors)986 inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
987   MS_ASSERT(restored_origin_tensors != nullptr);
988   for (auto &restored_origin_tensor : *restored_origin_tensors) {
989     restored_origin_tensor.second->set_data(nullptr);
990     delete (restored_origin_tensor.second);
991     restored_origin_tensor.second = nullptr;
992   }
993   restored_origin_tensors->clear();
994 }
995 
RestoreTensorData(std::map<Tensor *,Tensor * > * restored_origin_tensors)996 inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
997   MS_ASSERT(restored_origin_tensors != nullptr);
998   for (auto &restored_origin_tensor : *restored_origin_tensors) {
999     auto *origin_tensor = restored_origin_tensor.first;
1000     auto *restored_tensor = restored_origin_tensor.second;
1001     MS_ASSERT(origin_tensor != nullptr);
1002     MS_ASSERT(restored_tensor != nullptr);
1003     origin_tensor->FreeData();
1004     origin_tensor->set_data_type(restored_tensor->data_type());
1005     origin_tensor->set_data(restored_tensor->data());
1006     origin_tensor->set_own_data(restored_tensor->own_data());
1007   }
1008   FreeRestoreTensors(restored_origin_tensors);
1009 }
1010 }  // namespace
1011 
ResetByExecutionPlan(std::string node_name,TypeId * data_type)1012 void Scheduler::ResetByExecutionPlan(std::string node_name, TypeId *data_type) {
1013   if (execution_plan_ == nullptr) {
1014     return;
1015   }
1016   auto iter = execution_plan_->find(node_name);
1017   if (iter != execution_plan_->end()) {
1018     *data_type = iter->second;
1019   }
1020   return;
1021 }
1022 
FindCpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,TypeId kernel_data_type,kernel::KernelExec ** kernel)1023 int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1024                              OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
1025                              kernel::KernelExec **kernel) {
1026   MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr.");
1027   auto op_type = op_parameter->type_;
1028   if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1029     MS_LOG(INFO) << "unsupported op_type: " << PrimitiveCurVersionTypeName(op_type)
1030                  << ", data_type: " << desc.data_type;
1031     return RET_NOT_SUPPORT;
1032   }
1033   kernel::KernelKey cpu_desc = desc;
1034   if (kernel_data_type == kNumberTypeFloat16) {
1035     if (!context_->IsCpuFloat16Enabled() ||
1036         (cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
1037       return RET_NOT_SUPPORT;
1038     }
1039     cpu_desc.data_type = kNumberTypeFloat16;
1040   }
1041   auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->graph_.version_,
1042                                         context_->float_mode);
1043   if (ret != RET_OK) {
1044     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1045     return RET_NOT_SUPPORT;
1046   }
1047   std::map<Tensor *, Tensor *> restored_origin_tensors;
1048 
1049   if (is_train_session_) {
1050     ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type,
1051                                context_->device_and_pkg_support_fp16_);
1052     if (ret != RET_OK) {
1053       MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
1054       return RET_NOT_SUPPORT;
1055     }
1056   }
1057 
1058 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
1059   // reset op task num, The number of operator segmentation tasks is not necessarily equal to the number of threads
1060   int thread_num_limit = ParallelThreadPoolManager::GetInstance()->GetTaskNum(config_info_);
1061   if (thread_num_limit != -1 && IsSharedThreadPoolOp(op_type)) {
1062     op_parameter->thread_num_ = thread_num_limit;
1063   }
1064 #endif
1065 
1066   ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, cpu_desc,
1067                                                      op_parameter, kernel);
1068   if (ret == RET_OK) {
1069     MS_LOG(DEBUG) << "Get TypeId(expect = " << kernel_data_type << ", real = " << cpu_desc.data_type
1070                   << ") op success: " << PrimitiveCurVersionTypeName(op_type);
1071     if (is_train_session_) {
1072       ret = (*kernel)->Prepare();
1073       RestoreTensorData(&restored_origin_tensors);
1074     }
1075   }
1076   return ret;
1077 }
1078 
1079 #ifdef GPU_OPENCL
FindGpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,kernel::KernelExec ** kernel,TypeId prefer_data_type)1080 int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1081                              OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::KernelExec **kernel,
1082                              TypeId prefer_data_type) {
1083   MS_ASSERT(op_parameter != nullptr);
1084   MS_ASSERT(kernel != nullptr);
1085   if (!context_->IsDeviceTypeEnabled(DT_GPU)) {
1086     return RET_NOT_SUPPORT;
1087   }
1088 
1089   // support more data type like int32
1090   kernel::KernelKey gpu_desc{kernel::KERNEL_ARCH::kGPU, desc.data_type, NHWC, desc.type};
1091   if (desc.data_type == kNumberTypeFloat32 && context_->IsGpuFloat16Enabled()) {
1092     gpu_desc.data_type = kNumberTypeFloat16;
1093   }
1094   if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kNumberTypeFloat32) {
1095     gpu_desc.data_type = prefer_data_type;
1096   }
1097   // weight dequant
1098   auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32, src_model_->graph_.version_,
1099                                         context_->float_mode);
1100   if (ret != RET_OK) {
1101     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
1102     return RET_NOT_SUPPORT;
1103   }
1104   // we don't need to restore tensor for copy data
1105   ret = CopyConstTensorData(in_tensors, op_parameter->type_);
1106   if (ret != RET_OK) {
1107     MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
1108     return RET_NOT_SUPPORT;
1109   }
1110   ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, gpu_desc,
1111                                                      op_parameter, kernel);
1112   if (ret == RET_OK) {
1113     MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1114   } else {
1115     MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
1116   }
1117   return ret;
1118 }
1119 #endif
1120 
FindProviderKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId data_type,kernel::KernelExec ** kernel)1121 int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
1122                                   const LiteGraph::Node *node, TypeId data_type, kernel::KernelExec **kernel) {
1123 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
1124   MS_ASSERT(kernel != nullptr);
1125   int ret = RET_NOT_SUPPORT;
1126   auto prim_type = GetPrimitiveType(node->primitive_, context_->get_schema_version());
1127   if (prim_type == schema::PrimitiveType_Custom) {
1128     for (auto &&device : context_->device_list_) {
1129       if (!device.provider_.empty() && !device.provider_device_.empty()) {
1130         kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type,       NHWC, prim_type,
1131                                device.provider_device_,   device.provider_};
1132         ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc,
1133                                                            nullptr, kernel, node->primitive_);
1134         if (ret == RET_OK && *kernel != nullptr) {
1135           return ret;
1136         }
1137       }
1138     }
1139 
1140     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, prim_type, "", ""};
1141     ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1142                                                        kernel, node->primitive_);
1143     if (ret == RET_OK && *kernel != nullptr) {
1144       return ret;
1145     }
1146     return RET_NOT_SUPPORT;
1147   }
1148   if (!context_->IsProviderEnabled()) {
1149     return ret;
1150   }
1151   if (context_->get_schema_version() == SCHEMA_V0) {
1152     return ret;
1153   }
1154   for (auto &&device : context_->device_list_) {
1155     if (!device.provider_.empty()) {
1156       kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type,       NHWC, prim_type,
1157                              device.provider_device_,   device.provider_};
1158       ret = KernelRegistry::GetInstance()->GetKernelExec(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
1159                                                          kernel, node->primitive_);
1160       if (ret == RET_OK && *kernel != nullptr) {
1161         return ret;
1162       }
1163     }
1164   }
1165 #endif
1166   return RET_NOT_SUPPORT;
1167 }
1168 
FindBackendKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId prefer_data_type)1169 kernel::KernelExec *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
1170                                                  const std::vector<Tensor *> &out_tensors, const LiteGraph::Node *node,
1171                                                  TypeId prefer_data_type) {
1172   MS_ASSERT(node != nullptr);
1173   // why we need this
1174   TypeId data_type;
1175   if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1176     if (in_tensors.front()->data_type() == kNumberTypeBool) {
1177       data_type = kNumberTypeBool;
1178     } else {
1179       data_type = kNumberTypeFloat32;
1180     }
1181   } else {
1182     data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
1183     if (data_type == kTypeUnknown) {
1184       MS_LOG(ERROR) << "GetFirstFp32Fp16OrInt8Type is unknown.";
1185       return nullptr;
1186     }
1187   }
1188   if (context_->float_mode) {
1189     for (auto tensor : out_tensors) {
1190       if (!tensor->quant_params().empty() &&
1191           (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeUInt8)) {
1192         data_type = kNumberTypeFloat32;
1193         tensor->set_data_type(kNumberTypeFloat32);
1194       }
1195     }
1196   }
1197   kernel::KernelExec *kernel = nullptr;
1198   auto status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
1199   if (status == RET_OK && kernel != nullptr) {
1200     return kernel;
1201   }
1202   MS_ASSERT(!node->output_indices_.empty());
1203   OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1204   if (op_parameter == nullptr) {
1205     MS_LOG(ERROR) << "Can not find OpParameter!type: "
1206                   << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1207     return nullptr;
1208   }
1209 
1210 #if (defined GPU_OPENCL) || (defined ENABLE_FP16)
1211   int kernel_thread_count = op_parameter->thread_num_;
1212 #endif
1213   op_parameter->is_train_session_ = is_train_session_;
1214   kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, op_parameter->type_};
1215 
1216 #ifdef GPU_OPENCL
1217   bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
1218   bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
1219   if (gpu_priority && use_gpu_kernel) {
1220     status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel, prefer_data_type);
1221     if (status == RET_OK) {
1222       return kernel;
1223     } else {
1224       MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1225                     << node->name_;
1226       if (status == RET_ERROR) {
1227         op_parameters_.erase(node->output_indices_.at(0));
1228         auto ret = InferNodeShape(node);
1229         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1230           op_parameter = op_parameters_[node->output_indices_.at(0)];
1231           op_parameter->thread_num_ = kernel_thread_count;
1232         } else {
1233           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1234           return nullptr;
1235         }
1236       }
1237     }
1238   }
1239 #endif
1240 #ifdef ENABLE_FP16
1241   if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
1242       ((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
1243     status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
1244     if (status == RET_OK) {
1245       return kernel;
1246     } else {
1247       MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1248                     << node->name_;
1249       if (status == RET_ERROR) {
1250         op_parameters_.erase(node->output_indices_.at(0));
1251         auto ret = InferNodeShape(node);
1252         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1253           op_parameter = op_parameters_[node->output_indices_.at(0)];
1254           op_parameter->thread_num_ = kernel_thread_count;
1255         } else {
1256           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1257           return nullptr;
1258         }
1259       }
1260     }
1261   }
1262 #endif
1263   if (data_type == kNumberTypeFloat16) {
1264     MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
1265     desc.data_type = kNumberTypeFloat32;
1266   }
1267   status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
1268   if (status == RET_OK) {
1269     return kernel;
1270   } else if (status == RET_ERROR) {
1271     op_parameters_.erase(node->output_indices_.at(0));
1272     auto ret = InferNodeShape(node);
1273     if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
1274       MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1275     }
1276   }
1277 #ifdef OP_INT8_CLIP
1278   if (desc.data_type == kNumberTypeInt8) {
1279     MS_LOG(ERROR) << unsupport_int8_log;
1280   }
1281 #endif
1282   return nullptr;
1283 }
1284 
1285 namespace {
GetKernelSubGraphType(const kernel::KernelExec * kernel,const InnerContext & context,bool is_controlflow=false)1286 kernel::SubGraphType GetKernelSubGraphType(const kernel::KernelExec *kernel, const InnerContext &context,
1287                                            bool is_controlflow = false) {
1288   if (kernel == nullptr) {
1289     return kernel::kNotSubGraph;
1290   }
1291 
1292   auto desc = kernel->desc();
1293   if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
1294     if (desc.data_type == kNumberTypeFloat16) {
1295       return kernel::kGpuFp16SubGraph;
1296     } else {
1297       return kernel::kGpuFp32SubGraph;
1298     }
1299   } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) {
1300     return kernel::kNpuSubGraph;
1301   } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) {
1302     return kernel::kApuSubGraph;
1303   } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) {
1304     if (desc.data_type == kNumberTypeFloat16) {
1305       return kernel::kCpuFP16SubGraph;
1306     } else {
1307       return kernel::kCpuFP32SubGraph;
1308     }
1309   } else if (desc.arch == kernel::KERNEL_ARCH::kCustom) {
1310     return kernel::kCustomSubGraph;
1311   }
1312   return kernel::kNotSubGraph;
1313 }
1314 }  // namespace
1315 
SchedulePartialToKernel(const lite::LiteGraph::Node * src_node)1316 kernel::KernelExec *Scheduler::SchedulePartialToKernel(const lite::LiteGraph::Node *src_node) {
1317   MS_ASSERT(src_model_ != nullptr);
1318   MS_ASSERT(src_node != nullptr);
1319   auto *primitive = src_node->primitive_;
1320   MS_ASSERT(primitive != nullptr);
1321   if (!IsPartialNode(primitive, context_->get_schema_version())) {
1322     return nullptr;
1323   }
1324   auto subgraph_index = GetPartialGraphIndex(src_node->primitive_, context_->get_schema_version());
1325   if (SubGraphHasScheduled(subgraph_index)) {
1326     MS_LOG(INFO) << "Subgraph has been scheduled.";
1327     return {};
1328   } else {
1329     SubGraphMarkScheduled(subgraph_index);
1330   }
1331   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1332   if (subgraph_kernel == nullptr) {
1333     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1334     return {};
1335   }
1336   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1337   return subgraph_kernel;
1338 }
1339 
1340 #ifdef ENABLE_FP16
SubGraphPreferDataType(const int & subgraph_index,TypeId * prefer_data_type)1341 int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
1342   if (!context_->IsCpuFloat16Enabled() || context_->GetDelegateMode() == kNNAPI) {
1343     *prefer_data_type = kNumberTypeFloat32;
1344     return RET_OK;
1345   }
1346 
1347   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1348   for (auto node_index : subgraph->node_indices_) {
1349     auto node = src_model_->graph_.all_nodes_[node_index];
1350     MS_ASSERT(node != nullptr);
1351     MS_ASSERT(!node->output_indices_.empty());
1352     OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1353     if (op_parameter == nullptr) {
1354       MS_LOG(ERROR) << "Can not find OpParameter!type: "
1355                     << GetPrimitiveTypeName(node->primitive_, context_->get_schema_version());
1356       return RET_ERROR;
1357     }
1358     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16, NHWC, op_parameter->type_};
1359     if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1360       *prefer_data_type = kNumberTypeFloat32;
1361       return RET_OK;
1362     }
1363 
1364     std::vector<Tensor *> inputs;
1365     std::vector<Tensor *> outputs;
1366     FindNodeInoutTensors(*node, &inputs, &outputs);
1367     if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1368       *prefer_data_type = kNumberTypeFloat32;
1369       return RET_OK;
1370     }
1371     TypeId data_type = GetFirstFp32Fp16OrInt8Type(inputs);
1372     if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) {
1373       *prefer_data_type = kNumberTypeFloat32;
1374       return RET_OK;
1375     }
1376   }
1377   *prefer_data_type = kNumberTypeFloat16;
1378   return RET_OK;
1379 }
1380 #endif
1381 
ScheduleMainSubGraphToKernels()1382 std::vector<kernel::KernelExec *> Scheduler::ScheduleMainSubGraphToKernels() {
1383   std::vector<kernel::KernelExec *> kernels;
1384   std::vector<lite::Tensor *> in_tensors;
1385   std::vector<lite::Tensor *> out_tensors;
1386   TypeId prefer_data_type = context_->GetDelegateMode() == kNNAPI ? kNumberTypeFloat32 : kTypeUnknown;
1387   auto ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1388   if (ret != RET_OK) {
1389     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << kMainSubGraphIndex;
1390     for (auto *kernel : kernels) {
1391       delete kernel;
1392       kernel = nullptr;
1393     }
1394     return {};
1395   }
1396   return kernels;
1397 }
1398 
SchedulePartialToSubGraphKernel(const int & subgraph_index)1399 kernel::KernelExec *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
1400   TypeId prefer_data_type = kTypeUnknown;
1401 #ifdef ENABLE_FP16
1402   if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
1403     MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
1404     return nullptr;
1405   }
1406 #endif
1407   std::vector<kernel::KernelExec *> kernels;
1408   std::vector<lite::Tensor *> in_tensors;
1409   std::vector<lite::Tensor *> out_tensors;
1410   auto ret = ScheduleSubGraphToKernels(subgraph_index, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1411   if (ret != RET_OK) {
1412     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << subgraph_index;
1413     return nullptr;
1414   }
1415   kernel::KernelExecUtil::FindAllInoutKernels(kernels);
1416   kernel::SubGraphType cur_sub_graph_type = kernel::kCpuFP32SubGraph;
1417   if (!kernels.empty()) {
1418     cur_sub_graph_type = GetKernelSubGraphType(kernels.front(), *context_, true);
1419   }
1420   MS_LOG(INFO) << "cur_sub_graph_type: " << cur_sub_graph_type;
1421   auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1422     kernels, &in_tensors, &out_tensors, cur_sub_graph_type, *context_, context_->get_schema_version());
1423   if (subgraph_kernel == nullptr) {
1424     MS_LOG(ERROR) << "CreateSubGraphKernel failed, cur_sub_graph_type: " << cur_sub_graph_type;
1425     return nullptr;
1426   }
1427   return subgraph_kernel;
1428 }
1429 
ScheduleSubGraphToSubGraphKernels(const int & subgraph_index)1430 std::vector<kernel::KernelExec *> Scheduler::ScheduleSubGraphToSubGraphKernels(const int &subgraph_index) {
1431   if (subgraph_index == kMainSubGraphIndex) {
1432     return ScheduleMainSubGraphToKernels();
1433   }
1434   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1435   if (subgraph_kernel == nullptr) {
1436     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1437     return {};
1438   }
1439   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1440   subgraph_index_subgraph_kernel_map_[subgraph_index] = subgraph_kernel;
1441   return {subgraph_kernel};
1442 }
1443 
ScheduleNodeToKernel(const lite::LiteGraph::Node * src_node,TypeId prefer_data_type)1444 kernel::KernelExec *Scheduler::ScheduleNodeToKernel(const lite::LiteGraph::Node *src_node, TypeId prefer_data_type) {
1445   std::vector<Tensor *> inputs;
1446   std::vector<Tensor *> outputs;
1447   MS_ASSERT(src_node != nullptr);
1448   FindNodeInoutTensors(*src_node, &inputs, &outputs);
1449 
1450   ResetByExecutionPlan(src_node->name_, &prefer_data_type);
1451 
1452   mindspore::kernel::KernelExec *kernel = nullptr;
1453   if (src_model_->model_type_ != mindspore::lite::ModelType_MSLite) {
1454     auto abstract_model_ptr = reinterpret_cast<AbstractBaseModel *>(src_model_);
1455     if (abstract_model_ptr == nullptr) {
1456       MS_LOG(ERROR) << "src model is not abstract base model return nullptr.";
1457       return nullptr;
1458     }
1459     kernel = abstract_model_ptr->FindBackendKernel(inputs, outputs, src_node, context_, prefer_data_type);
1460     if (kernel == nullptr) {
1461       MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_;
1462       return nullptr;
1463     }
1464   } else {
1465     kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type);
1466     if (kernel == nullptr) {
1467       MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_
1468                     << ", type: " << GetPrimitiveTypeName(src_node->primitive_, context_->get_schema_version());
1469       return nullptr;
1470     }
1471   }
1472   op_parameters_[src_node->output_indices_.at(0)] = nullptr;
1473   auto ret = kernel::KernelExecUtil::SetKernelTensorDataType(kernel);
1474   if (ret != RET_OK) {
1475     MS_LOG(ERROR) << "Set tensor data type for kernel " << kernel->name() << std::endl;
1476     delete kernel;
1477     return nullptr;
1478   }
1479   kernel->set_name(src_node->name_);
1480   if (kernel->kernel() != nullptr) {
1481     kernel->kernel()->SetConfig(config_info_);
1482   }
1483   return kernel;
1484 }
1485 
IsControlFlowPattern(const lite::LiteGraph::Node & partial_node)1486 bool Scheduler::IsControlFlowPattern(const lite::LiteGraph::Node &partial_node) {
1487   lite::LiteGraph::Node *partial_node_output = nullptr;
1488   for (auto output_index : partial_node.output_indices_) {
1489     for (auto &node : src_model_->graph_.all_nodes_) {
1490       if (IsContain(node->input_indices_, output_index)) {
1491         partial_node_output = node;
1492         break;
1493       }
1494     }
1495   }
1496 
1497   return partial_node_output != nullptr &&
1498          (IsCallNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1499           IsSwitchNode(partial_node_output->primitive_, context_->get_schema_version()) ||
1500           IsSwitchLayerNode(partial_node_output->primitive_, context_->get_schema_version()));
1501 }
1502 
ScheduleGraphToKernels(std::vector<kernel::KernelExec * > * dst_kernels,TypeId prefer_data_type)1503 int Scheduler::ScheduleGraphToKernels(std::vector<kernel::KernelExec *> *dst_kernels, TypeId prefer_data_type) {
1504   subgraphs_to_schedule_.push_back(kMainSubGraphIndex);
1505   while (!subgraphs_to_schedule_.empty()) {
1506     auto cur_subgraph_index = subgraphs_to_schedule_.front();
1507     subgraphs_to_schedule_.pop_front();
1508     auto kernels = ScheduleSubGraphToSubGraphKernels(cur_subgraph_index);
1509     if (kernels.empty()) {
1510       MS_LOG(ERROR) << "ScheduleSubGraphToSubGraphKernel failed";
1511       return RET_ERROR;
1512     }
1513     std::copy(kernels.begin(), kernels.end(), std::back_inserter(*dst_kernels));
1514   }
1515   return RET_OK;
1516 }
1517 
ScheduleSubGraphToKernels(size_t subgraph_index,std::vector<kernel::KernelExec * > * dst_kernels,std::vector<lite::Tensor * > * in_tensors,std::vector<lite::Tensor * > * out_tensors,TypeId prefer_data_type)1518 int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::KernelExec *> *dst_kernels,
1519                                          std::vector<lite::Tensor *> *in_tensors,
1520                                          std::vector<lite::Tensor *> *out_tensors, TypeId prefer_data_type) {
1521   MS_ASSERT(src_model_ != nullptr);
1522   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
1523   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
1524   MS_ASSERT(dst_kernels != nullptr);
1525   MS_ASSERT(dst_kernels->empty());
1526   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1527   auto ret = RET_OK;
1528   for (auto node_index : subgraph->node_indices_) {
1529     auto node = src_model_->graph_.all_nodes_[node_index];
1530     MS_ASSERT(node != nullptr);
1531     auto *primitive = node->primitive_;
1532     MS_ASSERT(primitive != nullptr);
1533     kernel::KernelExec *kernel = nullptr;
1534 
1535     if (src_model_->model_type_ == ModelType_MSLite && IsPartialNode(primitive, context_->get_schema_version())) {
1536       if (IsControlFlowPattern(*node)) {
1537         kernel = ScheduleNodeToKernel(node, prefer_data_type);
1538         auto partial_subgraph_index = GetPartialGraphIndex(primitive, context_->get_schema_version());
1539         MS_CHECK_TRUE_MSG(control_flow_scheduler_ != nullptr, RET_ERROR, "control flow scheduler is nullptr.");
1540         control_flow_scheduler_->RecordSubgraphCaller(partial_subgraph_index, kernel);
1541         if (SubGraphHasScheduled(partial_subgraph_index)) {
1542           partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1543           MS_LOG(INFO) << "subgraph has scheduled. ";
1544         } else {
1545           SubGraphMarkScheduled(partial_subgraph_index);
1546           partial_kernel_subgraph_index_map_[kernel] = static_cast<size_t>(partial_subgraph_index);
1547           subgraphs_to_schedule_.push_back(partial_subgraph_index);
1548         }
1549       } else {
1550         MS_CHECK_TRUE_MSG(
1551           subgraph_index != static_cast<size_t>(GetPartialGraphIndex(node->primitive_, context_->get_schema_version())),
1552           RET_ERROR, "Unreasonable cycles exist in subgraph.");
1553         kernel = SchedulePartialToKernel(node);
1554       }
1555     } else {
1556       kernel = ScheduleNodeToKernel(node, prefer_data_type);
1557     }
1558     if (kernel == nullptr || ret != RET_OK) {
1559       MS_LOG(ERROR) << "schedule node return nullptr, name: " << node->name_
1560                     << ", type: " << GetPrimitiveTypeName(primitive, context_->get_schema_version());
1561       return RET_ERROR;
1562     }
1563     kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index)));
1564     dst_kernels->emplace_back(kernel);
1565     auto litert_kernel = reinterpret_cast<kernel::Kernel *>(kernel->kernel());
1566     if (MS_UNLIKELY(litert_kernel == nullptr)) {
1567       MS_LOG(ERROR) << "nullptr exist in scheduler.";
1568       return RET_ERROR;
1569     }
1570     primitives_.emplace(litert_kernel, static_cast<const schema::Primitive *>(primitive));
1571   }
1572   if (in_tensors != nullptr) {
1573     std::transform(subgraph->input_indices_.begin(), subgraph->input_indices_.end(), std::back_inserter(*in_tensors),
1574                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1575   }
1576   if (out_tensors != nullptr) {
1577     std::transform(subgraph->output_indices_.begin(), subgraph->output_indices_.end(), std::back_inserter(*out_tensors),
1578                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1579   }
1580   return RET_OK;
1581 }
1582 
1583 namespace {
KernelFitCurrentSubGraphCPUFp32(TypeId data_type)1584 bool KernelFitCurrentSubGraphCPUFp32(TypeId data_type) {
1585   return (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat || data_type == kNumberTypeInt8 ||
1586           data_type == kNumberTypeInt || data_type == kNumberTypeInt32 || data_type == kNumberTypeInt64 ||
1587           data_type == kNumberTypeUInt8 || data_type == kNumberTypeBool);
1588 }
1589 
KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type,const kernel::KernelExec & kernel)1590 bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::KernelExec &kernel) {
1591   switch (subgraph_type) {
1592     case kernel::SubGraphType::kNotSubGraph:
1593     case kernel::SubGraphType::kApuSubGraph:
1594       return false;
1595     case kernel::SubGraphType::kGpuFp16SubGraph:
1596       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1597         return false;
1598       }
1599       return (kernel.desc().data_type != kNumberTypeFloat32);
1600     case kernel::SubGraphType::kGpuFp32SubGraph:
1601       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1602         return false;
1603       }
1604       return (kernel.desc().data_type != kNumberTypeFloat16);
1605     case kernel::SubGraphType::kNpuSubGraph:
1606       return kernel.desc().arch == kernel::KERNEL_ARCH::kNPU;
1607     case kernel::SubGraphType::kCpuFP16SubGraph: {
1608       auto desc = kernel.desc();
1609       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1610         return false;
1611       }
1612 #ifdef ENABLE_FP16
1613       if (desc.data_type == kNumberTypeInt8 || desc.data_type == kNumberTypeUInt8) {
1614         return true;
1615       }
1616 #endif
1617       return (desc.data_type == kNumberTypeFloat16);
1618     }
1619     case kernel::SubGraphType::kCpuFP32SubGraph: {
1620       auto desc = kernel.desc();
1621       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1622         return false;
1623       }
1624       return KernelFitCurrentSubGraphCPUFp32(desc.data_type);
1625     }
1626     default:
1627       return false;
1628   }
1629 }
1630 
FindAllSubGraphKernels(const std::vector<kernel::KernelExec * > & sorted_kernels,const InnerContext & context,size_t * cur_index,int schema_version)1631 kernel::KernelExec *FindAllSubGraphKernels(const std::vector<kernel::KernelExec *> &sorted_kernels,
1632                                            const InnerContext &context, size_t *cur_index, int schema_version) {
1633   std::vector<kernel::KernelExec *> sub_kernels;
1634   sub_kernels.emplace_back(sorted_kernels[*cur_index]);
1635   auto cur_sub_graph_type = GetKernelSubGraphType(sorted_kernels[*cur_index], context);
1636   for (*cur_index = *cur_index + 1; *cur_index < sorted_kernels.size(); ++(*cur_index)) {
1637     auto cur_kernel = sorted_kernels[*cur_index];
1638     MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph);
1639     // already a subgraph or a delegate
1640     if (cur_kernel->desc().arch == kernel::kDelegate) {
1641       --(*cur_index);
1642       break;
1643     }
1644     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph ||
1645         !KernelFitCurrentSubGraph(cur_sub_graph_type, *cur_kernel)) {
1646       --(*cur_index);
1647       break;
1648     }
1649     sub_kernels.emplace_back(cur_kernel);
1650   }
1651   return kernel::KernelExecUtil::CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type, context,
1652                                                       schema_version);
1653 }
1654 }  // namespace
1655 
ConstructNormalSubGraphs(const std::vector<kernel::KernelExec * > & src_kernel,std::vector<kernel::KernelExec * > * dst_kernel,std::map<const kernel::KernelExec *,bool> * is_kernel_finish)1656 int Scheduler::ConstructNormalSubGraphs(const std::vector<kernel::KernelExec *> &src_kernel,
1657                                         std::vector<kernel::KernelExec *> *dst_kernel,
1658                                         std::map<const kernel::KernelExec *, bool> *is_kernel_finish) {
1659   if (src_kernel.empty()) {
1660     return RET_OK;
1661   }
1662 
1663   // construct subgraph
1664   for (size_t index = 0; index < src_kernel.size(); index++) {
1665     auto cur_kernel = src_kernel[index];
1666     MS_ASSERT(cur_kernel != nullptr);
1667     // Not support APU now
1668     MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph);
1669     if (cur_kernel->desc().arch == kernel::kDelegate) {
1670       dst_kernel->emplace_back(cur_kernel);
1671       continue;
1672     }
1673     // already a subgraph or a delegate
1674     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
1675       dst_kernel->emplace_back(cur_kernel);
1676       continue;
1677     }
1678     auto subgraph = FindAllSubGraphKernels(src_kernel, *context_, &index, context_->get_schema_version());
1679     if (subgraph == nullptr) {
1680       MS_LOG(ERROR) << "Create SubGraphKernel failed";
1681       return RET_ERROR;
1682     }
1683     dst_kernel->emplace_back(subgraph);
1684   }
1685   for (auto *subgraph : *dst_kernel) {
1686     if (subgraph->desc().arch == kernel::kDelegate) {
1687       *infer_along_running_ = false;
1688       continue;
1689     }
1690     if (subgraph->subgraph_type() != kernel::kCpuFP32SubGraph &&
1691         subgraph->subgraph_type() != kernel::kCpuFP16SubGraph) {
1692       *infer_along_running_ = false;
1693     }
1694     auto subgraph_kernel = static_cast<kernel::SubGraphKernel *>(subgraph);
1695     if (subgraph_kernel == nullptr) {
1696       MS_LOG(ERROR) << "kernel: " << subgraph->name() << " not is subgraph kernel.";
1697       return RET_ERROR;
1698     }
1699     // this is for train session cpu fp16, should be removed in the future.
1700     auto ret = subgraph_kernel->SetFp16Attr();
1701     if (ret != RET_OK) {
1702       MS_LOG(ERROR) << "Init SubGraph failed: " << ret;
1703       return ret;
1704     }
1705   }
1706   return RET_OK;
1707 }
1708 
GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor * > & in_tensors)1709 TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) {
1710   for (const auto &tensor : in_tensors) {
1711     auto dtype = tensor->data_type();
1712     if (dtype == kObjectTypeTensorType) {
1713       return TensorListDataType(tensor);
1714     }
1715     std::unordered_set<TypeId> type_set = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8,  kNumberTypeInt32,
1716                                            kNumberTypeBool,    kNumberTypeUInt8,   kObjectTypeString};
1717     if (type_set.find(dtype) != type_set.end()) {
1718       return dtype;
1719     }
1720   }
1721   if (in_tensors.empty()) {
1722     MS_LOG(ERROR) << "in tensor is empty.";
1723     return kTypeUnknown;
1724   }
1725   MS_ASSERT(!in_tensors.empty());
1726   return in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
1727 }
1728 
PartialSubGraphType(const std::vector<kernel::KernelExec * > & kernels)1729 kernel::SubGraphType Scheduler::PartialSubGraphType(const std::vector<kernel::KernelExec *> &kernels) {
1730   if (std::any_of(kernels.begin(), kernels.end(),
1731                   [](const kernel::KernelExec *item) { return item->desc().data_type == kNumberTypeFloat16; })) {
1732     return kernel::kCpuFP16SubGraph;
1733   }
1734   return kernel::kCpuFP32SubGraph;
1735 }
1736 
InferSwitchShape(const lite::LiteGraph::Node * switch_node)1737 int Scheduler::InferSwitchShape(const lite::LiteGraph::Node *switch_node) {
1738   MS_ASSERT(src_model_ != nullptr);
1739   MS_ASSERT(switch_node != nullptr);
1740   std::deque<lite::LiteGraph::Node *> partial_cnode_to_infer{};
1741   for (size_t i = 1; i < switch_node->input_indices_.size(); ++i) {
1742     auto branch_output_index = switch_node->input_indices_.at(i);
1743     for (auto &node : src_model_->graph_.all_nodes_) {
1744       if (IsContain(node->output_indices_, branch_output_index) &&
1745           IsPartialNode(node->primitive_, context_->get_schema_version()) &&
1746           partial_cnode_inferred_.find(node) == partial_cnode_inferred_.end()) {
1747         partial_cnode_inferred_.insert(node);
1748         partial_cnode_to_infer.push_back(node);
1749         break;
1750       }
1751     }
1752   }
1753 
1754   while (!partial_cnode_to_infer.empty()) {
1755     auto &node = partial_cnode_to_infer.front();
1756     partial_cnode_to_infer.pop_front();
1757     int ret = InferPartialShape(node);
1758     if (ret != RET_OK) {
1759       MS_LOG(WARNING) << "partial infer not ok, ret: " << ret;
1760     }
1761   }
1762   return RET_OK;
1763 }
1764 
NodeInputIsSwitchType(const lite::LiteGraph::Node * node)1765 LiteGraph::Node *Scheduler::NodeInputIsSwitchType(const lite::LiteGraph::Node *node) {
1766   MS_ASSERT(src_model_ != nullptr);
1767   MS_ASSERT(node != nullptr);
1768   for (auto &iter : src_model_->graph_.all_nodes_) {
1769     if (iter->output_indices_ == node->input_indices_) {
1770       if (IsSwitchNode(iter->primitive_, context_->get_schema_version()) ||
1771           IsSwitchLayerNode(iter->primitive_, context_->get_schema_version())) {
1772         return iter;
1773       } else {
1774         return nullptr;
1775       }
1776     }
1777   }
1778   return nullptr;
1779 }
1780 
SubGraphHasScheduled(const int & index)1781 bool Scheduler::SubGraphHasScheduled(const int &index) {
1782   return scheduled_subgraph_index_.find(index) != scheduled_subgraph_index_.end();
1783 }
1784 
SubGraphMarkScheduled(const int & index)1785 void Scheduler::SubGraphMarkScheduled(const int &index) { scheduled_subgraph_index_.insert(index); }
1786 
ConstructControlFlowMainGraph(std::vector<kernel::KernelExec * > * kernels)1787 int Scheduler::ConstructControlFlowMainGraph(std::vector<kernel::KernelExec *> *kernels) {
1788   auto back_kernels = *kernels;
1789   kernels->clear();
1790   std::vector<kernel::KernelExec *> main_graph_kernels{};
1791   for (auto &kernel : back_kernels) {
1792     if (kernel->subgraph_type() != kernel::kNotSubGraph) {
1793       kernels->push_back(kernel);
1794     } else {
1795       main_graph_kernels.push_back(kernel);
1796     }
1797   }
1798   auto cur_subgraph_type = PartialSubGraphType(main_graph_kernels);
1799   auto subgraph_kernel = kernel::KernelExecUtil::CreateSubGraphKernel(
1800     main_graph_kernels, nullptr, nullptr, cur_subgraph_type, *context_, context_->get_schema_version());
1801   if (subgraph_kernel == nullptr) {
1802     MS_LOG(ERROR) << "create main graph for control flow model failed.";
1803     return RET_ERROR;
1804   }
1805   kernels->insert(kernels->begin(), subgraph_kernel);
1806   return RET_OK;
1807 }
1808 
NonTailCallNodes()1809 std::vector<kernel::KernelExec *> Scheduler::NonTailCallNodes() {
1810   std::vector<kernel::KernelExec *> ret{};
1811   if (*is_control_flow_) {
1812     ret = control_flow_scheduler_->GetNonTailCalls();
1813   }
1814   return ret;
1815 }
1816 }  // namespace mindspore::lite
1817