1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_GRAPH_KERNEL_BUILD_H_ 17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_GRAPH_KERNEL_BUILD_H_ 18 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <sstream> 23 #include <unordered_map> 24 #include <map> 25 #include "ir/anf.h" 26 #include "include/backend/optimizer/optimizer.h" 27 #include "kernel/framework_utils.h" 28 #include "kernel/kernel.h" 29 #include "kernel/kash/kernel_pack.h" 30 #include "kernel/graph_kernel/graph_kernel_builder.h" 31 #include "backend/common/graph_kernel/core/graph_kernel_splitter.h" 32 #include "backend/common/graph_kernel/adapter/graph_kernel_splitter_with_py.h" 33 34 namespace mindspore { 35 namespace graphkernel { 36 class SafeSplitSchemer : public CommonSplitSchemer { 37 public: 38 SafeSplitSchemer() = default; 39 ~SafeSplitSchemer() = default; 40 bool Split(const FuncGraphPtr &func_graph) override; 41 42 protected: 43 void Run(const FuncGraphPtr &func_graph); 44 void SplitNodes(const FuncGraphPtr &func_graph); 45 }; 46 47 class SafeGraphKernelSplitter : public GraphKernelSplitter { 48 public: 49 SafeGraphKernelSplitter() = default; 50 ~SafeGraphKernelSplitter() = default; GetSplitSchema(const std::string &)51 std::shared_ptr<SplitSchemer> GetSplitSchema(const std::string &) override { 52 return std::make_shared<SafeSplitSchemer>(); 53 } 54 }; 55 56 class KernelCompilerGraphKernelSplitter : public GraphKernelSplitter { 57 public: 58 KernelCompilerGraphKernelSplitter() = default; 59 ~KernelCompilerGraphKernelSplitter() = default; GetSplitSchema(const std::string &)60 std::shared_ptr<SplitSchemer> GetSplitSchema(const std::string &) override { 61 return std::make_shared<SplitByJsonSchemer>(address_node_map_, json_desc_str_); 62 } 63 SetAddressNodeMap(const std::map<std::string,AnfNodePtr> & address_node_map)64 void SetAddressNodeMap(const std::map<std::string, AnfNodePtr> &address_node_map) { 65 address_node_map_ = address_node_map; 66 } 67 SetJson(const std::string & json_desc_str)68 void SetJson(const std::string &json_desc_str) { json_desc_str_ = json_desc_str; } 69 std::map<std::string, AnfNodePtr> address_node_map_; 70 std::string json_desc_str_; 71 }; 72 73 class GraphKernelBuild : public opt::Pass { 74 public: GraphKernelBuild()75 GraphKernelBuild() : Pass("graph_kernel_build") {} 76 ~GraphKernelBuild() override = default; 77 bool Run(const FuncGraphPtr &func_graph) override; 78 79 private: 80 void Init(); 81 bool Process(const FuncGraphPtr &func_graph, int iter); 82 kernel::JsonNodePair CollectNode(const AnfNodePtr &node) const; 83 // Collect graph kernel nodes in main graph. 84 void CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) const; 85 // Collect graph kernel nodes that do not have compile cache, which means these nodes need to be compiled. 86 std::vector<kernel::JsonNodePair> CollectNotCachedNodes(const std::vector<kernel::JsonNodePair> &nodes); 87 // Parallel compiling. 88 void ParallelBuild(const std::vector<kernel::JsonNodePair> &nodes); 89 // Split nodes that compiled failed. 90 bool SplitNodes(const std::vector<kernel::JsonNodePair> &nodes); 91 // Split nodes that compiled failed. 92 bool SplitNodesByKernelCompiler(const std::vector<kernel::JsonNodePair> &nodes); 93 94 SafeGraphKernelSplitter splitter_; // used to split nodes that compile failed 95 kernel::KernelMeta *bin_map_{nullptr}; 96 std::shared_ptr<kernel::GraphKernelBuilder> kernel_builder_{nullptr}; 97 std::unordered_map<std::string, kernel::KernelPackPtr> kernel_pack_; // compile cache 98 }; 99 } // namespace graphkernel 100 } // namespace mindspore 101 #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_GRAPH_KERNEL_BUILD_H_ 102