• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_
19 #include <vector>
20 #include <set>
21 #include "src/executor/kernel_exec.h"
22 #include "src/executor/sub_graph_kernel.h"
23 #include "src/litert/inner_context.h"
24 
25 namespace mindspore::kernel {
26 
27 class MS_API KernelExecUtil {
28  public:
29   static std::vector<KernelExec *> SubgraphInputNodes(const std::vector<KernelExec *> &kernels);
30   static std::vector<KernelExec *> SubgraphOutputNodes(const std::vector<KernelExec *> &kernels);
31   static int TopologicalSortNodes(std::vector<KernelExec *> *nodes, std::vector<KernelExec *> in_nodes = {});
32   static std::vector<lite::Tensor *> SubgraphInputTensors(const std::vector<KernelExec *> &kernels);
33   static std::vector<lite::Tensor *> SubgraphOutputTensors(const std::vector<KernelExec *> &kernels);
34   static void InitTensorInitRefCount(const std::vector<KernelExec *> &kernels);
35   static bool IsSwitchTypeCall(KernelExec *kernel);
36   static bool IsNonTailCall(const KernelExec *node);
37   static bool IsTailCall(const KernelExec *node);
38   static std::vector<KernelExec *> GetCallInputPartials(const KernelExec *call_node);
39   static KernelExec *GetPartialOutputCall(const KernelExec *partial_node);
40   static bool IsNonTailCallSubGraph(KernelExec *kernel);
41   static bool IsTailCallSubGraph(KernelExec *kernel);
42   static std::vector<KernelExec *> GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node);
43   static KernelExec *GetInputsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type);
44   static bool InputsContainsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type);
45   // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph
46   static void FindAllInoutKernels(const std::vector<KernelExec *> &kernels);
47   static void FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec *> &kernels);
48   static KernelExec *FindInKernelForInTensor(const KernelExec *kernel, lite::Tensor *tensor);
49   static std::vector<KernelExec *> FindOutKernelsForOutTensor(const KernelExec *kernel, lite::Tensor *tensor);
50   static KernelExec *FindInKernelForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph);
51   static std::vector<KernelExec *> FindOutKernelsForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph);
52   static int SetKernelTensorDataType(const kernel::KernelExec *kernel);
53   static SubGraphKernel *CreateSubGraphKernel(const std::vector<KernelExec *> &kernels,
54                                               const std::vector<lite::Tensor *> *in_tensors,
55                                               const std::vector<lite::Tensor *> *out_tensors, SubGraphType type,
56                                               const lite::InnerContext &context, int schema_version);
57   static int ReplaceSubGraphNodesInTensor(KernelExec *kernel, const lite::Tensor *old_tensor, lite::Tensor *new_tensor);
58   static int ReplaceSubGraphNodesOutTensor(KernelExec *kernel, const lite::Tensor *old_tensor,
59                                            lite::Tensor *new_tensor);
60   static bool IsOutputSubGraph(const KernelExec *subgraph_kernel);
61   static SubGraphKernel *BelongToWhichSubGraph(const std::vector<KernelExec *> &subgraphs, KernelExec *kernel);
62 
63  private:
64   static std::set<lite::Tensor *> AllOutTensor(const std::vector<KernelExec *> &kernels);
65 };
66 
67 }  // namespace mindspore::kernel
68 
69 #endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_
70