1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/kernel_exec_util.h"
18 #include <utility>
19 #include <queue>
20 #include <unordered_map>
21 #include <set>
22 #include "src/executor/sub_graph_kernel.h"
23 #include "nnacl/call_parameter.h"
24 #if GPU_OPENCL
25 #include "src/litert/kernel/opencl/opencl_subgraph.h"
26 #include "src/litert/kernel/gpu/opencl/opencl_runtime.h"
27 #endif
28 #include "src/control_flow/control_subgraph_creator.h"
29 #include "src/litert/kernel/cpu/base/partial_fusion.h"
30
31 namespace mindspore::kernel {
32 using mindspore::lite::RET_ERROR;
33 using mindspore::lite::RET_OK;
34
TopologicalSortNodes(std::vector<KernelExec * > * nodes,std::vector<KernelExec * > in_nodes)35 int KernelExecUtil::TopologicalSortNodes(std::vector<KernelExec *> *nodes, std::vector<KernelExec *> in_nodes) {
36 auto old_nodes = *nodes;
37 if (in_nodes.empty()) {
38 in_nodes = KernelExecUtil::SubgraphInputNodes(old_nodes);
39 }
40 nodes->clear();
41 nodes->reserve(old_nodes.size());
42 std::queue<KernelExec *> kernel_queue;
43 for (auto kernel : in_nodes) {
44 if (std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(),
45 [&](KernelExec *in_kernel) { return (!lite::IsContain(old_nodes, in_kernel)); })) {
46 kernel_queue.push(kernel);
47 }
48 }
49
50 while (!kernel_queue.empty()) {
51 auto cur_kernel = kernel_queue.front();
52 (void)nodes->emplace_back(cur_kernel);
53 kernel_queue.pop();
54 if (cur_kernel == nullptr) {
55 MS_LOG(ERROR) << "TopologicalSortKernels failed, nullptr in nodes.";
56 return lite::RET_NULL_PTR;
57 }
58 auto next_kernels = cur_kernel->out_kernels();
59 for (auto next_kernel : next_kernels) {
60 if (!lite::IsContain(old_nodes, next_kernel)) {
61 continue;
62 }
63 if (lite::IsContain(*nodes, next_kernel)) {
64 MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist.";
65 return lite::RET_ERROR;
66 }
67 auto in_kernels = next_kernel->in_kernels();
68 if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](KernelExec *in_kernel) {
69 return lite::IsContain(*nodes, in_kernel) || (!lite::IsContain(old_nodes, in_kernel));
70 })) {
71 kernel_queue.push(next_kernel);
72 }
73 }
74 }
75 if (nodes->size() != old_nodes.size()) {
76 MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_nodes.size()
77 << ", kernels size after sort: " << nodes->size();
78 return lite::RET_ERROR;
79 }
80 return lite::RET_OK;
81 }
82
AllOutTensor(const std::vector<KernelExec * > & kernels)83 std::set<lite::Tensor *> KernelExecUtil::AllOutTensor(const std::vector<KernelExec *> &kernels) {
84 std::set<lite::Tensor *> all_out_tensors{};
85 for (const auto &kernel_in_subgraph : kernels) {
86 for (auto *tensor : kernel_in_subgraph->out_tensors()) {
87 (void)all_out_tensors.insert(tensor);
88 }
89 }
90 return all_out_tensors;
91 }
92
SubgraphInputNodes(const std::vector<KernelExec * > & kernels)93 std::vector<KernelExec *> KernelExecUtil::SubgraphInputNodes(const std::vector<KernelExec *> &kernels) {
94 std::vector<KernelExec *> input_nodes;
95 std::set<lite::Tensor *> all_out_tensors = AllOutTensor(kernels);
96 for (const auto &kernel : kernels) {
97 MS_ASSERT(kernel != nullptr);
98 bool kernel_is_input = false;
99 auto all_input_tensors = kernel->in_tensors();
100 for (auto input : kernel->in_tensors()) {
101 if (input->IsConst()) {
102 continue;
103 }
104 if (all_out_tensors.find(input) != all_out_tensors.end()) {
105 continue;
106 }
107 kernel_is_input = true;
108 break;
109 }
110 if (kernel_is_input && !lite::IsContain(input_nodes, kernel)) {
111 input_nodes.push_back(kernel);
112 }
113 }
114 return input_nodes;
115 }
116
SubgraphOutputNodes(const std::vector<KernelExec * > & kernels)117 std::vector<KernelExec *> KernelExecUtil::SubgraphOutputNodes(const std::vector<KernelExec *> &kernels) {
118 std::set<KernelExec *> all_kernels{};
119 for (const auto &kernel : kernels) {
120 (void)all_kernels.insert(kernel);
121 }
122 std::vector<KernelExec *> output_nodes;
123 // if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output
124 for (const auto &kernel : kernels) {
125 MS_ASSERT(kernel != nullptr);
126 if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) {
127 if (!lite::IsContain(output_nodes, kernel)) {
128 output_nodes.push_back(kernel);
129 }
130 continue;
131 }
132 if (std::any_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
133 [&all_kernels](KernelExec *tmp) { return all_kernels.find(tmp) == all_kernels.end(); }) &&
134 !lite::IsContain(output_nodes, kernel)) {
135 output_nodes.push_back(kernel);
136 }
137 }
138 return output_nodes;
139 }
140
SubgraphInputTensors(const std::vector<KernelExec * > & kernels)141 std::vector<lite::Tensor *> KernelExecUtil::SubgraphInputTensors(const std::vector<KernelExec *> &kernels) {
142 std::vector<lite::Tensor *> input_tensors;
143 std::vector<KernelExec *> input_nodes = SubgraphInputNodes(kernels);
144 for (const auto &input_node : input_nodes) {
145 auto &in_node_in_kernels = input_node->in_kernels();
146 auto &in_node_in_tensors = input_node->in_tensors();
147 for (auto &in_node_in_tensor : in_node_in_tensors) {
148 if (in_node_in_tensor->IsGraphInput() || (in_node_in_kernels.empty() && !in_node_in_tensor->IsConst())) {
149 if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
150 input_tensors.push_back(in_node_in_tensor);
151 }
152 }
153 }
154 for (auto in_node_in_kernel : in_node_in_kernels) {
155 auto iter = std::find(kernels.begin(), kernels.end(), in_node_in_kernel);
156 if (iter != kernels.end()) {
157 continue;
158 }
159 auto &outer_in_kernel_out_tensors = in_node_in_kernel->out_tensors();
160 for (auto in_node_in_tensor : in_node_in_tensors) {
161 auto outer_in_kernel_out_tensors_iter =
162 std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor);
163 if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
164 if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
165 input_tensors.push_back(in_node_in_tensor);
166 }
167 }
168 }
169 }
170 }
171 return input_tensors;
172 }
173
SubgraphOutputTensors(const std::vector<KernelExec * > & kernels)174 std::vector<lite::Tensor *> KernelExecUtil::SubgraphOutputTensors(const std::vector<KernelExec *> &kernels) {
175 std::vector<lite::Tensor *> output_tensors;
176 std::vector<KernelExec *> output_nodes = SubgraphOutputNodes(kernels);
177 for (const auto &output_kernel : output_nodes) {
178 auto &outer_out_kernels = output_kernel->out_kernels();
179 auto &out_kernel_out_tensors = output_kernel->out_tensors();
180 for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
181 if ((out_kernel_out_tensor->IsGraphOutput() || outer_out_kernels.empty()) &&
182 !lite::IsContain(output_tensors, out_kernel_out_tensor)) {
183 output_tensors.push_back(out_kernel_out_tensor);
184 }
185 }
186 if (!outer_out_kernels.empty()) {
187 for (auto outer_out_kernel : outer_out_kernels) {
188 auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
189 if (iter != kernels.end()) {
190 continue;
191 }
192 auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
193 for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
194 auto outer_out_kernel_in_tensors_iter =
195 std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
196 if ((outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) &&
197 !lite::IsContain(output_tensors, out_kernel_out_tensor)) {
198 output_tensors.push_back(out_kernel_out_tensor);
199 }
200 }
201 }
202 }
203 }
204 return output_tensors;
205 }
206
InitTensorInitRefCount(const std::vector<KernelExec * > & kernels)207 void KernelExecUtil::InitTensorInitRefCount(const std::vector<KernelExec *> &kernels) {
208 for (auto *kernel : kernels) {
209 kernel->InitOutTensorInitRefCount(&kernels);
210 }
211 }
212
GetInputsSpecificNode(const KernelExec * kernel,const schema::PrimitiveType & primitive_type)213 KernelExec *KernelExecUtil::GetInputsSpecificNode(const KernelExec *kernel,
214 const schema::PrimitiveType &primitive_type) {
215 for (auto input : kernel->in_kernels()) {
216 if (input->type() == primitive_type) {
217 return input;
218 }
219 }
220 return nullptr;
221 }
222
InputsContainsSpecificNode(const KernelExec * kernel,const schema::PrimitiveType & primitive_type)223 bool KernelExecUtil::InputsContainsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type) {
224 if (GetInputsSpecificNode(kernel, primitive_type)) {
225 return true;
226 }
227 return false;
228 }
229
FindAllInoutKernels(const std::vector<KernelExec * > & kernels)230 void KernelExecUtil::FindAllInoutKernels(const std::vector<KernelExec *> &kernels) {
231 std::unordered_map<lite::Tensor *, KernelExec *> tensor_pre_kernel;
232 std::unordered_map<lite::Tensor *, std::vector<KernelExec *>> tensor_post_kernels;
233 for (auto *kernel : kernels) {
234 for (auto *tensor : kernel->out_tensors()) {
235 tensor_pre_kernel[tensor] = kernel;
236 }
237 for (auto *tensor : kernel->in_tensors()) {
238 (tensor_post_kernels[tensor]).push_back(kernel);
239 }
240 }
241
242 for (auto *kernel : kernels) {
243 kernel->set_in_kernels({});
244 for (auto *tensor : kernel->in_tensors()) {
245 auto iter = tensor_pre_kernel.find(tensor);
246 if (iter != tensor_pre_kernel.end() && kernel != iter->second) {
247 kernel->AddInKernel(iter->second);
248 }
249 }
250 kernel->set_out_kernels({});
251 for (auto *tensor : kernel->out_tensors()) {
252 auto iter = tensor_post_kernels.find(tensor);
253 if (iter != tensor_post_kernels.end()) {
254 for (auto *find_kernel : iter->second) {
255 if (kernel == find_kernel) {
256 continue;
257 }
258 kernel->AddOutKernel(find_kernel);
259 }
260 }
261 }
262 }
263 }
264
FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec * > & kernels)265 void KernelExecUtil::FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec *> &kernels) {
266 std::vector<KernelExec *> all_kernels;
267 for (auto kernel : kernels) {
268 if (kernel->desc().arch == kDelegate) {
269 all_kernels.push_back(kernel);
270 continue;
271 }
272 auto sub_graph = reinterpret_cast<SubGraphKernel *>(kernel);
273 MS_ASSERT(sub_graph != nullptr);
274 auto kernel_in_subgraph = sub_graph->nodes();
275 (void)all_kernels.insert(all_kernels.end(), kernel_in_subgraph.begin(), kernel_in_subgraph.end());
276 }
277
278 KernelExecUtil::FindAllInoutKernels(all_kernels);
279 }
280
FindInKernelForInTensor(const KernelExec * kernel,lite::Tensor * tensor)281 KernelExec *KernelExecUtil::FindInKernelForInTensor(const KernelExec *kernel, lite::Tensor *tensor) {
282 for (auto in_kernel : kernel->in_kernels()) {
283 if (lite::IsContain(in_kernel->out_tensors(), tensor)) {
284 return in_kernel;
285 }
286 }
287 return nullptr;
288 }
289
FindOutKernelsForOutTensor(const KernelExec * kernel,lite::Tensor * tensor)290 std::vector<KernelExec *> KernelExecUtil::FindOutKernelsForOutTensor(const KernelExec *kernel, lite::Tensor *tensor) {
291 MS_CHECK_TRUE_RET(kernel != nullptr, {});
292 std::vector<KernelExec *> out_kernels;
293 for (auto out_kernel : kernel->out_kernels()) {
294 if (lite::IsContain(out_kernel->in_tensors(), tensor)) {
295 out_kernels.push_back(out_kernel);
296 }
297 }
298 return out_kernels;
299 }
300
FindInKernelForTensorInSubGraph(lite::Tensor * tensor,SubGraphKernel * graph)301 KernelExec *KernelExecUtil::FindInKernelForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph) {
302 MS_CHECK_TRUE_RET(graph != nullptr, nullptr);
303 auto iter = std::find_if(graph->nodes().begin(), graph->nodes().end(),
304 [&tensor](const auto &node) { return lite::IsContain(node->out_tensors(), tensor); });
305 if (iter != graph->nodes().end()) {
306 return *iter;
307 }
308 return nullptr;
309 }
310
FindOutKernelsForTensorInSubGraph(lite::Tensor * tensor,SubGraphKernel * graph)311 std::vector<KernelExec *> KernelExecUtil::FindOutKernelsForTensorInSubGraph(lite::Tensor *tensor,
312 SubGraphKernel *graph) {
313 MS_CHECK_TRUE_RET(graph != nullptr, {});
314 std::vector<KernelExec *> out_kernels(graph->nodes().size());
315 auto iter = std::copy_if(graph->nodes().begin(), graph->nodes().end(), out_kernels.begin(),
316 [&tensor](const auto &node) { return lite::IsContain(node->in_tensors(), tensor); });
317 out_kernels.erase(iter, out_kernels.end());
318 return out_kernels;
319 }
320
SetKernelTensorDataType(const kernel::KernelExec * kernel)321 int KernelExecUtil::SetKernelTensorDataType(const kernel::KernelExec *kernel) {
322 CHECK_NULL_RETURN(kernel);
323 if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
324 return RET_OK;
325 }
326 if (kernel->desc().data_type == kNumberTypeFloat16) {
327 for (auto tensor : kernel->out_tensors()) {
328 if (tensor->data_type() == kNumberTypeFloat32) {
329 tensor->set_data_type(kNumberTypeFloat16);
330 }
331 }
332 } else if (kernel->desc().data_type == kNumberTypeFloat32) {
333 for (auto tensor : kernel->in_tensors()) {
334 if (!tensor->IsConst() && tensor->data_type() == kNumberTypeFloat16) {
335 tensor->set_data_type(kNumberTypeFloat32);
336 }
337 }
338 for (auto tensor : kernel->out_tensors()) {
339 if (tensor->data_type() == kNumberTypeFloat16 && kernel->type() != schema::PrimitiveType_Cast) {
340 tensor->set_data_type(kNumberTypeFloat32);
341 }
342 }
343 }
344 return RET_OK;
345 }
346
IsOutputSubGraph(const KernelExec * subgraph_kernel)347 bool KernelExecUtil::IsOutputSubGraph(const KernelExec *subgraph_kernel) {
348 MS_CHECK_TRUE_RET(subgraph_kernel != nullptr, false);
349 return !subgraph_kernel->out_tensors().empty() &&
350 std::all_of(subgraph_kernel->out_tensors().begin(), subgraph_kernel->out_tensors().end(),
351 [](lite::Tensor *tensor) { return tensor->IsGraphOutput(); });
352 }
353
354 namespace {
CreateCustomSubGraph(std::vector<KernelExec * > && input_kernels,std::vector<KernelExec * > && output_kernels,const std::vector<KernelExec * > & kernels,MSKernel * kernel)355 SubGraphKernel *CreateCustomSubGraph(std::vector<KernelExec *> &&input_kernels,
356 std::vector<KernelExec *> &&output_kernels,
357 const std::vector<KernelExec *> &kernels, MSKernel *kernel) {
358 auto sub_kernel = new (std::nothrow) CustomSubGraph(input_kernels, output_kernels, kernels, kernel);
359 if (sub_kernel == nullptr) {
360 MS_LOG(ERROR) << "create custom subgraph failed!";
361 return nullptr;
362 }
363 return sub_kernel;
364 }
365 } // namespace
366
CreateSubGraphKernel(const std::vector<KernelExec * > & kernels,const std::vector<lite::Tensor * > * in_tensors,const std::vector<lite::Tensor * > * out_tensors,SubGraphType type,const lite::InnerContext & context,int schema_version)367 SubGraphKernel *KernelExecUtil::CreateSubGraphKernel(const std::vector<KernelExec *> &kernels,
368 const std::vector<lite::Tensor *> *in_tensors,
369 const std::vector<lite::Tensor *> *out_tensors, SubGraphType type,
370 const lite::InnerContext &context, int schema_version) {
371 std::vector<lite::Tensor *> input_tensors;
372 std::vector<lite::Tensor *> output_tensors;
373 if (in_tensors != nullptr) {
374 input_tensors = *in_tensors;
375 } else {
376 input_tensors = SubgraphInputTensors(kernels);
377 }
378 if (out_tensors != nullptr) {
379 output_tensors = *out_tensors;
380 } else {
381 output_tensors = SubgraphOutputTensors(kernels);
382 }
383 auto lite_kernel = new (std::nothrow) LiteKernel(nullptr, input_tensors, output_tensors, &context);
384 if (lite_kernel == nullptr) {
385 MS_LOG(ERROR) << "Create subgraph lite-kernel failed.";
386 return nullptr;
387 }
388 std::vector<KernelExec *> input_kernels = SubgraphInputNodes(kernels);
389 std::vector<KernelExec *> output_kernels = SubgraphOutputNodes(kernels);
390 SubGraphKernel *sub_graph = nullptr;
391 switch (type) {
392 case kCpuFP32SubGraph: {
393 sub_graph = new (std::nothrow) CpuFp32SubGraph(input_kernels, output_kernels, kernels, lite_kernel);
394 } break;
395 case kCpuFP16SubGraph: {
396 #ifdef ENABLE_FP16
397 sub_graph = new (std::nothrow) CpuFp16SubGraph(input_kernels, output_kernels, kernels, lite_kernel);
398 for (auto out_tensor : output_tensors) {
399 if (out_tensor->data_type() == kNumberTypeFloat32) {
400 out_tensor->set_data_type(kNumberTypeFloat16);
401 }
402 }
403 #endif
404 } break;
405 case kGpuFp32SubGraph:
406 case kGpuFp16SubGraph: {
407 #if GPU_OPENCL
408 sub_graph = new (std::nothrow) OpenCLSubGraph(input_kernels, output_kernels, kernels, lite_kernel);
409 #endif
410 } break;
411 case kCustomSubGraph: {
412 sub_graph = CreateCustomSubGraph(std::move(input_kernels), std::move(output_kernels), kernels, lite_kernel);
413 } break;
414 case kEntranceSubGraph:
415 case kExitSubGraph: {
416 sub_graph = lite::CreateControlSubgraph(type, lite_kernel);
417 } break;
418 case kAclSubGraph: {
419 sub_graph = new (std::nothrow) AclSubGraph(input_kernels, output_kernels, kernels, lite_kernel);
420 } break;
421 default: {
422 MS_LOG(ERROR) << "not support subgraph type: " << type;
423 delete lite_kernel;
424 return nullptr;
425 }
426 }
427 if (sub_graph == nullptr) {
428 delete lite_kernel;
429 MS_LOG(ERROR) << "create subgraph type " << type << "failed.";
430 return nullptr;
431 }
432 sub_graph->set_context(&context);
433 sub_graph->SetSchemaVersion(schema_version);
434 return sub_graph;
435 }
436
ReplaceSubGraphNodesInTensor(KernelExec * kernel,const lite::Tensor * old_tensor,lite::Tensor * new_tensor)437 int KernelExecUtil::ReplaceSubGraphNodesInTensor(KernelExec *kernel, const lite::Tensor *old_tensor,
438 lite::Tensor *new_tensor) {
439 CHECK_NULL_RETURN(kernel);
440 int ref_count = 0;
441 /* set op input for calculate */
442 if (kernel->desc().arch == kDelegate) {
443 ref_count++;
444 } else {
445 auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
446 if (subgraph_kernel == nullptr) {
447 MS_LOG(ERROR) << "cast to subgraph kernel failed.";
448 return RET_ERROR;
449 }
450 for (auto in_node : reinterpret_cast<SubGraphKernel *>(kernel)->in_nodes()) {
451 for (size_t node_in_index = 0; node_in_index < in_node->in_tensors().size(); node_in_index++) {
452 if (old_tensor == in_node->in_tensors()[node_in_index]) {
453 in_node->set_in_tensor(new_tensor, node_in_index);
454 ref_count++;
455 }
456 }
457 }
458 }
459 CHECK_NULL_RETURN(new_tensor);
460 new_tensor->set_init_ref_count(ref_count);
461 return RET_OK;
462 }
463
ReplaceSubGraphNodesOutTensor(KernelExec * kernel,const lite::Tensor * old_tensor,lite::Tensor * new_tensor)464 int KernelExecUtil::ReplaceSubGraphNodesOutTensor(KernelExec *kernel, const lite::Tensor *old_tensor,
465 lite::Tensor *new_tensor) {
466 CHECK_NULL_RETURN(kernel);
467 int ref_count = 0;
468 /* set op output for calculate */
469 if (kernel->desc().arch == kDelegate) {
470 ref_count++;
471 } else {
472 auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
473 if (subgraph_kernel == nullptr) {
474 MS_LOG(ERROR) << "cast to subgraph kernel failed.";
475 return RET_ERROR;
476 }
477 for (auto out_node : reinterpret_cast<SubGraphKernel *>(kernel)->out_nodes()) {
478 for (size_t node_out_index = 0; node_out_index < out_node->out_tensors().size(); node_out_index++) {
479 if (old_tensor == out_node->out_tensors()[node_out_index]) {
480 out_node->set_out_tensor(new_tensor, node_out_index);
481 ref_count++;
482 }
483 }
484 }
485 }
486 CHECK_NULL_RETURN(new_tensor);
487 new_tensor->set_init_ref_count(ref_count);
488 return RET_OK;
489 }
490
BelongToWhichSubGraph(const std::vector<KernelExec * > & subgraphs,KernelExec * kernel)491 SubGraphKernel *KernelExecUtil::BelongToWhichSubGraph(const std::vector<KernelExec *> &subgraphs, KernelExec *kernel) {
492 for (auto &item : subgraphs) {
493 if (item->subgraph_type() == kernel::kNotSubGraph) {
494 continue;
495 }
496 auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(item);
497 if (subgraph == nullptr) {
498 continue;
499 }
500 if (std::any_of(subgraph->nodes().begin(), subgraph->nodes().end(),
501 [&kernel](const KernelExec *node) { return node == kernel; })) {
502 return subgraph;
503 }
504 }
505 return nullptr;
506 }
507
508 #ifndef CONTROLFLOW_TENSORLIST_CLIP
IsSwitchTypeCall(KernelExec * kernel)509 bool KernelExecUtil::IsSwitchTypeCall(KernelExec *kernel) {
510 if (kernel == nullptr) {
511 return false;
512 }
513 if (kernel->desc().arch == kDelegate) {
514 return false;
515 }
516 auto *subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
517 if (subgraph_kernel == nullptr) {
518 return false;
519 }
520 for (auto &node : subgraph_kernel->nodes()) {
521 if ((node->type() == schema::PrimitiveType_Switch || node->type() == schema::PrimitiveType_SwitchLayer) &&
522 InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
523 node->out_kernels().front()->type() == schema::PrimitiveType_Call) {
524 return true;
525 }
526 }
527
528 return false;
529 }
530
IsNonTailCall(const KernelExec * node)531 bool KernelExecUtil::IsNonTailCall(const KernelExec *node) {
532 if (node == nullptr) {
533 MS_LOG(ERROR) << "node is nullptr";
534 return false;
535 }
536 auto parameter = reinterpret_cast<CallParameter *>(node->op_parameter());
537 if (parameter == nullptr) {
538 MS_LOG(ERROR) << "Parameter is nullptr";
539 return false;
540 }
541 return node->type() == schema::PrimitiveType_Call && !(parameter->is_tail_call);
542 }
543
IsTailCall(const KernelExec * node)544 bool KernelExecUtil::IsTailCall(const KernelExec *node) {
545 return node->type() == schema::PrimitiveType_Call &&
546 (reinterpret_cast<CallParameter *>(node->op_parameter())->is_tail_call);
547 }
548
IsNonTailCallSubGraph(KernelExec * kernel)549 bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) {
550 auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
551 if (subgraph_kernel == nullptr) {
552 return false;
553 }
554 auto nodes = subgraph_kernel->nodes();
555 return std::any_of(nodes.begin(), nodes.end(),
556 [](const KernelExec *node) { return KernelExecUtil::IsNonTailCall(node); });
557 }
558
IsTailCallSubGraph(KernelExec * kernel)559 bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) {
560 auto subgraph_kernel = reinterpret_cast<SubGraphKernel *>(kernel);
561 if (subgraph_kernel == nullptr) {
562 return false;
563 }
564 if (IsNonTailCallSubGraph(subgraph_kernel)) {
565 return false;
566 }
567 auto output_nodes = subgraph_kernel->out_nodes();
568 if (std::any_of(output_nodes.begin(), output_nodes.end(), [](const KernelExec *node) { return IsTailCall(node); })) {
569 return true;
570 }
571 return false;
572 }
573
GetCallInputPartials(const KernelExec * call_node)574 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) {
575 if (call_node == nullptr) {
576 return {};
577 }
578 if (call_node->type() != schema::PrimitiveType_Call) {
579 MS_LOG(ERROR) << "input node is not call node.";
580 return {};
581 }
582 auto call_inputs = call_node->in_kernels();
583 if (call_inputs.size() != 1) {
584 MS_LOG(ERROR) << "call inputs size is: " << call_inputs.size() << ", not is 1.";
585 return {};
586 }
587
588 std::vector<KernelExec *> partial_nodes{};
589 auto call_input_node = call_inputs.front();
590 switch (SchemaType(call_input_node->type())) {
591 case schema::PrimitiveType_PartialFusion: {
592 partial_nodes.push_back(call_input_node);
593 break;
594 }
595 case schema::PrimitiveType_Switch:
596 case schema::PrimitiveType_SwitchLayer: {
597 auto switch_type_node = call_input_node;
598 for (auto item : switch_type_node->in_kernels()) {
599 if (item->type() == schema::PrimitiveType_PartialFusion) {
600 partial_nodes.push_back(item);
601 }
602 }
603 break;
604 }
605 default: {
606 MS_LOG(ERROR) << "not support call input type is: " << call_input_node->type();
607 return {};
608 }
609 }
610 return partial_nodes;
611 }
612
GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec * call_node)613 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node) {
614 auto partial_nodes = GetCallInputPartials(call_node);
615 std::vector<KernelExec *> all_subgraphs{};
616 for (auto partial_node : partial_nodes) {
617 auto partial_kernel = reinterpret_cast<PartialFusionKernel *>(partial_node->kernel());
618 if (partial_kernel == nullptr) {
619 MS_LOG(ERROR) << "cast to partial kernel failed.";
620 return all_subgraphs;
621 }
622 // only get the output subgraph, the last subgraph is the output subgraph.
623 auto partial_subgraphs = partial_kernel->subgraph_kernels();
624 all_subgraphs.push_back(partial_subgraphs.back());
625 // exit graph's input graph also need set same output tensor init refcount.
626 if (partial_subgraphs.size() > 1 && partial_subgraphs.back()->subgraph_type() == kExitSubGraph) {
627 auto last_index = partial_subgraphs.size() - 1;
628 all_subgraphs.push_back(partial_subgraphs[last_index - 1]);
629 }
630 }
631 return all_subgraphs;
632 }
633
GetPartialOutputCall(const KernelExec * partial_node)634 KernelExec *KernelExecUtil::GetPartialOutputCall(const KernelExec *partial_node) {
635 if (partial_node == nullptr) {
636 return nullptr;
637 }
638 if (partial_node->type() != schema::PrimitiveType_PartialFusion) {
639 MS_LOG(ERROR) << "input node is not partial node.";
640 return nullptr;
641 }
642 auto partial_outputs = partial_node->out_kernels();
643 if (partial_outputs.size() != 1) {
644 MS_LOG(ERROR) << "partial outputs size is: " << partial_outputs.size() << ", not is 1.";
645 return nullptr;
646 }
647
648 KernelExec *call_node = nullptr;
649 auto partial_output_node = partial_outputs.front();
650 switch (SchemaType(partial_output_node->type())) {
651 case schema::PrimitiveType_Call: {
652 call_node = partial_output_node;
653 break;
654 }
655 case schema::PrimitiveType_Switch:
656 case schema::PrimitiveType_SwitchLayer: {
657 auto switch_type_node = partial_output_node;
658 auto switch_outputs = switch_type_node->out_kernels();
659 if (switch_outputs.size() != 1) {
660 MS_LOG(ERROR) << "switch outputs size is: " << switch_outputs.size() << ", not is 1.";
661 return nullptr;
662 }
663 if (switch_outputs.front()->type() == schema::PrimitiveType_Call) {
664 call_node = switch_outputs.front();
665 } else {
666 MS_LOG(ERROR) << "graph is not right, switch output is not call node.";
667 return nullptr;
668 }
669 break;
670 }
671 default: {
672 MS_LOG(ERROR) << "not support partial output type is: " << partial_output_node->type();
673 return nullptr;
674 }
675 }
676 return call_node;
677 }
678
679 #else
680
IsSwitchTypeCall(KernelExec * kernel)681 bool KernelExecUtil::IsSwitchTypeCall(KernelExec *kernel) { return false; }
682
IsNonTailCall(const KernelExec * node)683 bool KernelExecUtil::IsNonTailCall(const KernelExec *node) { return false; }
684
IsTailCall(const KernelExec * node)685 bool KernelExecUtil::IsTailCall(const KernelExec *node) { return false; }
686
IsNonTailCallSubGraph(KernelExec * kernel)687 bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) { return false; }
688
IsTailCallSubGraph(KernelExec * kernel)689 bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) { return false; }
690
GetCallInputPartials(const KernelExec * call_node)691 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) { return {}; }
692
GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec * call_node)693 std::vector<KernelExec *> KernelExecUtil::GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node) {
694 return {};
695 }
696
GetPartialOutputCall(const KernelExec * partial_node)697 KernelExec *KernelExecUtil::GetPartialOutputCall(const KernelExec *partial_node) { return nullptr; }
698
699 #endif
700 } // namespace mindspore::kernel
701