1 2 /** 3 * Copyright 2021-2022 Huawei Technologies Co., Ltd 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <sstream> 24 #include <string> 25 #include <tuple> 26 #include <vector> 27 28 #include "base/base.h" 29 #include "include/backend/anf_runtime_algorithm.h" 30 #include "include/common/utils/anfalgo.h" 31 #include "include/backend/optimizer/optimizer.h" 32 #include "backend/common/graph_kernel/parallel_cost_model.h" 33 #include "include/backend/kernel_graph.h" 34 #include "utils/ms_context.h" 35 36 namespace mindspore::graphkernel { 37 class ParallelInfo { 38 public: 39 ParallelInfo() = default; ParallelInfo(const AnfNodePtrList & nodes,const std::vector<DimInfoPtr> & dims,const FusionInfoPtr & fusion_info)40 ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info) 41 : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {} 42 ~ParallelInfo() = default; 43 GetSize()44 size_t GetSize() const { 45 if (nodes_.size() != dims_.size()) { 46 MS_LOG(EXCEPTION) << "Internal error in parallel info! nodes' size is different from dims' size: " 47 << nodes_.size() << " vs " << dims_.size(); 48 } 49 return nodes_.size(); 50 } nodes()51 const AnfNodePtrList &nodes() const { return nodes_; } dims()52 const std::vector<DimInfoPtr> &dims() const { return dims_; } fusion_info()53 const FusionInfoPtr &fusion_info() const { return fusion_info_; } 54 55 private: 56 AnfNodePtrList nodes_; 57 std::vector<DimInfoPtr> dims_; 58 FusionInfoPtr fusion_info_; 59 }; 60 61 class ParallelConfig { 62 public: 63 ParallelConfig() = default; ParallelConfig(size_t max_n)64 explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} 65 ~ParallelConfig() = default; max_num_for_fuse()66 size_t max_num_for_fuse() const { return max_num_for_fuse_; } 67 68 private: 69 size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. 70 }; 71 72 struct NodeRelation { 73 public: NodeRelationNodeRelation74 NodeRelation() {} 75 ~NodeRelation() = default; 76 OrderedSet<AnfNodePtr> pres; 77 OrderedSet<AnfNodePtr> nexts; 78 }; 79 80 class ParallelOpFusion : public opt::Pass { 81 public: ParallelOpFusion(const std::string & target,const ParallelConfig & config)82 ParallelOpFusion(const std::string &target, const ParallelConfig &config) 83 : Pass("parallel_fusion"), target_(target), config_(config) {} 84 ~ParallelOpFusion() override = default; 85 bool Run(const FuncGraphPtr &graph) override; 86 87 private: 88 std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets, 89 const std::vector<bool> &used, 90 const AnfNodePtrList &nodes, 91 const std::set<int> &excludes) const; 92 93 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates( 94 size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices, 95 std::map<AnfNodePtr, int> *sorted_indices); 96 97 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs); 98 99 void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group, 100 std::vector<ParallelInfo> *parallel_infos); 101 102 std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); 103 104 void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info); 105 106 void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); 107 108 bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, 109 const std::shared_ptr<session::KernelGraph> &kernel_graph); 110 111 OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes); 112 std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels); 113 114 std::string target_; 115 ParallelConfig config_; 116 ParallelCostModelPtr cost_model_ptr_; 117 std::set<AnfNodePtr> virtual_noout_nodes_; 118 std::set<AnfNodePtr> ignore_noin_nodes_; 119 unsigned int parallel_level_{0}; 120 }; 121 using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>; 122 } // namespace mindspore::graphkernel 123 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 124