• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/scheduler.h"
18 #include <map>
19 #include <set>
20 #include <queue>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #ifndef CONTROLFLOW_TENSORLIST_CLIP
25 #include "src/tensorlist.h"
26 #endif
27 #include "include/errorcode.h"
28 #include "src/common/graph_util.h"
29 #include "src/common/utils.h"
30 #include "src/kernel_registry.h"
31 #include "include/registry/register_kernel.h"
32 #include "src/lite_kernel_util.h"
33 #include "src/sub_graph_kernel.h"
34 #include "src/ops/populate/populate_register.h"
35 #include "src/common/version_manager.h"
36 #include "src/common/prim_util.h"
37 #include "src/lite_model.h"
38 #include "src/common/tensor_util.h"
39 #include "src/common/context_util.h"
40 #include "src/runtime/infer_manager.h"
41 #ifndef RUNTIME_PASS_CLIP
42 #include "src/runtime/runtime_pass.h"
43 #endif
44 #ifndef AUTO_PARALLEL_CLIP
45 #include "src/sub_graph_split.h"
46 #endif
47 #ifndef WEIGHT_DECODE_CLIP
48 #include "src/weight_decoder.h"
49 #endif
50 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
51 #include "nnacl/nnacl_common.h"
52 #if GPU_OPENCL
53 #include "src/runtime/kernel/opencl/opencl_subgraph.h"
54 #include "src/runtime/gpu/opencl/opencl_runtime.h"
55 #endif
56 #include "include/registry/register_kernel_interface.h"
57 #ifndef CONTROLFLOW_TENSORLIST_CLIP
58 #include "src/runtime/kernel/arm/base/partial_fusion.h"
59 #endif
60 #ifdef SUPPORT_NNRT
61 #include "src/delegate/nnrt/nnrt_delegate.h"
62 #endif
63 
64 namespace mindspore::lite {
65 namespace {
66 constexpr int kMainSubGraphIndex = 0;
CreateCustomSubGraph(std::vector<kernel::LiteKernel * > && input_kernels,std::vector<kernel::LiteKernel * > && output_kernels,const std::vector<kernel::LiteKernel * > & kernels,kernel::Kernel * kernel)67 kernel::SubGraphKernel *CreateCustomSubGraph(std::vector<kernel::LiteKernel *> &&input_kernels,
68                                              std::vector<kernel::LiteKernel *> &&output_kernels,
69                                              const std::vector<kernel::LiteKernel *> &kernels, kernel::Kernel *kernel) {
70   auto sub_kernel = new (std::nothrow) kernel::CustomSubGraph(input_kernels, output_kernels, kernels, kernel);
71   if (sub_kernel == nullptr) {
72     MS_LOG(ERROR) << "create custom subgraph failed!";
73     delete kernel;
74     return nullptr;
75   }
76   return sub_kernel;
77 }
78 }  // namespace
79 
80 namespace {
81 // support_fp16: current device and package support float16
CastConstTensorData(Tensor * tensor,TypeId dst_data_type,bool support_fp16)82 int CastConstTensorData(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
83   MS_ASSERT(tensor != nullptr);
84   MS_ASSERT(tensor->IsConst());
85   MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
86   MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16);
87   if (tensor->data_type() == dst_data_type) {
88     return RET_OK;
89   }
90   auto origin_own_data = tensor->own_data();
91   auto origin_dt = tensor->data_type();
92   auto origin_data = tensor->data();
93   MS_ASSERT(origin_data != nullptr);
94   tensor->set_data(nullptr);
95   tensor->set_data_type(dst_data_type);
96   auto ret = tensor->MallocData();
97   if (RET_OK != ret) {
98     MS_LOG(ERROR) << "malloc data failed";
99     // reset tensor
100     tensor->set_data(origin_data);
101     tensor->set_data_type(origin_dt);
102     tensor->set_own_data(origin_own_data);
103     return ret;
104   }
105   auto new_tensor_data = tensor->data();
106   MS_ASSERT(new_tensor_data != nullptr);
107   if (dst_data_type == kNumberTypeFloat32) {
108     Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
109   } else {  // dst_data_type == kNumberTypeFloat16
110     Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
111   }
112   if (origin_data != nullptr && origin_own_data) {
113     if (tensor->allocator() == nullptr) {
114       free(origin_data);
115     } else {
116       tensor->allocator()->Free(origin_data);
117     }
118   }
119   return RET_OK;
120 }
121 
122 // support_fp16: current device and package support float16
CastKernelWeight(kernel::SubGraphType belong_subgraph_type,kernel::LiteKernel * kernel,bool support_fp16)123 int CastKernelWeight(kernel::SubGraphType belong_subgraph_type, kernel::LiteKernel *kernel, bool support_fp16) {
124   MS_ASSERT(kernel != nullptr);
125   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
126   if (belong_subgraph_type != kernel::kCpuFP32SubGraph && belong_subgraph_type != kernel::kCpuFP16SubGraph) {
127     return RET_OK;
128   }
129   for (auto *tensor : kernel->in_tensors()) {
130     MS_ASSERT(tensor != nullptr);
131     // only cast const tensor
132     // tensorlist not support fp16 now
133     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
134       continue;
135     }
136     // only support fp32->fp16 or fp16->fp32
137     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
138       continue;
139     }
140     if (tensor->data_type() == kNumberTypeFloat32 && belong_subgraph_type == kernel::kCpuFP16SubGraph) {
141       auto ret = CastConstTensorData(tensor, kNumberTypeFloat16, support_fp16);
142       if (ret != RET_OK) {
143         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
144         return ret;
145       }
146     } else if (tensor->data_type() == kNumberTypeFloat16 && belong_subgraph_type == kernel::kCpuFP32SubGraph) {
147       auto ret = CastConstTensorData(tensor, kNumberTypeFloat32, support_fp16);
148       if (ret != RET_OK) {
149         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
150         return ret;
151       }
152     } else {
153       MS_LOG(DEBUG) << "No need to cast";
154     }
155   }
156   return RET_OK;
157 }
158 
CopyConstTensorData(const std::vector<Tensor * > & tensors,int op_type)159 int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
160   // packed kernels such as conv don't need to copy because weight will be packed in kernel
161   if (IsPackedOp(op_type)) {
162     return RET_OK;
163   }
164   for (auto *tensor : tensors) {
165     // only copy non-copied const tensor
166     if (!tensor->IsConst() || tensor->own_data()) {
167       continue;
168     }
169     if (tensor->data_type() == kObjectTypeTensorType) {
170       // tensorlist's data is nullptr since ConvertTensors
171       // we never set or malloc data of tensorlist but malloc tensors in tensorlist
172       MS_ASSERT(tensor->data() == nullptr);
173     } else {
174       auto copy_tensor = Tensor::CopyTensor(*tensor, true);
175       if (copy_tensor == nullptr) {
176         MS_LOG(ERROR) << "Copy tensor failed";
177         return RET_ERROR;
178       }
179       tensor->FreeData();
180       tensor->set_data(copy_tensor->data());
181       tensor->set_own_data(true);
182       copy_tensor->set_data(nullptr);
183       delete (copy_tensor);
184     }
185   }
186   return RET_OK;
187 }
188 }  // namespace
189 
190 // support_fp16: current device and package support float16
HandleBuildinCpuKernelWeight(kernel::SubGraphType belong_subgraph_type,kernel::LiteKernel * kernel)191 int Scheduler::HandleBuildinCpuKernelWeight(kernel::SubGraphType belong_subgraph_type, kernel::LiteKernel *kernel) {
192   MS_ASSERT(kernel != nullptr);
193   MS_ASSERT(kernel->subgraph_type() == kernel::kNotSubGraph);
194   if (is_train_session_ || kernel->type() == schema::PrimitiveType_Custom ||
195       kernel->desc().provider != kernel::kBuiltin) {
196     return RET_OK;
197   }
198   auto ret = CastKernelWeight(belong_subgraph_type, kernel, context_->device_and_pkg_support_fp16());
199   if (ret != RET_OK) {
200     MS_LOG(DEBUG) << "CastKernelWeight failed: " << ret;
201     return RET_NOT_SUPPORT;
202   }
203   if (!(reinterpret_cast<LiteModel *>(src_model_)->keep_model_buf())) {
204     // we don't need to restore tensor for copy data
205     ret = CopyConstTensorData(kernel->in_tensors(), kernel->op_parameter()->type_);
206     if (ret != RET_OK) {
207       MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
208       return RET_NOT_SUPPORT;
209     }
210   }
211   return RET_OK;
212 }
213 
InitKernels(std::vector<kernel::LiteKernel * > dst_kernels)214 int Scheduler::InitKernels(std::vector<kernel::LiteKernel *> dst_kernels) {
215   if (is_train_session_) {
216     return RET_OK;
217   }
218   for (auto kernel : dst_kernels) {
219 #ifndef DELEGATE_CLIP
220     // delegate graph kernel
221     if (kernel->desc().arch == kernel::kDelegate) {
222       continue;
223     }
224 #endif
225     auto subgraph_type = kernel->subgraph_type();
226     if (subgraph_type == kernel::kNotSubGraph) {
227       MS_LOG(ERROR) << "construct subgraph failed.";
228       return RET_ERROR;
229     }
230     auto subgraph_nodes = reinterpret_cast<kernel::SubGraphKernel *>(kernel)->nodes();
231     for (auto node : subgraph_nodes) {
232       auto ret = HandleBuildinCpuKernelWeight(subgraph_type, node);
233       if (ret != RET_OK) {
234         return ret;
235       }
236       ret = node->Init();
237       if (ret != RET_OK) {
238         MS_LOG(ERROR) << "Kernel " << node->name() << " Init failed.";
239         return ret;
240       }
241     }
242   }
243   return RET_OK;
244 }
245 
SchedulePreProcess()246 int Scheduler::SchedulePreProcess() {
247   this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
248 
249   int infershape_ret = InferSubGraphShape(kMainSubGraphIndex);
250   if (infershape_ret != RET_OK && infershape_ret != RET_INFER_INVALID) {
251     MS_LOG(ERROR) << "op infer shape failed.";
252     return infershape_ret;
253   }
254 
255   if (context_->enable_parallel_ && infershape_ret != RET_INFER_INVALID) {
256 #ifndef AUTO_PARALLEL_CLIP
257     auto search_sub_graph =
258       SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
259     search_sub_graph.SubGraphSplit();
260 #else
261     MS_LOG(ERROR) << unsupport_auto_parallel_log;
262     return RET_NOT_SUPPORT;
263 #endif
264   }
265   return RET_OK;
266 }
267 
CheckCpuValid(std::vector<kernel::LiteKernel * > * dst_kernels)268 int Scheduler::CheckCpuValid(std::vector<kernel::LiteKernel *> *dst_kernels) {
269   if (context_->IsCpuEnabled() == true) {
270     return RET_OK;
271   }
272   for (auto kernel : *dst_kernels) {
273     if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
274       MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
275       return RET_ERROR;
276     }
277   }
278   return RET_OK;
279 }
280 
Schedule(std::vector<kernel::LiteKernel * > * dst_kernels)281 int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
282   int check_input_ret = CheckInputParam(dst_kernels);
283   if (check_input_ret != RET_OK) {
284     MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
285     return check_input_ret;
286   }
287 
288   schema_version_ = reinterpret_cast<LiteModel *>(src_model_)->GetSchemaVersion();
289 
290   int ret = SchedulePreProcess();
291   if (ret != RET_OK) {
292     return ret;
293   }
294 
295   ret = ScheduleGraphToKernels(dst_kernels);
296   FreeOpParameters();
297   op_parameters_.clear();
298   if (ret != RET_OK) {
299     MS_LOG(ERROR) << "Schedule graph to kernels failed.";
300     return ret;
301   }
302 
303 #ifndef CONTROLFLOW_TENSORLIST_CLIP
304   SetSubgraphForPartialNode();
305 #endif
306 
307 #ifndef DELEGATE_CLIP
308   ret = InitDelegateKernels(dst_kernels);
309   if (ret != RET_OK) {
310     MS_LOG(ERROR) << "Repalce delegate kernels failed.";
311     return ret;
312   }
313   context_->thread_pool()->SetSpinCountMinValue();
314 #endif
315 
316   ret = CheckCpuValid(dst_kernels);
317   if (ret != RET_OK) {
318     MS_LOG(ERROR) << "kernels invalid in set devices.";
319     return ret;
320   }
321 
322   kernel::LiteKernelUtil::FindAllInoutKernels(*dst_kernels);
323 
324 #ifndef CONTROLFLOW_TENSORLIST_CLIP
325   if (IsControlFlowParttern(*dst_kernels)) {
326     ret = ConstructControlFlowMainGraph(dst_kernels);
327     if (ret != RET_OK) {
328       MS_LOG(ERROR) << "ConstructControlFlowMainGraph failed.";
329       return ret;
330     }
331   } else {
332 #endif
333     auto src_kernel = *dst_kernels;
334     dst_kernels->clear();
335     std::map<const kernel::LiteKernel *, bool> is_kernel_finish;
336     ret = ConstructSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
337     if (ret != RET_OK) {
338       MS_LOG(ERROR) << "ConstructSubGraphs failed.";
339       return ret;
340     }
341 #ifndef CONTROLFLOW_TENSORLIST_CLIP
342   }
343 #endif
344 
345 #ifndef RUNTIME_PASS_CLIP
346   RuntimePass(dst_kernels, src_tensors_);
347 #endif
348 
349   ret = InitKernels(*dst_kernels);
350   if (ret != RET_OK) {
351     MS_LOG(ERROR) << "InitKernels failed.";
352     return ret;
353   }
354 
355   MS_LOG(DEBUG) << "schedule kernels success.";
356   for (auto subgraph : *dst_kernels) {
357     MS_LOG(DEBUG) << "[subgraph] : " << subgraph->name() << ",  type:" << subgraph->subgraph_type();
358     if (subgraph->desc().arch == kernel::KERNEL_ARCH::kDelegate) {
359       continue;
360     }
361     std::vector<kernel ::LiteKernel *> kernel_list = reinterpret_cast<kernel::SubGraphKernel *>(subgraph)->nodes();
362     for (auto kernel : kernel_list) {
363       MS_LOG(DEBUG) << "kernel: [" << kernel->name() << "] get TypeId(" << kernel->desc().data_type
364                     << ") op success. op_type: " << PrimitiveCurVersionTypeName(kernel->desc().type)
365                     << ", arch: " << kernel->desc().arch;
366     }
367   }
368   return RET_OK;
369 }
370 
CheckInputParam(std::vector<kernel::LiteKernel * > * dst_kernels)371 int Scheduler::CheckInputParam(std::vector<kernel::LiteKernel *> *dst_kernels) {
372   if (dst_kernels == nullptr) {
373     return RET_ERROR;
374   }
375   if (src_model_ == nullptr) {
376     MS_LOG(ERROR) << "Input model is nullptr";
377     return RET_PARAM_INVALID;
378   }
379   if (src_model_->graph_.sub_graphs_.empty()) {
380     MS_LOG(ERROR) << "Model should have a subgraph at least";
381     return RET_PARAM_INVALID;
382   }
383   return RET_OK;
384 }
385 
386 #ifndef DELEGATE_CLIP
ReplaceDelegateKernels(std::vector<kernel::LiteKernel * > * dst_kernels)387 int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels) {
388   std::vector<kernel::Kernel *> kernels;
389   for (size_t i = 0; i < dst_kernels->size(); i++) {
390     kernels.push_back((*dst_kernels)[i]->kernel());
391   }
392 
393   ms_inputs_ = LiteTensorsToMSTensors(inputs_);
394   ms_outputs_ = LiteTensorsToMSTensors(outputs_);
395   auto schema_version = static_cast<SchemaVersion>(schema_version_);
396   DelegateModel<schema::Primitive> *model =
397     new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
398   if (model == nullptr) {
399     MS_LOG(ERROR) << "New delegate model failed.";
400     return RET_NULL_PTR;
401   }
402 #ifdef SUPPORT_NNRT
403   if (context_->IsNNRtEnabled()) {
404     auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
405     delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
406   }
407 #endif
408   auto ret = delegate_->Build(model);
409   if (ret != mindspore::kSuccess) {
410     delete model;
411     MS_LOG(ERROR) << "Delegate prepare kernels failed.";
412     return RET_ERROR;
413   }
414   MS_LOG(INFO) << "Delegate build end.";
415   auto src_kernels = *dst_kernels;
416   dst_kernels->clear();
417   std::map<const kernel::LiteKernel *, bool> delegate_support;
418   for (auto kernel : src_kernels) {
419     delegate_support[kernel] = true;
420   }
421   for (auto kernel : kernels) {
422     size_t index = 0;
423     for (; index < src_kernels.size(); index++) {
424       if (kernel == src_kernels[index]->kernel()) {
425         // Kernels that the delegate does not support keep the original backend
426         dst_kernels->push_back(src_kernels[index]);
427         delegate_support[src_kernels[index]] = false;
428         break;
429       }
430     }
431     if (index == src_kernels.size()) {
432       // New liteKernel to save delegate subgraph
433       std::shared_ptr<kernel::Kernel> shared_kernel(kernel);
434       auto lite_kernel = new (std::nothrow) kernel::LiteKernel(shared_kernel);
435       if (lite_kernel == nullptr) {
436         delete model;
437         MS_LOG(ERROR) << "New LiteKernel for delegate subgraph failed.";
438         return RET_NULL_PTR;
439       }
440       auto delegate_type = kNumberTypeFloat32;
441       for (auto &input : kernel->inputs()) {
442         if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) {
443           delegate_type = kNumberTypeFloat16;
444           break;
445         }
446       }
447       kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, schema::PrimitiveType_NONE, "", ""};
448       lite_kernel->set_desc(delegate_desc);
449       dst_kernels->push_back(lite_kernel);
450     }
451   }
452   // Release the cpu kernel that has been replace by delegate subgraph
453   for (auto kernel : src_kernels) {
454     if (delegate_support[kernel] == true) {
455       delete kernel;
456     }
457   }
458   delete model;
459   return RET_OK;
460 }
461 
InitDelegateKernels(std::vector<kernel::LiteKernel * > * dst_kernels)462 int Scheduler::InitDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels) {
463   /* no delegate valid */
464   if (delegate_ == nullptr) {
465     return RET_OK;
466   }
467   /* external delegate */
468   if (delegate_device_type_ == -1) {
469     auto ret = ReplaceDelegateKernels(dst_kernels);
470     if (ret != RET_OK) {
471       MS_LOG(ERROR) << "external delegate init failed.";
472       return ret;
473     }
474   }
475 
476   /* Inner delegate  :  check Priority */
477   std::vector<kernel::LiteKernel *> src_kernels = *dst_kernels;
478   dst_kernels->clear();
479 
480   while (!src_kernels.empty()) {
481     std::vector<kernel::LiteKernel *> tmp_kernels;
482     kernel::LiteKernel *remain_kernel = nullptr;
483 
484     /* Loop for inner delegate npu and TensorRT subgraph */
485     while (!src_kernels.empty()) {
486       auto kernel = src_kernels.front();
487       VectorErase(&src_kernels, kernel);
488       bool priority_ret =
489         DeviceTypePriority(context_, delegate_device_type_, KernelArchToDeviceType(kernel->desc().arch));
490       if (priority_ret == true) {
491         tmp_kernels.push_back(kernel);
492       } else {
493         remain_kernel = kernel;
494         break;
495       }
496     }
497 
498     /* start current NPU-kernels replace */
499     if (tmp_kernels.empty()) {
500       if (remain_kernel != nullptr) {
501         dst_kernels->push_back(remain_kernel);
502         remain_kernel = nullptr;
503       }
504       continue;
505     }
506     auto ret = ReplaceDelegateKernels(&tmp_kernels);
507     if (ret != RET_OK) {
508       MS_LOG(ERROR) << "NPU delegate repalce delegate kernels failed.";
509       return ret;
510     }
511 
512     dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
513     tmp_kernels.clear();
514     if (remain_kernel != nullptr) {
515       dst_kernels->push_back(remain_kernel);
516       remain_kernel = nullptr;
517     }
518   }
519 
520   return RET_OK;
521 }
522 #endif
523 
FindNodeInoutTensors(const lite::LiteGraph::Node & node,std::vector<Tensor * > * inputs,std::vector<Tensor * > * outputs)524 void Scheduler::FindNodeInoutTensors(const lite::LiteGraph::Node &node, std::vector<Tensor *> *inputs,
525                                      std::vector<Tensor *> *outputs) {
526   MS_ASSERT(inputs != nullptr);
527   MS_ASSERT(outputs != nullptr);
528   auto in_size = node.input_indices_.size();
529   inputs->reserve(in_size);
530   for (size_t j = 0; j < in_size; ++j) {
531     inputs->emplace_back(src_tensors_->at(node.input_indices_[j]));
532   }
533   auto out_size = node.output_indices_.size();
534   outputs->reserve(out_size);
535   for (size_t j = 0; j < out_size; ++j) {
536     outputs->emplace_back(src_tensors_->at(node.output_indices_[j]));
537   }
538 }
539 
InferNodeShape(const lite::LiteGraph::Node * node)540 int Scheduler::InferNodeShape(const lite::LiteGraph::Node *node) {
541   MS_ASSERT(node != nullptr);
542   auto primitive = node->primitive_;
543   MS_ASSERT(primitive != nullptr);
544   std::vector<Tensor *> inputs;
545   std::vector<Tensor *> outputs;
546   FindNodeInoutTensors(*node, &inputs, &outputs);
547   int ret;
548 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
549   ret = KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders(), schema_version_);
550   if (ret != RET_NOT_SUPPORT) {
551     return ret;
552   }
553 #endif
554 
555   auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(
556     GetPrimitiveType(node->primitive_, schema_version_), schema_version_);
557   if (parame_gen == nullptr) {
558     MS_LOG(ERROR) << "parameter generator is nullptr.";
559     FreeOpParameters();
560     return RET_NULL_PTR;
561   }
562   auto parameter = parame_gen(primitive);
563   if (parameter == nullptr) {
564     MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << GetPrimitiveTypeName(primitive, schema_version_);
565     FreeOpParameters();
566     return RET_ERROR;
567   }
568 
569   parameter->quant_type_ = node->quant_type_;
570   parameter->thread_num_ = context_->thread_num_;
571 
572   if (op_parameters_.find(node->output_indices_.at(0)) != op_parameters_.end()) {
573     free(parameter);
574     parameter = op_parameters_[node->output_indices_.at(0)];
575   } else {
576     op_parameters_[node->output_indices_.at(0)] = parameter;
577   }
578 
579   if (IsCallNode(primitive, schema_version_)) {
580     return InferCallShape(node);
581   }
582   ret = KernelInferShape(inputs, outputs, parameter);
583 
584 #ifndef CONTROLFLOW_TENSORLIST_CLIP
585   bool not_able_to_infer = false;
586   for (auto &input : inputs) {
587     if (input->data_type() == kObjectTypeTensorType) {
588       not_able_to_infer = true;
589       break;
590     }
591   }
592 
593   if (not_able_to_infer) {
594     for (auto &output : outputs) {
595       output->set_shape({-1});
596     }
597     MS_LOG(ERROR) << "RET_INFER_INVALID";
598     return RET_INFER_INVALID;
599   }
600 #endif
601 
602   if (ret == RET_OK) {
603     for (auto &output : outputs) {
604       if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
605         MS_LOG(ERROR) << "The size of output tensor is too big";
606         FreeOpParameters();
607         return RET_ERROR;
608       }
609     }
610   } else if (ret != RET_INFER_INVALID) {
611     FreeOpParameters();
612     MS_LOG(ERROR) << "RET_INFER_INVALID";
613     return RET_ERROR;
614   }
615   return ret;
616 }
617 
FreeOpParameters()618 void Scheduler::FreeOpParameters() {
619   for (auto &param : op_parameters_) {
620     if (param.second != nullptr) {
621       free(param.second);
622       param.second = nullptr;
623     }
624   }
625 }
626 
RestoreSubGraphInput(const lite::LiteGraph::Node * partial_node)627 int Scheduler::RestoreSubGraphInput(const lite::LiteGraph::Node *partial_node) {
628   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, schema_version_);
629   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
630   for (size_t i = 0; i < subgraph->input_indices_.size(); ++i) {
631     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
632     subgraph_input->set_data(nullptr);
633   }
634   return RET_OK;
635 }
636 
CopyCommonTensor(Tensor * dst_tensor,Tensor * src_tensor)637 void CopyCommonTensor(Tensor *dst_tensor, Tensor *src_tensor) {
638   dst_tensor->set_data_type(src_tensor->data_type());
639   dst_tensor->set_shape(src_tensor->shape());
640   dst_tensor->set_format(src_tensor->format());
641   dst_tensor->set_data(src_tensor->data());
642 }
643 
CopyPartialShapeToSubGraph(const lite::LiteGraph::Node * partial_node)644 int Scheduler::CopyPartialShapeToSubGraph(const lite::LiteGraph::Node *partial_node) {
645   auto subgraph_index = GetPartialGraphIndex(partial_node->primitive_, schema_version_);
646   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
647   if (subgraph->input_indices_.size() != partial_node->input_indices_.size()) {
648     MS_LOG(ERROR) << "partial node " << partial_node->name_ << " inputs size: " << partial_node->input_indices_.size()
649                   << " vs "
650                   << " subgraph input size: " << subgraph->input_indices_.size();
651     return RET_PARAM_INVALID;
652   }
653 
654   for (size_t i = 0; i < partial_node->input_indices_.size(); ++i) {
655     auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
656     auto &partial_input = src_tensors_->at(partial_node->input_indices_[i]);
657     switch (partial_input->data_type()) {
658       case kObjectTypeTensorType: {
659         return RET_INFER_INVALID;
660       }
661       default: {
662         CopyCommonTensor(subgraph_input, partial_input);
663         break;
664       }
665     }
666   }
667 
668   return RET_OK;
669 }
670 
InferPartialShape(const lite::LiteGraph::Node * node)671 int Scheduler::InferPartialShape(const lite::LiteGraph::Node *node) {
672   MS_ASSERT(src_model_ != nullptr);
673   MS_ASSERT(node != nullptr);
674   if (!IsPartialNode(node->primitive_, schema_version_)) {
675     MS_LOG(ERROR) << "Node is not a partial";
676     return RET_PARAM_INVALID;
677   }
678   CopyPartialShapeToSubGraph(node);
679   int subgraph_index = GetPartialGraphIndex(node->primitive_, schema_version_);
680   auto ret = InferSubGraphShape(subgraph_index);
681   if (ret != RET_OK) {
682     MS_LOG(WARNING) << "infer subgraph: " << subgraph_index << " failed, ret:" << ret;
683   }
684   RestoreSubGraphInput(node);
685   return ret;
686 }
687 
NodeInputIsPartial(const lite::LiteGraph::Node * node)688 LiteGraph::Node *Scheduler::NodeInputIsPartial(const lite::LiteGraph::Node *node) {
689   MS_ASSERT(src_model_ != nullptr);
690   MS_ASSERT(node != nullptr);
691   for (auto &iter : src_model_->graph_.all_nodes_) {
692     if (iter->output_indices_ == node->input_indices_) {
693       if (IsPartialNode(iter->primitive_, schema_version_)) {
694         return iter;
695       } else {
696         return nullptr;
697       }
698     }
699   }
700   return nullptr;
701 }
702 
InferCallShape(const lite::LiteGraph::Node * node)703 int Scheduler::InferCallShape(const lite::LiteGraph::Node *node) {
704   MS_ASSERT(src_model_ != nullptr);
705   MS_ASSERT(node != nullptr);
706   if (!IsCallNode(node->primitive_, schema_version_)) {
707     MS_LOG(ERROR) << "Node is not a call cnode";
708     return RET_PARAM_INVALID;
709   }
710 
711   auto partial_input = NodeInputIsPartial(node);
712   if (partial_input) {
713     return InferPartialShape(partial_input);
714   }
715 #ifndef CONTROLFLOW_TENSORLIST_CLIP
716   auto switch_input = NodeInputIsSwitch(node);
717   if (switch_input) {
718     return InferSwitchShape(switch_input);
719   }
720 #endif
721 
722   MS_LOG(ERROR) << "call input is not partial and also not switch.";
723   return RET_ERROR;
724 }
725 
InferSubGraphShape(size_t subgraph_index)726 int Scheduler::InferSubGraphShape(size_t subgraph_index) {
727   MS_ASSERT(src_model_ != nullptr);
728   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
729   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
730   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
731   int subgraph_infershape_ret = RET_OK;
732   for (auto node_index : subgraph->node_indices_) {
733     auto node = src_model_->graph_.all_nodes_[node_index];
734     MS_ASSERT(node != nullptr);
735     auto *primitive = node->primitive_;
736     if (primitive == nullptr) {
737       MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
738       return RET_ERROR;
739     }
740     auto ret = InferNodeShape(node);
741     if (ret == RET_INFER_INVALID) {
742       MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_
743                    << ", type: " << GetPrimitiveTypeName(primitive, schema_version_) << ", set infer flag to false.";
744       subgraph_infershape_ret = RET_INFER_INVALID;
745     } else if (ret != RET_OK) {
746       MS_LOG(ERROR) << "InferShape failed, name: " << node->name_
747                     << ", type: " << GetPrimitiveTypeName(primitive, schema_version_);
748       return RET_INFER_ERR;
749     }
750   }
751   return subgraph_infershape_ret;
752 }
753 
754 namespace {
755 // support_fp16: current device and package support float16
CastAndRestoreConstTensorData(Tensor * tensor,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)756 int CastAndRestoreConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors,
757                                   TypeId dst_data_type, bool support_fp16) {
758   MS_ASSERT(tensor != nullptr);
759   MS_ASSERT(tensor->IsConst());
760   MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
761   MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16);
762   if (tensor->data_type() == dst_data_type) {
763     return RET_OK;
764   }
765   auto origin_data = tensor->data();
766   MS_ASSERT(origin_data != nullptr);
767   auto restore_tensor = Tensor::CopyTensor(*tensor, false);
768   restore_tensor->set_data(origin_data);
769   restore_tensor->set_own_data(tensor->own_data());
770   tensor->set_data(nullptr);
771   tensor->set_data_type(dst_data_type);
772   auto ret = tensor->MallocData();
773   if (RET_OK != ret) {
774     MS_LOG(ERROR) << "malloc data failed";
775     return ret;
776   }
777   auto new_tensor_data = tensor->data();
778   MS_ASSERT(new_tensor_data != nullptr);
779   if (dst_data_type == kNumberTypeFloat32) {
780     Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
781   } else {  // dst_data_type == kNumberTypeFloat16
782     Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
783   }
784   if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
785     MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
786     return RET_ERROR;
787   }
788   (*restored_origin_tensors)[tensor] = restore_tensor;
789   return RET_OK;
790 }
791 
792 // support_fp16: current device and package support float16
CastConstTensorsData(const std::vector<Tensor * > & tensors,std::map<Tensor *,Tensor * > * restored_origin_tensors,TypeId dst_data_type,bool support_fp16)793 int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
794                          TypeId dst_data_type, bool support_fp16) {
795   MS_ASSERT(restored_origin_tensors != nullptr);
796   if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
797     MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type.";
798     return RET_PARAM_INVALID;
799   }
800   for (auto *tensor : tensors) {
801     MS_ASSERT(tensor != nullptr);
802     // only cast const tensor
803     // tensorlist not support fp16 now
804     if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
805       continue;
806     }
807     // only support fp32->fp16 or fp16->fp32
808     if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) {
809       continue;
810     }
811     if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
812       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16, support_fp16);
813       if (ret != RET_OK) {
814         MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
815         return ret;
816       }
817     } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
818       auto ret = CastAndRestoreConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32, support_fp16);
819       if (ret != RET_OK) {
820         MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
821         return ret;
822       }
823     } else {
824       MS_LOG(DEBUG) << "No need to cast from " << tensor->data_type() << " to " << dst_data_type;
825     }
826   }
827   return RET_OK;
828 }
829 
FreeRestoreTensors(std::map<Tensor *,Tensor * > * restored_origin_tensors)830 inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
831   MS_ASSERT(restored_origin_tensors != nullptr);
832   for (auto &restored_origin_tensor : *restored_origin_tensors) {
833     restored_origin_tensor.second->set_data(nullptr);
834     delete (restored_origin_tensor.second);
835   }
836   restored_origin_tensors->clear();
837 }
838 
RestoreTensorData(std::map<Tensor *,Tensor * > * restored_origin_tensors)839 inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
840   MS_ASSERT(restored_origin_tensors != nullptr);
841   for (auto &restored_origin_tensor : *restored_origin_tensors) {
842     auto *origin_tensor = restored_origin_tensor.first;
843     auto *restored_tensor = restored_origin_tensor.second;
844     MS_ASSERT(origin_tensor != nullptr);
845     MS_ASSERT(restored_tensor != nullptr);
846     origin_tensor->FreeData();
847     origin_tensor->set_data_type(restored_tensor->data_type());
848     origin_tensor->set_data(restored_tensor->data());
849     origin_tensor->set_own_data(restored_tensor->own_data());
850   }
851   FreeRestoreTensors(restored_origin_tensors);
852 }
853 }  // namespace
854 
ResetByExecutionPlan(std::string node_name,TypeId * data_type)855 void Scheduler::ResetByExecutionPlan(std::string node_name, TypeId *data_type) {
856   if (execution_plan_ == nullptr) {
857     return;
858   }
859   auto iter = execution_plan_->find(node_name);
860   if (iter != execution_plan_->end()) {
861     *data_type = iter->second;
862   }
863   return;
864 }
865 
FindCpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,TypeId kernel_data_type,kernel::LiteKernel ** kernel)866 int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
867                              OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
868                              kernel::LiteKernel **kernel) {
869   MS_ASSERT(op_parameter != nullptr);
870   auto op_type = op_parameter->type_;
871   if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
872     return RET_NOT_SUPPORT;
873   }
874   kernel::KernelKey cpu_desc = desc;
875   if (kernel_data_type == kNumberTypeFloat16) {
876     if (!context_->IsCpuFloat16Enabled() ||
877         (cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
878       return RET_NOT_SUPPORT;
879     }
880     cpu_desc.data_type = kNumberTypeFloat16;
881   }
882   int ret;
883 #ifndef WEIGHT_DECODE_CLIP
884   ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type);
885   if (ret != RET_OK) {
886     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
887     return RET_NOT_SUPPORT;
888   }
889 #endif
890   std::map<Tensor *, Tensor *> restored_origin_tensors;
891 
892   if (is_train_session_) {
893     ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type,
894                                context_->device_and_pkg_support_fp16());
895     if (ret != RET_OK) {
896       MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
897       return RET_NOT_SUPPORT;
898     }
899   }
900   ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, ms_context_, cpu_desc, op_parameter,
901                                                  kernel);
902   if (ret == RET_OK) {
903     MS_LOG(DEBUG) << "Get TypeId(expect = " << kernel_data_type << ", real = " << cpu_desc.data_type
904                   << ") op success: " << PrimitiveCurVersionTypeName(op_type);
905     if (is_train_session_) {
906       (*kernel)->Init();
907       RestoreTensorData(&restored_origin_tensors);
908     }
909   }
910   return ret;
911 }
912 
913 #ifdef GPU_OPENCL
FindGpuKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,OpParameter * op_parameter,const kernel::KernelKey & desc,kernel::LiteKernel ** kernel,TypeId prefer_data_type)914 int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
915                              OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel,
916                              TypeId prefer_data_type) {
917   MS_ASSERT(op_parameter != nullptr);
918   MS_ASSERT(kernel != nullptr);
919   if (!context_->IsGpuEnabled()) {
920     return RET_NOT_SUPPORT;
921   }
922 
923   // support more data type like int32
924   kernel::KernelKey gpu_desc{kernel::KERNEL_ARCH::kGPU, desc.data_type, desc.type};
925   if (desc.data_type == kNumberTypeFloat32 && context_->IsGpuFloat16Enabled()) {
926     gpu_desc.data_type = kNumberTypeFloat16;
927   }
928   if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kNumberTypeFloat32) {
929     gpu_desc.data_type = prefer_data_type;
930   }
931   int ret;
932 #ifndef WEIGHT_DECODE_CLIP
933   // weight dequant
934   ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
935   if (ret != RET_OK) {
936     MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
937     return RET_NOT_SUPPORT;
938   }
939 #endif
940   // we don't need to restore tensor for copy data
941   ret = CopyConstTensorData(in_tensors, op_parameter->type_);
942   if (ret != RET_OK) {
943     MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
944     return RET_NOT_SUPPORT;
945   }
946   ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, ms_context_, gpu_desc, op_parameter,
947                                                  kernel);
948   if (ret == RET_OK) {
949     MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
950   } else {
951     MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
952   }
953   return ret;
954 }
955 #endif
956 
957 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
FindProviderKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId data_type,kernel::LiteKernel ** kernel)958 int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
959                                   const LiteGraph::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
960   MS_ASSERT(kernel != nullptr);
961   int ret = RET_NOT_SUPPORT;
962   auto prim_type = GetPrimitiveType(node->primitive_, schema_version_);
963   if (prim_type == schema::PrimitiveType_Custom) {
964     for (auto &&device : context_->device_list_) {
965       if (!device.provider_.empty() && !device.provider_device_.empty()) {
966         kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, prim_type, device.provider_device_,
967                                device.provider_};
968         ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
969                                                        kernel, node->primitive_);
970         if (ret == RET_OK && *kernel != nullptr) {
971           return ret;
972         }
973       }
974     }
975 
976     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, prim_type, "", ""};
977     ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
978                                                    kernel, node->primitive_);
979     if (ret == RET_OK && *kernel != nullptr) {
980       return ret;
981     }
982     return RET_NOT_SUPPORT;
983   }
984   if (!context_->IsProviderEnabled()) {
985     return ret;
986   }
987   if (schema_version_ == SCHEMA_V0) {
988     return ret;
989   }
990   for (auto &&device : context_->device_list_) {
991     if (!device.provider_.empty()) {
992       kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, prim_type, device.provider_device_,
993                              device.provider_};
994       ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, ms_context_, desc, nullptr,
995                                                      kernel, node->primitive_);
996       if (ret == RET_OK && *kernel != nullptr) {
997         return ret;
998       }
999     }
1000   }
1001 
1002   return RET_NOT_SUPPORT;
1003 }
1004 #endif
1005 
FindBackendKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const LiteGraph::Node * node,TypeId prefer_data_type)1006 kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
1007                                                  const std::vector<Tensor *> &out_tensors, const LiteGraph::Node *node,
1008                                                  TypeId prefer_data_type) {
1009   MS_ASSERT(node != nullptr);
1010   // why we need this
1011   TypeId data_type =
1012     (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
1013   kernel::LiteKernel *kernel = nullptr;
1014   int status;
1015 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
1016   status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
1017   if (status == RET_OK && kernel != nullptr) {
1018     return kernel;
1019   }
1020 #endif
1021   MS_ASSERT(!node->output_indices_.empty());
1022   OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1023   if (op_parameter == nullptr) {
1024     MS_LOG(ERROR) << "Can not find OpParameter!type: " << GetPrimitiveTypeName(node->primitive_, schema_version_);
1025     return nullptr;
1026   }
1027 
1028 #ifdef WEIGHT_DECODE_CLIP
1029   if ((op_parameter->quant_type_ == schema::QuantType_QUANT_WEIGHT) ||
1030       (node->quant_type_ == schema::QuantType_QUANT_WEIGHT)) {
1031     MS_LOG(ERROR) << unsupport_weight_decode_log;
1032     return nullptr;
1033   }
1034 #endif
1035 
1036 #if (defined GPU_OPENCL) || (defined ENABLE_FP16)
1037   int kernel_thread_count = op_parameter->thread_num_;
1038 #endif
1039   op_parameter->is_train_session_ = is_train_session_;
1040   kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
1041 
1042 #ifdef GPU_OPENCL
1043   bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
1044   bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
1045   if (gpu_priority && use_gpu_kernel) {
1046     status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel, prefer_data_type);
1047     if (status == RET_OK) {
1048       return kernel;
1049     } else {
1050       MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1051                     << node->name_;
1052       if (status == RET_ERROR) {
1053         op_parameters_.erase(node->output_indices_.at(0));
1054         auto ret = InferNodeShape(node);
1055         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1056           op_parameter = op_parameters_[node->output_indices_.at(0)];
1057           op_parameter->thread_num_ = kernel_thread_count;
1058         } else {
1059           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1060           return nullptr;
1061         }
1062       }
1063     }
1064   }
1065 #endif
1066 #ifdef ENABLE_FP16
1067   if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
1068       ((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
1069     status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
1070     if (status == RET_OK) {
1071       return kernel;
1072     } else {
1073       MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
1074                     << node->name_;
1075       if (status == RET_ERROR) {
1076         op_parameters_.erase(node->output_indices_.at(0));
1077         auto ret = InferNodeShape(node);
1078         if (ret == RET_INFER_INVALID || ret == RET_OK) {
1079           op_parameter = op_parameters_[node->output_indices_.at(0)];
1080           op_parameter->thread_num_ = kernel_thread_count;
1081         } else {
1082           MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1083           return nullptr;
1084         }
1085       }
1086     }
1087   }
1088 #endif
1089   if (data_type == kNumberTypeFloat16) {
1090     MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
1091     desc.data_type = kNumberTypeFloat32;
1092   }
1093   status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
1094   if (status == RET_OK) {
1095     return kernel;
1096   } else if (status == RET_ERROR) {
1097     op_parameters_.erase(node->output_indices_.at(0));
1098     auto ret = InferNodeShape(node);
1099     if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
1100       MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
1101     }
1102   } else if (status == RET_NOT_SUPPORT) {
1103     free(op_parameter);
1104   }
1105   return nullptr;
1106 }
1107 
1108 namespace {
CreateSubGraphKernel(const std::vector<kernel::LiteKernel * > & kernels,const std::vector<lite::Tensor * > * in_tensors,const std::vector<lite::Tensor * > * out_tensors,kernel::SubGraphType type,const InnerContext & context,int schema_version)1109 kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
1110                                              const std::vector<lite::Tensor *> *in_tensors,
1111                                              const std::vector<lite::Tensor *> *out_tensors, kernel::SubGraphType type,
1112                                              const InnerContext &context, int schema_version) {
1113   if (type == kernel::kApuSubGraph) {
1114     return nullptr;
1115   }
1116   std::vector<Tensor *> input_tensors;
1117   std::vector<Tensor *> output_tensors;
1118   if (in_tensors != nullptr) {
1119     input_tensors = *in_tensors;
1120   } else {
1121     input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels);
1122   }
1123   if (out_tensors != nullptr) {
1124     output_tensors = *out_tensors;
1125   } else {
1126     output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels);
1127   }
1128   auto innerkernel = new (std::nothrow) kernel::InnerKernel(nullptr, input_tensors, output_tensors, &context);
1129   if (innerkernel == nullptr) {
1130     return nullptr;
1131   }
1132   std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputNodes(kernels);
1133   std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputNodes(kernels);
1134   kernel::SubGraphKernel *sub_graph = nullptr;
1135   if (type == kernel::kCustomSubGraph) {
1136     sub_graph = CreateCustomSubGraph(std::move(input_kernels), std::move(output_kernels), kernels, innerkernel);
1137   }
1138   if (type == kernel::kGpuFp16SubGraph || type == kernel::kGpuFp32SubGraph) {
1139 #if GPU_OPENCL
1140     sub_graph = new (std::nothrow) kernel::OpenCLSubGraph(input_kernels, output_kernels, kernels, innerkernel);
1141     if (sub_graph == nullptr) {
1142       MS_LOG(ERROR) << "Create OpenCLSubGraph failed";
1143       delete innerkernel;
1144       return nullptr;
1145     }
1146 #else
1147     delete innerkernel;
1148     return nullptr;
1149 #endif
1150   }
1151   if (type == kernel::kCpuFP16SubGraph) {
1152 #ifdef ENABLE_FP16
1153     sub_graph = new (std::nothrow) kernel::CpuFp16SubGraph(input_kernels, output_kernels, kernels, innerkernel);
1154     if (sub_graph == nullptr) {
1155       MS_LOG(ERROR) << "FP16 subgraph new failed.";
1156       delete innerkernel;
1157       return nullptr;
1158     }
1159     for (auto out_tensor : output_tensors) {
1160       if (out_tensor->data_type() == kNumberTypeFloat32) {
1161         out_tensor->set_data_type(kNumberTypeFloat16);
1162       }
1163     }
1164 #else
1165     delete innerkernel;
1166     MS_LOG(ERROR) << "FP16 subgraph is not supported!";
1167     return nullptr;
1168 #endif
1169   }
1170   if (type == kernel::kCpuFP32SubGraph) {
1171     sub_graph = new (std::nothrow) kernel::CpuFp32SubGraph(input_kernels, output_kernels, kernels, innerkernel);
1172     if (sub_graph == nullptr) {
1173       MS_LOG(ERROR) << "FP32 subgraph new failed.";
1174       delete innerkernel;
1175       return nullptr;
1176     }
1177   }
1178   if (sub_graph == nullptr) {
1179     MS_LOG(ERROR) << "create sub graph failed.";
1180     return nullptr;
1181   }
1182   sub_graph->set_context(&context);
1183   sub_graph->SetSchemaVersion(schema_version);
1184   return sub_graph;
1185 }
1186 
1187 namespace {
GetCustomKernelSubGraphType(const kernel::LiteKernel * kernel)1188 kernel::SubGraphType GetCustomKernelSubGraphType(const kernel::LiteKernel *kernel) {
1189   auto desc = kernel->desc();
1190   if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
1191     if (desc.data_type == kNumberTypeFloat16) {
1192       return kernel::kGpuFp16SubGraph;
1193     }
1194     return kernel::kGpuFp32SubGraph;
1195   }
1196   return kernel::kCustomSubGraph;
1197 }
1198 }  // namespace
1199 
GetKernelSubGraphType(const kernel::LiteKernel * kernel,const InnerContext & context,bool is_controlflow=false)1200 kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel, const InnerContext &context,
1201                                            bool is_controlflow = false) {
1202   if (kernel == nullptr) {
1203     return kernel::kNotSubGraph;
1204   }
1205 
1206   auto desc = kernel->desc();
1207   if (desc.provider != kernel::kBuiltin) {
1208     return GetCustomKernelSubGraphType(kernel);
1209   }
1210   if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
1211     if (desc.data_type == kNumberTypeFloat16) {
1212       return kernel::kGpuFp16SubGraph;
1213     } else {
1214       return kernel::kGpuFp32SubGraph;
1215     }
1216   } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) {
1217     return kernel::kNpuSubGraph;
1218   } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) {
1219     return kernel::kApuSubGraph;
1220   } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) {
1221     if (desc.data_type == kNumberTypeFloat16) {
1222       return kernel::kCpuFP16SubGraph;
1223     } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 ||
1224                desc.data_type == kNumberTypeInt64 || desc.data_type == kNumberTypeUInt8 ||
1225                desc.data_type == kNumberTypeBool) {
1226       return kernel::kCpuFP32SubGraph;
1227     } else if (desc.data_type == kNumberTypeInt32) {
1228       if (context.IsCpuFloat16Enabled() && !is_controlflow) {
1229         return kernel::kCpuFP16SubGraph;
1230       } else {
1231         return kernel::kCpuFP32SubGraph;
1232       }
1233     }
1234   }
1235   return kernel::kNotSubGraph;
1236 }
1237 }  // namespace
1238 
SchedulePartialToKernel(const lite::LiteGraph::Node * src_node)1239 kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::LiteGraph::Node *src_node) {
1240   MS_ASSERT(src_model_ != nullptr);
1241   MS_ASSERT(src_node != nullptr);
1242   auto *primitive = src_node->primitive_;
1243   MS_ASSERT(primitive != nullptr);
1244   if (!IsPartialNode(primitive, schema_version_)) {
1245     return nullptr;
1246   }
1247   auto subgraph_index = GetPartialGraphIndex(src_node->primitive_, schema_version_);
1248   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1249   if (subgraph_kernel == nullptr) {
1250     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1251     return {};
1252   }
1253   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1254   return subgraph_kernel;
1255 }
1256 
1257 #ifdef ENABLE_FP16
SubGraphPreferDataType(const int & subgraph_index,TypeId * prefer_data_type)1258 int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
1259   if (!context_->IsCpuFloat16Enabled()) {
1260     *prefer_data_type = kNumberTypeFloat32;
1261     return RET_OK;
1262   }
1263 
1264   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1265   for (auto node_index : subgraph->node_indices_) {
1266     auto node = src_model_->graph_.all_nodes_[node_index];
1267     MS_ASSERT(node != nullptr);
1268     MS_ASSERT(!node->output_indices_.empty());
1269     OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
1270     if (op_parameter == nullptr) {
1271       MS_LOG(ERROR) << "Can not find OpParameter!type: " << GetPrimitiveTypeName(node->primitive_, schema_version_);
1272       return RET_ERROR;
1273     }
1274     kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16,
1275                            static_cast<schema::PrimitiveType>(op_parameter->type_)};
1276     if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
1277       *prefer_data_type = kNumberTypeFloat32;
1278       return RET_OK;
1279     }
1280 
1281     std::vector<Tensor *> inputs;
1282     std::vector<Tensor *> outputs;
1283     FindNodeInoutTensors(*node, &inputs, &outputs);
1284 #ifndef WEIGHT_DECODE_CLIP
1285     if (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
1286       *prefer_data_type = kNumberTypeFloat32;
1287       return RET_OK;
1288     }
1289 #endif
1290     TypeId data_type = GetFirstFp32Fp16OrInt8Type(inputs);
1291     if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) {
1292       *prefer_data_type = kNumberTypeFloat32;
1293       return RET_OK;
1294     }
1295   }
1296   *prefer_data_type = kNumberTypeFloat16;
1297   return RET_OK;
1298 }
1299 #endif
1300 
ScheduleMainSubGraphToKernels()1301 std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
1302   std::vector<kernel::LiteKernel *> kernels;
1303   std::vector<lite::Tensor *> in_tensors;
1304   std::vector<lite::Tensor *> out_tensors;
1305   auto ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, &kernels, &in_tensors, &out_tensors);
1306   if (ret != RET_OK) {
1307     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << kMainSubGraphIndex;
1308     return {};
1309   }
1310   return kernels;
1311 }
1312 
SchedulePartialToSubGraphKernel(const int & subgraph_index)1313 kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
1314   TypeId prefer_data_type = kTypeUnknown;
1315 #ifdef ENABLE_FP16
1316   if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
1317     MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
1318     return nullptr;
1319   }
1320 #endif
1321   std::vector<kernel::LiteKernel *> kernels;
1322   std::vector<lite::Tensor *> in_tensors;
1323   std::vector<lite::Tensor *> out_tensors;
1324   auto ret = ScheduleSubGraphToKernels(subgraph_index, &kernels, &in_tensors, &out_tensors, prefer_data_type);
1325   if (ret != RET_OK) {
1326     MS_LOG(ERROR) << "Schedule subgraph failed, index: " << subgraph_index;
1327     return nullptr;
1328   }
1329   kernel::LiteKernelUtil::FindAllInoutKernels(kernels);
1330   kernel::SubGraphType cur_sub_graph_type = kernel::kCpuFP32SubGraph;
1331   if (!kernels.empty()) {
1332     cur_sub_graph_type = GetKernelSubGraphType(kernels.front(), *context_, true);
1333   }
1334   MS_LOG(INFO) << "cur_sub_graph_type: " << cur_sub_graph_type;
1335   auto subgraph_kernel =
1336     CreateSubGraphKernel(kernels, &in_tensors, &out_tensors, cur_sub_graph_type, *context_, schema_version_);
1337   if (subgraph_kernel == nullptr) {
1338     MS_LOG(ERROR) << "CreateSubGraphKernel failed, cur_sub_graph_type: " << cur_sub_graph_type;
1339     return nullptr;
1340   }
1341   return subgraph_kernel;
1342 }
1343 
ScheduleSubGraphToSubGraphKernels(const int & subgraph_index)1344 std::vector<kernel::LiteKernel *> Scheduler::ScheduleSubGraphToSubGraphKernels(const int &subgraph_index) {
1345   if (subgraph_index == kMainSubGraphIndex) {
1346     return ScheduleMainSubGraphToKernels();
1347   }
1348   auto subgraph_kernel = SchedulePartialToSubGraphKernel(subgraph_index);
1349   if (subgraph_kernel == nullptr) {
1350     MS_LOG(ERROR) << "SchedulePartialToSubGraphKernel failed, subgraph_index: " << subgraph_index;
1351     return {};
1352   }
1353   subgraph_kernel->set_name("subgraph_" + std::to_string(subgraph_index));
1354   subgraph_index_subgraph_kernel_map_[subgraph_index] = subgraph_kernel;
1355   return {subgraph_kernel};
1356 }
1357 
ScheduleNodeToKernel(const lite::LiteGraph::Node * src_node,TypeId prefer_data_type)1358 kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::LiteGraph::Node *src_node, TypeId prefer_data_type) {
1359   std::vector<Tensor *> inputs;
1360   std::vector<Tensor *> outputs;
1361   MS_ASSERT(src_node != nullptr);
1362   FindNodeInoutTensors(*src_node, &inputs, &outputs);
1363 
1364   ResetByExecutionPlan(src_node->name_, &prefer_data_type);
1365 
1366   auto *kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type);
1367   op_parameters_[src_node->output_indices_.at(0)] = nullptr;
1368   if (kernel == nullptr) {
1369     MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_
1370                   << ", type: " << GetPrimitiveTypeName(src_node->primitive_, schema_version_);
1371     return nullptr;
1372   }
1373 
1374   SetKernelTensorDataType(kernel);
1375   kernel->set_name(src_node->name_);
1376   return kernel;
1377 }
1378 
IsControlFlowPattern(const lite::LiteGraph::Node & partial_node)1379 bool Scheduler::IsControlFlowPattern(const lite::LiteGraph::Node &partial_node) {
1380   lite::LiteGraph::Node *partial_node_output = nullptr;
1381   for (auto output_index : partial_node.output_indices_) {
1382     for (auto &node : src_model_->graph_.all_nodes_) {
1383       if (IsContain(node->input_indices_, output_index)) {
1384         partial_node_output = node;
1385         break;
1386       }
1387     }
1388   }
1389 
1390   return partial_node_output == nullptr ? false
1391                                         : (IsCallNode(partial_node_output->primitive_, schema_version_) ||
1392                                            IsSwitchNode(partial_node_output->primitive_, schema_version_));
1393 }
1394 
ScheduleGraphToKernels(std::vector<kernel::LiteKernel * > * dst_kernels,TypeId prefer_data_type)1395 int Scheduler::ScheduleGraphToKernels(std::vector<kernel::LiteKernel *> *dst_kernels, TypeId prefer_data_type) {
1396   subgraphs_to_schedule_.push_back(kMainSubGraphIndex);
1397   while (!subgraphs_to_schedule_.empty()) {
1398     auto cur_subgraph_index = subgraphs_to_schedule_.front();
1399     subgraphs_to_schedule_.pop_front();
1400     auto kernels = ScheduleSubGraphToSubGraphKernels(cur_subgraph_index);
1401     if (kernels.empty()) {
1402       MS_LOG(ERROR) << "ScheduleSubGraphToSubGraphKernel failed";
1403       return RET_ERROR;
1404     }
1405     std::copy(kernels.begin(), kernels.end(), std::back_inserter(*dst_kernels));
1406   }
1407   return RET_OK;
1408 }
1409 
ScheduleSubGraphToKernels(size_t subgraph_index,std::vector<kernel::LiteKernel * > * dst_kernels,std::vector<lite::Tensor * > * in_tensors,std::vector<lite::Tensor * > * out_tensors,TypeId prefer_data_type)1410 int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels,
1411                                          std::vector<lite::Tensor *> *in_tensors,
1412                                          std::vector<lite::Tensor *> *out_tensors, TypeId prefer_data_type) {
1413   MS_ASSERT(src_model_ != nullptr);
1414   MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
1415   MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
1416   MS_ASSERT(dst_kernels != nullptr);
1417   MS_ASSERT(dst_kernels->empty());
1418   auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
1419   auto ret = RET_OK;
1420   for (auto node_index : subgraph->node_indices_) {
1421     auto node = src_model_->graph_.all_nodes_[node_index];
1422     MS_ASSERT(node != nullptr);
1423     auto *primitive = node->primitive_;
1424     MS_ASSERT(primitive != nullptr);
1425     kernel::LiteKernel *kernel = nullptr;
1426 
1427     if (IsPartialNode(primitive, schema_version_)) {
1428       if (IsControlFlowPattern(*node)) {
1429 #ifndef CONTROLFLOW_TENSORLIST_CLIP
1430         kernel = ScheduleNodeToKernel(node, prefer_data_type);
1431         auto partial_subgraph_index = GetPartialGraphIndex(primitive, schema_version_);
1432         if (SubGraphHasScheduled(partial_subgraph_index)) {
1433           partial_kernel_subgraph_index_map_[kernel] = partial_subgraph_index;
1434           MS_LOG(INFO) << "subgraph has scheduled. ";
1435         } else {
1436           SubGraphMarkScheduled(partial_subgraph_index);
1437           partial_kernel_subgraph_index_map_[kernel] = partial_subgraph_index;
1438           subgraphs_to_schedule_.push_back(partial_subgraph_index);
1439         }
1440 #else
1441         MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
1442         return RET_ERROR;
1443 #endif
1444       } else {
1445         kernel = SchedulePartialToKernel(node);
1446       }
1447     } else {
1448       kernel = ScheduleNodeToKernel(node, prefer_data_type);
1449     }
1450     if (kernel == nullptr || ret != RET_OK) {
1451       MS_LOG(ERROR) << "schedule node return nullptr, name: " << node->name_
1452                     << ", type: " << GetPrimitiveTypeName(primitive, schema_version_);
1453       return RET_ERROR;
1454     }
1455     kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index)));
1456     dst_kernels->emplace_back(kernel);
1457     primitives_.emplace(kernel->kernel(), static_cast<const schema::Primitive *>(primitive));
1458   }
1459   if (in_tensors != nullptr) {
1460     std::transform(subgraph->input_indices_.begin(), subgraph->input_indices_.end(), std::back_inserter(*in_tensors),
1461                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1462   }
1463   if (out_tensors != nullptr) {
1464     std::transform(subgraph->output_indices_.begin(), subgraph->output_indices_.end(), std::back_inserter(*out_tensors),
1465                    [&](const uint32_t index) { return this->src_tensors_->at(index); });
1466   }
1467   return RET_OK;
1468 }
1469 
1470 namespace {
KernelFitCurrentSubGraphCPUFp32(TypeId data_type)1471 bool KernelFitCurrentSubGraphCPUFp32(TypeId data_type) {
1472   return (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat || data_type == kNumberTypeInt8 ||
1473           data_type == kNumberTypeInt || data_type == kNumberTypeInt32 || data_type == kNumberTypeInt64 ||
1474           data_type == kNumberTypeUInt8 || data_type == kNumberTypeBool);
1475 }
1476 
KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type,const kernel::LiteKernel & kernel)1477 bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::LiteKernel &kernel) {
1478   switch (subgraph_type) {
1479     case kernel::SubGraphType::kNotSubGraph:
1480     case kernel::SubGraphType::kApuSubGraph:
1481       return false;
1482     case kernel::SubGraphType::kGpuFp16SubGraph:
1483       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1484         return false;
1485       }
1486       return (kernel.desc().data_type != kNumberTypeFloat32);
1487     case kernel::SubGraphType::kGpuFp32SubGraph:
1488       if (kernel.desc().arch != kernel::KERNEL_ARCH::kGPU) {
1489         return false;
1490       }
1491       return (kernel.desc().data_type != kNumberTypeFloat16);
1492     case kernel::SubGraphType::kNpuSubGraph:
1493       return kernel.desc().arch == kernel::KERNEL_ARCH::kNPU;
1494     case kernel::SubGraphType::kCpuFP16SubGraph: {
1495       auto desc = kernel.desc();
1496       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1497         return false;
1498       }
1499       return (desc.data_type == kNumberTypeFloat16 || desc.data_type == kNumberTypeInt32 ||
1500               desc.data_type == kNumberTypeInt || desc.data_type == kNumberTypeBool);
1501     }
1502     case kernel::SubGraphType::kCpuFP32SubGraph: {
1503       auto desc = kernel.desc();
1504       if (desc.arch != kernel::KERNEL_ARCH::kCPU) {
1505         return false;
1506       }
1507       return KernelFitCurrentSubGraphCPUFp32(desc.data_type);
1508     }
1509     default:
1510       return false;
1511   }
1512 }
1513 
FindAllSubGraphKernels(const std::vector<kernel::LiteKernel * > & sorted_kernels,const InnerContext & context,size_t * cur_index,int schema_version)1514 kernel::LiteKernel *FindAllSubGraphKernels(const std::vector<kernel::LiteKernel *> &sorted_kernels,
1515                                            const InnerContext &context, size_t *cur_index, int schema_version) {
1516   std::vector<kernel::LiteKernel *> sub_kernels;
1517   sub_kernels.emplace_back(sorted_kernels[*cur_index]);
1518   auto cur_sub_graph_type = GetKernelSubGraphType(sorted_kernels[*cur_index], context);
1519   for (*cur_index = *cur_index + 1; *cur_index < sorted_kernels.size(); ++(*cur_index)) {
1520     auto cur_kernel = sorted_kernels[*cur_index];
1521     MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph);
1522     // already a subgraph or a delegate
1523 #ifndef DELEGATE_CLIP
1524     if (cur_kernel->desc().arch == kernel::kDelegate) {
1525       --(*cur_index);
1526       break;
1527     }
1528 #endif
1529     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph ||
1530         !KernelFitCurrentSubGraph(cur_sub_graph_type, *cur_kernel)) {
1531       --(*cur_index);
1532       break;
1533     }
1534     sub_kernels.emplace_back(cur_kernel);
1535   }
1536   return CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type, context, schema_version);
1537 }
1538 }  // namespace
1539 
ConstructSubGraphs(std::vector<kernel::LiteKernel * > src_kernel,std::vector<kernel::LiteKernel * > * dst_kernel,std::map<const kernel::LiteKernel *,bool> * is_kernel_finish)1540 int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
1541                                   std::vector<kernel::LiteKernel *> *dst_kernel,
1542                                   std::map<const kernel::LiteKernel *, bool> *is_kernel_finish) {
1543   if (src_kernel.empty()) {
1544     return RET_OK;
1545   }
1546 
1547   // construct subgraph
1548   for (size_t index = 0; index < src_kernel.size(); index++) {
1549     auto cur_kernel = src_kernel[index];
1550     MS_ASSERT(cur_kernel != nullptr);
1551     // Not support APU now
1552     MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph);
1553 #ifndef DELEGATE_CLIP
1554     if (cur_kernel->desc().arch == kernel::kDelegate) {
1555       dst_kernel->emplace_back(cur_kernel);
1556       continue;
1557     }
1558 #endif
1559     // already a subgraph or a delegate
1560     if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
1561       dst_kernel->emplace_back(cur_kernel);
1562       continue;
1563     }
1564     auto subgraph = FindAllSubGraphKernels(src_kernel, *context_, &index, schema_version_);
1565     if (subgraph == nullptr) {
1566       MS_LOG(ERROR) << "Create SubGraphKernel failed";
1567       return RET_ERROR;
1568     }
1569     dst_kernel->emplace_back(subgraph);
1570   }
1571   for (auto *subgraph : *dst_kernel) {
1572 #ifndef DELEGATE_CLIP
1573     if (subgraph->desc().arch != kernel::kDelegate) {
1574 #endif
1575       auto ret = subgraph->Init();
1576       if (ret != RET_OK) {
1577         MS_LOG(ERROR) << "Init SubGraph failed: " << ret;
1578         return ret;
1579       }
1580 #ifndef DELEGATE_CLIP
1581     }
1582 #endif
1583   }
1584   return RET_OK;
1585 }
1586 
GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor * > & in_tensors)1587 TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) {
1588   for (const auto &tensor : in_tensors) {
1589     auto dtype = tensor->data_type();
1590     if (dtype == kObjectTypeString) {
1591       return kNumberTypeFloat32;
1592     }
1593 #ifndef CONTROLFLOW_TENSORLIST_CLIP
1594     if (dtype == kObjectTypeTensorType) {
1595       auto tensor_list = reinterpret_cast<TensorList *>(tensor);
1596       auto tensor_list_dtype = tensor_list->tensors_data_type();
1597       if (tensor_list_dtype == kNumberTypeFloat32 || tensor_list_dtype == kNumberTypeFloat16 ||
1598           tensor_list_dtype == kNumberTypeInt8 || tensor_list_dtype == kNumberTypeInt32 ||
1599           tensor_list_dtype == kNumberTypeBool) {
1600         return tensor_list_dtype;
1601       }
1602     }
1603 #endif
1604     if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8 ||
1605         dtype == kNumberTypeInt32 || dtype == kNumberTypeBool) {
1606       return dtype;
1607     }
1608   }
1609   MS_ASSERT(!in_tensors.empty());
1610   return in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
1611 }
1612 
SetKernelTensorDataType(kernel::LiteKernel * kernel)1613 void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
1614   MS_ASSERT(kernel != nullptr);
1615   if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
1616     return;
1617   }
1618   if (kernel->desc().data_type == kNumberTypeFloat16) {
1619     for (auto tensor : kernel->out_tensors()) {
1620       if (tensor->data_type() == kNumberTypeFloat32) {
1621         tensor->set_data_type(kNumberTypeFloat16);
1622       }
1623     }
1624   } else if (kernel->desc().data_type == kNumberTypeFloat32) {
1625     for (auto tensor : kernel->in_tensors()) {
1626       if (!tensor->IsConst() && tensor->data_type() == kNumberTypeFloat16) {
1627         tensor->set_data_type(kNumberTypeFloat32);
1628       }
1629     }
1630     for (auto tensor : kernel->out_tensors()) {
1631       if (tensor->data_type() == kNumberTypeFloat16) {
1632         tensor->set_data_type(kNumberTypeFloat32);
1633       }
1634     }
1635   }
1636 }
1637 
PartialSubGraphType(const std::vector<kernel::LiteKernel * > & kernels)1638 kernel::SubGraphType Scheduler::PartialSubGraphType(const std::vector<kernel::LiteKernel *> &kernels) {
1639   if (std::any_of(kernels.begin(), kernels.end(),
1640                   [](kernel::LiteKernel *item) { return item->desc().data_type == kNumberTypeFloat16; })) {
1641     return kernel::kCpuFP16SubGraph;
1642   }
1643   return kernel::kCpuFP32SubGraph;
1644 }
1645 
1646 #ifndef CONTROLFLOW_TENSORLIST_CLIP
InferSwitchShape(const lite::LiteGraph::Node * switch_node)1647 int Scheduler::InferSwitchShape(const lite::LiteGraph::Node *switch_node) {
1648   MS_ASSERT(src_model_ != nullptr);
1649   MS_ASSERT(switch_node != nullptr);
1650   if (!IsSwitchNode(switch_node->primitive_, schema_version_)) {
1651     MS_LOG(ERROR) << "Node is not a switch";
1652     return RET_PARAM_INVALID;
1653   }
1654   std::deque<lite::LiteGraph::Node *> partial_cnode_to_infer{};
1655   auto true_branch_output_index = switch_node->input_indices_.at(kSwitchTrueBranch);
1656   auto false_branch_output_index = switch_node->input_indices_.at(kSwitchFalseBranch);
1657   for (auto &node : src_model_->graph_.all_nodes_) {
1658     if ((IsContain(node->output_indices_, true_branch_output_index) ||
1659          IsContain(node->output_indices_, false_branch_output_index)) &&
1660         IsPartialNode(node->primitive_, schema_version_) &&
1661         partial_cnode_inferred_.find(node) == partial_cnode_inferred_.end()) {
1662       partial_cnode_inferred_.insert(node);
1663       partial_cnode_to_infer.push_back(node);
1664     }
1665   }
1666 
1667   while (!partial_cnode_to_infer.empty()) {
1668     auto &node = partial_cnode_to_infer.front();
1669     partial_cnode_to_infer.pop_front();
1670     int ret = InferPartialShape(node);
1671     if (ret != RET_OK) {
1672       MS_LOG(WARNING) << "partial infer not ok, ret: " << ret;
1673     }
1674   }
1675   return RET_OK;
1676 }
1677 
NodeInputIsSwitch(const lite::LiteGraph::Node * node)1678 LiteGraph::Node *Scheduler::NodeInputIsSwitch(const lite::LiteGraph::Node *node) {
1679   MS_ASSERT(src_model_ != nullptr);
1680   MS_ASSERT(node != nullptr);
1681   for (auto &iter : src_model_->graph_.all_nodes_) {
1682     if (iter->output_indices_ == node->input_indices_) {
1683       if (IsSwitchNode(iter->primitive_, schema_version_)) {
1684         return iter;
1685       } else {
1686         return nullptr;
1687       }
1688     }
1689   }
1690   return nullptr;
1691 }
1692 
SubGraphHasScheduled(const int & index)1693 bool Scheduler::SubGraphHasScheduled(const int &index) {
1694   return scheduled_subgraph_index_.find(index) != scheduled_subgraph_index_.end();
1695 }
1696 
SubGraphMarkScheduled(const int & index)1697 void Scheduler::SubGraphMarkScheduled(const int &index) { scheduled_subgraph_index_.insert(index); }
1698 
1699 #ifndef CONTROLFLOW_TENSORLIST_CLIP
SetSubgraphForPartialNode()1700 void Scheduler::SetSubgraphForPartialNode() {
1701   for (auto &pair : partial_kernel_subgraph_index_map_) {
1702     auto &partial_kernel = pair.first;
1703     auto &subgraph_index = pair.second;
1704     static_cast<kernel::PartialFusionKernel *>(partial_kernel->kernel())
1705       ->set_subgraph_kernel(subgraph_index_subgraph_kernel_map_.at(subgraph_index));
1706   }
1707 }
1708 #endif
1709 
CopyTensorList(TensorList * dst_tensor,TensorList * src_tensor)1710 void CopyTensorList(TensorList *dst_tensor, TensorList *src_tensor) {
1711   dst_tensor->set_data_type(src_tensor->data_type());
1712   dst_tensor->set_format(src_tensor->format());
1713   dst_tensor->set_element_shape(src_tensor->element_shape());
1714   dst_tensor->set_shape(src_tensor->shape());
1715   std::vector<Tensor *> cpy_tensors{};
1716   for (auto &tensor : src_tensor->tensors()) {
1717     auto new_tensor = Tensor::CopyTensor(*tensor, false);
1718     cpy_tensors.push_back(new_tensor);
1719   }
1720   dst_tensor->set_tensors(cpy_tensors);
1721 }
1722 
IsControlFlowParttern(const std::vector<kernel::LiteKernel * > & kernels)1723 bool Scheduler::IsControlFlowParttern(const std::vector<kernel::LiteKernel *> &kernels) {
1724   if (std::any_of(kernels.begin(), kernels.end(), [](kernel::LiteKernel *item) {
1725         if (item->op_parameter()) {
1726           return item->op_parameter()->type_ == schema::PrimitiveType_PartialFusion;
1727         }
1728         return false;
1729       })) {
1730     return true;
1731   }
1732   return false;
1733 }
1734 
ConstructControlFlowMainGraph(std::vector<kernel::LiteKernel * > * kernels)1735 int Scheduler::ConstructControlFlowMainGraph(std::vector<kernel::LiteKernel *> *kernels) {
1736   auto back_kernels = *kernels;
1737   kernels->clear();
1738   std::vector<kernel::LiteKernel *> main_graph_kernels{};
1739   for (auto &kernel : back_kernels) {
1740     if (kernel->subgraph_type() != kernel::kNotSubGraph) {
1741       kernels->push_back(kernel);
1742     } else {
1743       main_graph_kernels.push_back(kernel);
1744     }
1745   }
1746   auto cur_subgraph_type = PartialSubGraphType(main_graph_kernels);
1747   auto subgraph_kernel =
1748     CreateSubGraphKernel(main_graph_kernels, nullptr, nullptr, cur_subgraph_type, *context_, schema_version_);
1749   if (subgraph_kernel == nullptr) {
1750     MS_LOG(ERROR) << "create main graph for control flow model failed.";
1751     return RET_ERROR;
1752   }
1753   kernels->insert(kernels->begin(), subgraph_kernel);
1754   return RET_OK;
1755 }
1756 #endif
1757 }  // namespace mindspore::lite
1758