1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_ 19 #include <vector> 20 #include <set> 21 #include "src/executor/kernel_exec.h" 22 #include "src/executor/sub_graph_kernel.h" 23 #include "src/litert/inner_context.h" 24 25 namespace mindspore::kernel { 26 27 class MS_API KernelExecUtil { 28 public: 29 static std::vector<KernelExec *> SubgraphInputNodes(const std::vector<KernelExec *> &kernels); 30 static std::vector<KernelExec *> SubgraphOutputNodes(const std::vector<KernelExec *> &kernels); 31 static int TopologicalSortNodes(std::vector<KernelExec *> *nodes, std::vector<KernelExec *> in_nodes = {}); 32 static std::vector<lite::Tensor *> SubgraphInputTensors(const std::vector<KernelExec *> &kernels); 33 static std::vector<lite::Tensor *> SubgraphOutputTensors(const std::vector<KernelExec *> &kernels); 34 static void InitTensorInitRefCount(const std::vector<KernelExec *> &kernels); 35 static bool IsSwitchTypeCall(KernelExec *kernel); 36 static bool IsNonTailCall(const KernelExec *node); 37 static bool IsTailCall(const KernelExec *node); 38 static std::vector<KernelExec *> GetCallInputPartials(const KernelExec *call_node); 39 static KernelExec *GetPartialOutputCall(const KernelExec *partial_node); 40 static bool IsNonTailCallSubGraph(KernelExec *kernel); 41 static bool IsTailCallSubGraph(KernelExec *kernel); 42 static std::vector<KernelExec *> GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node); 43 static KernelExec *GetInputsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type); 44 static bool InputsContainsSpecificNode(const KernelExec *kernel, const schema::PrimitiveType &primitive_type); 45 // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph 46 static void FindAllInoutKernels(const std::vector<KernelExec *> &kernels); 47 static void FindAllInoutKernelsInSubgraphKernel(const std::vector<KernelExec *> &kernels); 48 static KernelExec *FindInKernelForInTensor(const KernelExec *kernel, lite::Tensor *tensor); 49 static std::vector<KernelExec *> FindOutKernelsForOutTensor(const KernelExec *kernel, lite::Tensor *tensor); 50 static KernelExec *FindInKernelForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph); 51 static std::vector<KernelExec *> FindOutKernelsForTensorInSubGraph(lite::Tensor *tensor, SubGraphKernel *graph); 52 static int SetKernelTensorDataType(const kernel::KernelExec *kernel); 53 static SubGraphKernel *CreateSubGraphKernel(const std::vector<KernelExec *> &kernels, 54 const std::vector<lite::Tensor *> *in_tensors, 55 const std::vector<lite::Tensor *> *out_tensors, SubGraphType type, 56 const lite::InnerContext &context, int schema_version); 57 static int ReplaceSubGraphNodesInTensor(KernelExec *kernel, const lite::Tensor *old_tensor, lite::Tensor *new_tensor); 58 static int ReplaceSubGraphNodesOutTensor(KernelExec *kernel, const lite::Tensor *old_tensor, 59 lite::Tensor *new_tensor); 60 static bool IsOutputSubGraph(const KernelExec *subgraph_kernel); 61 static SubGraphKernel *BelongToWhichSubGraph(const std::vector<KernelExec *> &subgraphs, KernelExec *kernel); 62 63 private: 64 static std::set<lite::Tensor *> AllOutTensor(const std::vector<KernelExec *> &kernels); 65 }; 66 67 } // namespace mindspore::kernel 68 69 #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_EXEC_UTIL_H_ 70