• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/kernel_exec_util.h"
18 #include <utility>
19 #include <queue>
20 #include <unordered_map>
21 #include <set>
22 #include "src/executor/sub_graph_kernel.h"
23 #include "nnacl/call_parameter.h"
24 #if GPU_OPENCL
25 #include "src/litert/kernel/opencl/opencl_subgraph.h"
26 #include "src/litert/kernel/gpu/opencl/opencl_runtime.h"
27 #endif
28 #include "src/control_flow/control_subgraph_creator.h"
29 #include "src/litert/kernel/cpu/base/partial_fusion.h"
30 
31 namespace mindspore::kernel {
32 using mindspore::lite::RET_ERROR;
33 using mindspore::lite::RET_OK;
34 
TopologicalSortNodes(std::vector<KernelExec * > * nodes,std::vector<KernelExec * > in_nodes)35 int KernelExecUtil::TopologicalSortNodes(std::vector<KernelExec *> *nodes, std::vector<KernelExec *> in_nodes) {
36   auto old_nodes = *nodes;
37   if (in_nodes.empty()) {
38     in_nodes = KernelExecUtil::SubgraphInputNodes(old_nodes);
39   }
40   nodes->clear();
41   nodes->reserve(old_nodes.size());
42   std::queue<KernelExec *> kernel_queue;
43   for (auto kernel : in_nodes) {
44     if (std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(),
45                     [&](KernelExec *in_kernel) { return (!lite::IsContain(old_nodes, in_kernel)); })) {
46       kernel_queue.push(kernel);
47     }
48   }
49 
50   while (!kernel_queue.empty()) {
51     auto cur_kernel = kernel_queue.front();
52     (void)nodes->emplace_back(cur_kernel);
53     kernel_queue.pop();
54     if (cur_kernel == nullptr) {
55       MS_LOG(ERROR) << "TopologicalSortKernels failed, nullptr in nodes.";
56       return lite::RET_NULL_PTR;
57     }
58     auto next_kernels = cur_kernel->out_kernels();
59     for (auto next_kernel : next_kernels) {
60       if (!lite::IsContain(old_nodes, next_kernel)) {
61         continue;
62       }
63       if (lite::IsContain(*nodes, next_kernel)) {
64         MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist.";
65         return lite::RET_ERROR;
66       }
67       auto in_kernels = next_kernel->in_kernels();
68       if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](KernelExec *in_kernel) {
69             return lite::IsContain(*nodes, in_kernel) || (!lite::IsContain(old_nodes, in_kernel));
70           })) {
71         kernel_queue.push(next_kernel);
72       }
73     }
74   }
75   if (nodes->size() != old_nodes.size()) {
76     MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_nodes.size()
77                   << ", kernels size after sort: " << nodes->size();
78     return lite::RET_ERROR;
79   }
80   return lite::RET_OK;
81 }
82 
AllOutTensor(const std::vector<KernelExec * > & kernels)83 std::set<lite::Tensor *> KernelExecUtil::AllOutTensor(const std::vector<KernelExec *> &kernels) {
84   std::set<lite::Tensor *> all_out_tensors{};
85   for (const auto &kernel_in_subgraph : kernels) {
86     for (auto *tensor : kernel_in_subgraph->out_tensors()) {
87       (void)all_out_tensors.insert(tensor);
88     }
89   }
90   return all_out_tensors;
91 }
92 
SubgraphInputNodes(const std::vector<KernelExec * > & kernels)93 std::vector<KernelExec *> KernelExecUtil::SubgraphInputNodes(const std::vector<KernelExec *> &kernels) {
94   std::vector<KernelExec *> input_nodes;
95   std::set<lite::Tensor *> all_out_tensors = AllOutTensor(kernels);
96   for (const auto &kernel : kernels) {
97     MS_ASSERT(kernel != nullptr);
98     bool kernel_is_input = false;
99     auto all_input_tensors = kernel->in_tensors();
100     for (auto input : kernel->in_tensors()) {
101       if (input->IsConst()) {
102         continue;
103       }
104       if (all_out_tensors.find(input) != all_out_tensors.end()) {
105         continue;
106       }
107       kernel_is_input = true;
108       break;
109     }
110     if (kernel_is_input && !lite::IsContain(input_nodes, kernel)) {
111       input_nodes.push_back(kernel);
112     }
113   }
114   return input_nodes;
115 }
116 
SubgraphOutputNodes(const std::vector<KernelExec * > & kernels)117 std::vector<KernelExec *> KernelExecUtil::SubgraphOutputNodes(const std::vector<KernelExec *> &kernels) {
118   std::set<KernelExec *> all_kernels{};
119   for (const auto &kernel : kernels) {
120     (void)all_kernels.insert(kernel);
121   }
122   std::vector<KernelExec *> output_nodes;
123   // if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output
124   for (const auto &kernel : kernels) {
125     MS_ASSERT(kernel != nullptr);
126     if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) {
127       if (!lite::IsContain(output_nodes, kernel)) {
128         output_nodes.push_back(kernel);
129       }
130       continue;
131     }
132     if (std::any_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
133                     [&all_kernels](KernelExec *tmp) { return all_kernels.find(tmp) == all_kernels.end(); }) &&
134         !lite::IsContain(output_nodes, kernel)) {
135       output_nodes.push_back(kernel);
136     }
137   }
138   return output_nodes;
139 }
140 
SubgraphInputTensors(const std::vector<KernelExec * > & kernels)141 std::vector<lite::Tensor *> KernelExecUtil::SubgraphInputTensors(const std::vector<KernelExec *> &kernels) {
142   std::vector<lite::Tensor *> input_tensors;
143   std::vector<KernelExec *> input_nodes = SubgraphInputNodes(kernels);
144   for (const auto &input_node : input_nodes) {
145     auto &in_node_in_kernels = input_node->in_kernels();
146     auto &in_node_in_tensors = input_node->in_tensors();
147     for (auto &in_node_in_tensor : in_node_in_tensors) {
148       if (in_node_in_tensor->IsGraphInput() || (in_node_in_kernels.empty() && !in_node_in_tensor->IsConst())) {
149         if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
150           input_tensors.push_back(in_node_in_tensor);
151         }
152       }
153     }
154     for (auto in_node_in_kernel : in_node_in_kernels) {
155       auto iter = std::find(kernels.begin(), kernels.end(), in_node_in_kernel);
156       if (iter != kernels.end()) {
157         continue;
158       }
159       auto &outer_in_kernel_out_tensors = in_node_in_kernel->out_tensors();
160       for (auto in_node_in_tensor : in_node_in_tensors) {
161         auto outer_in_kernel_out_tensors_iter =
162           std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor);
163         if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
164           if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
165             input_tensors.push_back(in_node_in_tensor);
166           }
167         }
168       }
169     }
170   }
171   return input_tensors;
172 }
173 
SubgraphOutputTensors(const std::vector<KernelExec * > & kernels)174 std::vector<lite::Tensor *> KernelExecUtil::SubgraphOutputTensors(const std::vector<KernelExec *> &kernels) {
175   std::vector<lite::Tensor *> output_tensors;
176   std::vector<KernelExec *> output_nodes = SubgraphOutputNodes(kernels);
177   for (const auto &output_kernel : output_nodes) {
178     auto &outer_out_kernels = output_kernel->out_kernels();
179     auto &out_kernel_out_tensors = output_kernel->out_tensors();
180     for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
181       if ((out_kernel_out_tensor->IsGraphOutput() || outer_out_kernels.empty()) &&
182           !lite::IsContain(output_tensors, out_kernel_out_tensor)) {
183         output_tensors.push_back(out_kernel_out_tensor);
184       }
185     }
186     if (!outer_out_kernels.empty()) {
187       for (auto outer_out_kernel : outer_out_kernels) {
188         auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
189         if (iter != kernels.end()) {
190           continue;
191         }
192         auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
193         for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
194           auto outer_out_kernel_in_tensors_iter =
195             std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
196           if ((outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) &&
197               !lite::IsContain(output_tensors, out_kernel_out_tensor)) {
198             output_tensors.push_back(out_kernel_out_tensor);
199           }
200         }
201       }
202     }
203   }
204   return output_tensors;
205 }
206 
InitTensorInitRefCount(const std::vector<KernelExec * > & kernels)207 void KernelExecUtil::InitTensorInitRefCount(const std::vector<KernelExec *> &kernels) {
208   for (auto *kernel : kernels) {
209     kernel->InitOutTensorInitRefCount(&kernels);
210   }
211 }
212 
GetInputsSpecificNode(const KernelExec * kernel,const schema::PrimitiveType & primitive_type)213 KernelExec *KernelExecUtil::GetInputsSpecificNode(const KernelExec *kernel,
214                                                   const schema::PrimitiveType &primitive_type) {
215   for (auto input : kernel->in_kernels()) {
216     if (input->type() == primitive_type) {
217       return input;
218     }
219   }
220   return nullptr;
221 }
222 
InputsContainsSpecificNode(const KernelExec * kernel,const schema::PrimitiveType & primitive_type)223 bool KernelExecUtil::InputsContainsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type) {
224   if (GetInputsSpecificNode(kernel, primitive_type)) {
225     return true;
226   }
227   return false;
228 }
229 
FindAllInoutKernels(const std::vector<KernelExec * > & kernels)230 void KernelExecUtil::FindAllInoutKernels(const std::vector<KernelExec *> &kernels) {
231   std::unordered_map<lite::Tensor *, KernelExec *> tensor_pre_kernel;
232   std::unordered_map<lite::Tensor *, std::vector<KernelExec *>> tensor_post_kernels;
233   for (auto *kernel : kernels) {
234     for (auto *tensor : kernel->out_tensors()) {
235       tensor_pre_kernel[tensor] = kernel;
236     }
237     for (auto *tensor : kernel->in_tensors()) {
238       (tensor_post_kernels[tensor]).push_back(kernel);
239     }
240   }
241 
242   for (auto *kernel : kernels) {
243     kernel->set_in_kernels({});
244     for (auto *tensor : kernel->in_tensors()) {
245       auto iter = tensor_pre_kernel.find(tensor);
246       if (iter != tensor_pre_kernel.end() && kernel != iter->second) {
247         kernel->AddInKernel(iter->second);
248       }
249     }
250     kernel->set_out_kernels({});
251     for (auto *tensor : kernel->out_tensors()) {
252       auto iter = tensor_post_kernels.find(tensor);
253       if (iter != tensor_post_kernels.end()) {
254         for (auto *find_kernel : iter->second) {
255           if (kernel == find_kernel) {
256             continue;
257           }
258           kernel->AddOutKernel(find_kernel);
259         }
260       }
261     }
262   }
263 }
264 
FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec * > & kernels)265 void KernelExecUtil::FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec *> &kernels) {
266   std::vector<KernelExec *> all_kernels;
267   for (auto kernel : kernels) {
268     if (kernel->desc().arch == kDelegate) {
269       all_kernels.push_back(kernel);
270       continue;
271     }
272     auto sub_graph = reinterpret_cast<SubGraphKernel *>(kernel);
273     MS_ASSERT(sub_graph != nullptr);
274     auto kernel_in_subgraph = sub_graph->nodes();
275     (void)all_kernels.insert(all_kernels.end(), kernel_in_subgraph.begin(), kernel_in_subgraph.end());
276   }
277 
278   KernelExecUtil::FindAllInoutKernels(all_kernels);
279 }
280 
FindInKernelForInTensor(const KernelExec * kernel,lite::Tensor * tensor)281 KernelExec *KernelExecUtil::FindInKernelForInTensor(const KernelExec *kernel, lite::Tensor *tensor) {
282   for (auto in_kernel : kernel->in_kernels()) {
283     if (lite::IsContain(in_kernel->out_tensors(), tensor)) {
284       return in_kernel;
285     }
286   }
287   return nullptr;
288 }
289 
FindOutKernelsForOutTensor(const KernelExec * kernel,lite::Tensor * tensor)290 std::vector<KernelExec *> KernelExecUtil::FindOutKernelsForOutTensor(const KernelExec *kernel, lite::Tensor *tensor) {
291   MS_CHECK_TRUE_RET(kernel != nullptr, {});
292   std::vector<KernelExec *> out_kernels;
293   for (auto out_kernel : kernel->out_kernels()) {
294     if (lite::IsContain(out_kernel->in_tensors(), tensor)) {
295       out_kernels.push_back(out_kernel);
296     }
297   }
298   return out_kernels;
299 }
300 
FindInKernelForTensorInSubGraph(lite::Tensor * tensor,SubGraphKernel * graph)301 KernelExec *KernelExecUtil::FindInKernelForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph) {
302   MS_CHECK_TRUE_RET(graph != nullptr, nullptr);
303   auto iter = std::find_if(graph->nodes().begin(), graph->nodes().end(),
304                            [&tensor](const auto &node) { return lite::IsContain(node->out_tensors(), tensor); });
305   if (iter != graph->nodes().end()) {
306     return *iter;
307   }
308   return nullptr;
309 }
310 
FindOutKernelsForTensorInSubGraph(lite::Tensor * tensor,SubGraphKernel * graph)311 std::vector<KernelExec *> KernelExecUtil::FindOutKernelsForTensorInSubGraph(lite::Tensor *tensor,
312                                                                             SubGraphKernel *graph) {
313   MS_CHECK_TRUE_RET(graph != nullptr, {});
314   std::vector<KernelExec *> out_kernels(graph->nodes().size());
315   auto iter = std::copy_if(graph->nodes().begin(), graph->nodes().end(), out_kernels.begin(),
316                            [&tensor](const auto &node) { return lite::IsContain(node->in_tensors(), tensor); });
317   out_kernels.erase(iter, out_kernels.end());
318   return out_kernels;
319 }
320 
SetKernelTensorDataType(const kernel::KernelExec * kernel)321 int KernelExecUtil::SetKernelTensorDataType(const kernel::KernelExec *kernel) {
322   CHECK_NULL_RETURN(kernel);
323   if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
324     return RET_OK;
325   }
326   if (kernel->desc().data_type == kNumberTypeFloat16) {
327     for (auto tensor : kernel->out_tensors()) {
328       if (tensor->data_type() == kNumberTypeFloat32) {
329         tensor->set_data_type(kNumberTypeFloat16);
330       }
331     }
332   } else if (kernel->desc().data_type == kNumberTypeFloat32) {
333     for (auto tensor : kernel->in_tensors()) {
334       if (!tensor->IsConst() && tensor->data_type() == kNumberTypeFloat16) {
335         tensor->set_data_type(kNumberTypeFloat32);
336       }
337     }
338     for (auto tensor : kernel->out_tensors()) {
339       if (tensor->data_type() == kNumberTypeFloat16 && kernel->type() != schema::PrimitiveType_Cast) {
340         tensor->set_data_type(kNumberTypeFloat32);
341       }
342     }
343   }
344   return RET_OK;
345 }
346 
IsOutputSubGraph(const KernelExec * subgraph_kernel)347 bool KernelExecUtil::IsOutputSubGraph(const KernelExec *subgraph_kernel) {
348   MS_CHECK_TRUE_RET(subgraph_kernel != nullptr, false);
349   return !subgraph_kernel->out_tensors().empty() &&
350          std::all_of(subgraph_kernel->out_tensors().begin(), subgraph_kernel->out_tensors().end(),
351                      [](lite::Tensor *tensor) { return tensor->IsGraphOutput(); });
352 }
353 
354 namespace {
CreateCustomSubGraph(std::vector<KernelExec * > && input_kernels,std::vector<KernelExec * > && output_kernels,const std::vector<KernelExec * > & kernels,MSKernel * kernel)355 SubGraphKernel *CreateCustomSubGraph(std::vector<KernelExec *> &&input_kernels,
356                                      std::vector<KernelExec *> &&output_kernels,
357                                      const std::vector<KernelExec *> &kernels, MSKernel *kernel) {
358   auto sub_kernel = new (std::nothrow) CustomSubGraph(input_kernels, output_kernels, kernels, kernel);
359   if (sub_kernel == nullptr) {
360     MS_LOG(ERROR) << "create custom subgraph failed!";
361     return nullptr;
362   }
363   return sub_kernel;
364 }
365 }  // namespace
366 
CreateSubGraphKernel(const std::vector<KernelExec * > & kernels,const std::vector<lite::Tensor * > * in_tensors,const std::vector<lite::Tensor * > * out_tensors,SubGraphType type,const lite::InnerContext & context,int schema_version)367 SubGraphKernel *KernelExecUtil::CreateSubGraphKernel(const std::vector<KernelExec *> &kernels,
368                                                      const std::vector<lite::Tensor *> *in_tensors,
369                                                      const std::vector<lite::Tensor *> *out_tensors, SubGraphType type,
370                                                      const lite::InnerContext &context, int schema_version) {
371   std::vector<lite::Tensor *> input_tensors;
372   std::vector<lite::Tensor *> output_tensors;
373   if (in_tensors != nullptr) {
374     input_tensors = *in_tensors;
375   } else {
376     input_tensors = SubgraphInputTensors(kernels);
377   }
378   if (out_tensors != nullptr) {
379     output_tensors = *out_tensors;
380   } else {
381     output_tensors = SubgraphOutputTensors(kernels);
382   }
383   auto lite_kernel = new (std::nothrow) LiteKernel(nullptr, input_tensors, output_tensors, &context);
384   if (lite_kernel == nullptr) {
385     MS_LOG(ERROR) << "Create subgraph lite-kernel failed.";
386     return nullptr;
387   }
388   std::vector<KernelExec *> input_kernels = SubgraphInputNodes(kernels);
389   std::vector<KernelExec *> output_kernels = SubgraphOutputNodes(kernels);
390   SubGraphKernel *sub_graph = nullptr;
391   switch (type) {
392     case kCpuFP32SubGraph: {
393       sub_graph = new (std::nothrow) CpuFp32SubGraph(input_kernels, output_kernels, kernels, lite_kernel);
394     } break;
395     case kCpuFP16SubGraph: {
396 #ifdef ENABLE_FP16
397       sub_graph = new (std::nothrow) CpuFp16SubGraph(input_kernels, output_kernels, kernels, lite_kernel);
398       for (auto out_tensor : output_tensors) {
399         if (out_tensor->data_type() == kNumberTypeFloat32) {
400           out_tensor->set_data_type(kNumberTypeFloat16);
401         }
402       }
403 #endif
404     } break;
405     case kGpuFp32SubGraph:
406     case kGpuFp16SubGraph: {
407 #if GPU_OPENCL
408       sub_graph = new (std::nothrow) OpenCLSubGraph(input_kernels, output_kernels, kernels, lite_kernel);
409 #endif
410     } break;
411     case kCustomSubGraph: {
412       sub_graph = CreateCustomSubGraph(std::move(input_kernels), std::move(output_kernels), kernels, lite_kernel);
413     } break;
414     case kEntranceSubGraph:
415     case kExitSubGraph: {
416       sub_graph = lite::CreateControlSubgraph(type, lite_kernel);
417     } break;
418     case kAclSubGraph: {
419       sub_graph = new (std::nothrow) AclSubGraph(input_kernels, output_kernels, kernels, lite_kernel);
420     } break;
421     default: {
422       MS_LOG(ERROR) << "not support subgraph type: " << type;
423       delete lite_kernel;
424       return nullptr;
425     }
426   }
427   if (sub_graph == nullptr) {
428     delete lite_kernel;
429     MS_LOG(ERROR) << "create subgraph type " << type << "failed.";
430     return nullptr;
431   }
432   sub_graph->set_context(&context);
433   sub_graph->SetSchemaVersion(schema_version);
434   return sub_graph;
435 }
436 
ReplaceSubGraphNodesInTensor(KernelExec * kernel,const lite::Tensor * old_tensor,lite::Tensor * new_tensor)437 int KernelExecUtil::ReplaceSubGraphNodesInTensor(KernelExec *kernel, const lite::Tensor *old_tensor,
438                                                  lite::Tensor *new_tensor) {
439   CHECK_NULL_RETURN(kernel);
440   int ref_count = 0;
441   /* set op input for calculate */
442   if (kernel->desc().arch == kDelegate) {
443     ref_count++;
444   } else {
445     auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
446     if (subgraph_kernel == nullptr) {
447       MS_LOG(ERROR) << "cast to subgraph kernel failed.";
448       return RET_ERROR;
449     }
450     for (auto in_node : reinterpret_cast<SubGraphKernel *>(kernel)->in_nodes()) {
451       for (size_t node_in_index = 0; node_in_index < in_node->in_tensors().size(); node_in_index++) {
452         if (old_tensor == in_node->in_tensors()[node_in_index]) {
453           in_node->set_in_tensor(new_tensor, node_in_index);
454           ref_count++;
455         }
456       }
457     }
458   }
459   CHECK_NULL_RETURN(new_tensor);
460   new_tensor->set_init_ref_count(ref_count);
461   return RET_OK;
462 }
463 
ReplaceSubGraphNodesOutTensor(KernelExec * kernel,const lite::Tensor * old_tensor,lite::Tensor * new_tensor)464 int KernelExecUtil::ReplaceSubGraphNodesOutTensor(KernelExec *kernel, const lite::Tensor *old_tensor,
465                                                   lite::Tensor *new_tensor) {
466   CHECK_NULL_RETURN(kernel);
467   int ref_count = 0;
468   /* set op output for calculate */
469   if (kernel->desc().arch == kDelegate) {
470     ref_count++;
471   } else {
472     auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
473     if (subgraph_kernel == nullptr) {
474       MS_LOG(ERROR) << "cast to subgraph kernel failed.";
475       return RET_ERROR;
476     }
477     for (auto out_node : reinterpret_cast<SubGraphKernel *>(kernel)->out_nodes()) {
478       for (size_t node_out_index = 0; node_out_index < out_node->out_tensors().size(); node_out_index++) {
479         if (old_tensor == out_node->out_tensors()[node_out_index]) {
480           out_node->set_out_tensor(new_tensor, node_out_index);
481           ref_count++;
482         }
483       }
484     }
485   }
486   CHECK_NULL_RETURN(new_tensor);
487   new_tensor->set_init_ref_count(ref_count);
488   return RET_OK;
489 }
490 
BelongToWhichSubGraph(const std::vector<KernelExec * > & subgraphs,KernelExec * kernel)491 SubGraphKernel *KernelExecUtil::BelongToWhichSubGraph(const std::vector<KernelExec *> &subgraphs, KernelExec *kernel) {
492   for (auto &item : subgraphs) {
493     if (item->subgraph_type() == kernel::kNotSubGraph) {
494       continue;
495     }
496     auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(item);
497     if (subgraph == nullptr) {
498       continue;
499     }
500     if (std::any_of(subgraph->nodes().begin(), subgraph->nodes().end(),
501                     [&kernel](const KernelExec *node) { return node == kernel; })) {
502       return subgraph;
503     }
504   }
505   return nullptr;
506 }
507 
508 #ifndef CONTROLFLOW_TENSORLIST_CLIP
IsSwitchTypeCall(KernelExec * kernel)509 bool KernelExecUtil::IsSwitchTypeCall(KernelExec *kernel) {
510   if (kernel == nullptr) {
511     return false;
512   }
513   if (kernel->desc().arch == kDelegate) {
514     return false;
515   }
516   auto *subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
517   if (subgraph_kernel == nullptr) {
518     return false;
519   }
520   for (auto &node : subgraph_kernel->nodes()) {
521     if ((node->type() == schema::PrimitiveType_Switch || node->type() == schema::PrimitiveType_SwitchLayer) &&
522         InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
523         node->out_kernels().front()->type() == schema::PrimitiveType_Call) {
524       return true;
525     }
526   }
527 
528   return false;
529 }
530 
IsNonTailCall(const KernelExec * node)531 bool KernelExecUtil::IsNonTailCall(const KernelExec *node) {
532   if (node == nullptr) {
533     MS_LOG(ERROR) << "node is nullptr";
534     return false;
535   }
536   auto parameter = reinterpret_cast<CallParameter *>(node->op_parameter());
537   if (parameter == nullptr) {
538     MS_LOG(ERROR) << "Parameter is nullptr";
539     return false;
540   }
541   return node->type() == schema::PrimitiveType_Call && !(parameter->is_tail_call);
542 }
543 
IsTailCall(const KernelExec * node)544 bool KernelExecUtil::IsTailCall(const KernelExec *node) {
545   return node->type() == schema::PrimitiveType_Call &&
546          (reinterpret_cast<CallParameter *>(node->op_parameter())->is_tail_call);
547 }
548 
IsNonTailCallSubGraph(KernelExec * kernel)549 bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) {
550   auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
551   if (subgraph_kernel == nullptr) {
552     return false;
553   }
554   auto nodes = subgraph_kernel->nodes();
555   return std::any_of(nodes.begin(), nodes.end(),
556                      [](const KernelExec *node) { return KernelExecUtil::IsNonTailCall(node); });
557 }
558 
IsTailCallSubGraph(KernelExec * kernel)559 bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) {
560   auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
561   if (subgraph_kernel == nullptr) {
562     return false;
563   }
564   if (IsNonTailCallSubGraph(subgraph_kernel)) {
565     return false;
566   }
567   auto output_nodes = subgraph_kernel->out_nodes();
568   if (std::any_of(output_nodes.begin(), output_nodes.end(), [](const KernelExec *node) { return IsTailCall(node); })) {
569     return true;
570   }
571   return false;
572 }
573 
GetCallInputPartials(const KernelExec * call_node)574 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) {
575   if (call_node == nullptr) {
576     return {};
577   }
578   if (call_node->type() != schema::PrimitiveType_Call) {
579     MS_LOG(ERROR) << "input node is not call node.";
580     return {};
581   }
582   auto call_inputs = call_node->in_kernels();
583   if (call_inputs.size() != 1) {
584     MS_LOG(ERROR) << "call inputs size is: " << call_inputs.size() << ", not is 1.";
585     return {};
586   }
587 
588   std::vector<KernelExec *> partial_nodes{};
589   auto call_input_node = call_inputs.front();
590   switch (SchemaType(call_input_node->type())) {
591     case schema::PrimitiveType_PartialFusion: {
592       partial_nodes.push_back(call_input_node);
593       break;
594     }
595     case schema::PrimitiveType_Switch:
596     case schema::PrimitiveType_SwitchLayer: {
597       auto switch_type_node = call_input_node;
598       for (auto item : switch_type_node->in_kernels()) {
599         if (item->type() == schema::PrimitiveType_PartialFusion) {
600           partial_nodes.push_back(item);
601         }
602       }
603       break;
604     }
605     default: {
606       MS_LOG(ERROR) << "not support call input type is: " << call_input_node->type();
607       return {};
608     }
609   }
610   return partial_nodes;
611 }
612 
GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec * call_node)613 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node) {
614   auto partial_nodes = GetCallInputPartials(call_node);
615   std::vector<KernelExec *> all_subgraphs{};
616   for (auto partial_node : partial_nodes) {
617     auto partial_kernel = reinterpret_cast<PartialFusionKernel *>(partial_node->kernel());
618     if (partial_kernel == nullptr) {
619       MS_LOG(ERROR) << "cast to partial kernel failed.";
620       return all_subgraphs;
621     }
622     // only get the output subgraph, the last subgraph is the output subgraph.
623     auto partial_subgraphs = partial_kernel->subgraph_kernels();
624     all_subgraphs.push_back(partial_subgraphs.back());
625     // exit graph's input graph also need set same output tensor init refcount.
626     if (partial_subgraphs.size() > 1 && partial_subgraphs.back()->subgraph_type() == kExitSubGraph) {
627       auto last_index = partial_subgraphs.size() - 1;
628       all_subgraphs.push_back(partial_subgraphs[last_index - 1]);
629     }
630   }
631   return all_subgraphs;
632 }
633 
GetPartialOutputCall(const KernelExec * partial_node)634 KernelExec *KernelExecUtil::GetPartialOutputCall(const KernelExec *partial_node) {
635   if (partial_node == nullptr) {
636     return nullptr;
637   }
638   if (partial_node->type() != schema::PrimitiveType_PartialFusion) {
639     MS_LOG(ERROR) << "input node is not partial node.";
640     return nullptr;
641   }
642   auto partial_outputs = partial_node->out_kernels();
643   if (partial_outputs.size() != 1) {
644     MS_LOG(ERROR) << "partial outputs size is: " << partial_outputs.size() << ", not is 1.";
645     return nullptr;
646   }
647 
648   KernelExec *call_node = nullptr;
649   auto partial_output_node = partial_outputs.front();
650   switch (SchemaType(partial_output_node->type())) {
651     case schema::PrimitiveType_Call: {
652       call_node = partial_output_node;
653       break;
654     }
655     case schema::PrimitiveType_Switch:
656     case schema::PrimitiveType_SwitchLayer: {
657       auto switch_type_node = partial_output_node;
658       auto switch_outputs = switch_type_node->out_kernels();
659       if (switch_outputs.size() != 1) {
660         MS_LOG(ERROR) << "switch outputs size is: " << switch_outputs.size() << ", not is 1.";
661         return nullptr;
662       }
663       if (switch_outputs.front()->type() == schema::PrimitiveType_Call) {
664         call_node = switch_outputs.front();
665       } else {
666         MS_LOG(ERROR) << "graph is not right, switch output is not call node.";
667         return nullptr;
668       }
669       break;
670     }
671     default: {
672       MS_LOG(ERROR) << "not support partial output type is: " << partial_output_node->type();
673       return nullptr;
674     }
675   }
676   return call_node;
677 }
678 
679 #else
680 
IsSwitchTypeCall(KernelExec * kernel)681 bool KernelExecUtil::IsSwitchTypeCall(KernelExec *kernel) { return false; }
682 
IsNonTailCall(const KernelExec * node)683 bool KernelExecUtil::IsNonTailCall(const KernelExec *node) { return false; }
684 
IsTailCall(const KernelExec * node)685 bool KernelExecUtil::IsTailCall(const KernelExec *node) { return false; }
686 
IsNonTailCallSubGraph(KernelExec * kernel)687 bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) { return false; }
688 
IsTailCallSubGraph(KernelExec * kernel)689 bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) { return false; }
690 
GetCallInputPartials(const KernelExec * call_node)691 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) { return {}; }
692 
GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec * call_node)693 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node) {
694   return {};
695 }
696 
GetPartialOutputCall(const KernelExec * partial_node)697 KernelExec *KernelExecUtil::GetPartialOutputCall(const KernelExec *partial_node) { return nullptr; }
698 
699 #endif
700 }  // namespace mindspore::kernel
701