1 2 /** 3 * Copyright 2021 Huawei Technologies Co., Ltd 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <sstream> 24 #include <string> 25 #include <tuple> 26 #include <vector> 27 28 #include "base/base.h" 29 #include "backend/session/anf_runtime_algorithm.h" 30 #include "backend/optimizer/common/optimizer.h" 31 #include "backend/optimizer/graph_kernel/parallel_cost_model.h" 32 #include "backend/session/kernel_graph.h" 33 #include "utils/ms_context.h" 34 35 namespace mindspore { 36 namespace opt { 37 class ParallelInfo { 38 public: 39 ParallelInfo() = default; ParallelInfo(const AnfNodePtrList & nodes,const std::vector<DimInfoPtr> & dims,const FusionInfoPtr & fusion_info)40 ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info) 41 : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {} ParallelInfo(const ParallelInfo & obj)42 ParallelInfo(const ParallelInfo &obj) { 43 nodes_ = obj.nodes_; 44 dims_ = obj.dims_; 45 fusion_info_ = obj.fusion_info_; 46 } 47 ~ParallelInfo() = default; 48 GetSize()49 size_t GetSize() const { 50 if (nodes_.size() != dims_.size()) { 51 MS_LOG(EXCEPTION) << "Internal error in parallel info!"; 52 } 53 return nodes_.size(); 54 } nodes()55 const AnfNodePtrList &nodes() const { return nodes_; } dims()56 const std::vector<DimInfoPtr> &dims() const { return dims_; } fusion_info()57 const FusionInfoPtr &fusion_info() const { return fusion_info_; } 58 59 private: 60 AnfNodePtrList nodes_; 61 std::vector<DimInfoPtr> dims_; 62 FusionInfoPtr fusion_info_; 63 }; 64 65 class ParallelConfig { 66 public: 67 ParallelConfig() = default; ParallelConfig(size_t max_n)68 explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} ParallelConfig(const ParallelConfig & obj)69 explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; } 70 ~ParallelConfig() = default; max_num_for_fuse()71 size_t max_num_for_fuse() const { return max_num_for_fuse_; } 72 73 private: 74 size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. 75 }; 76 77 struct NodeRelation { 78 public: NodeRelationNodeRelation79 NodeRelation() {} 80 ~NodeRelation() = default; 81 OrderedSet<AnfNodePtr> pres; 82 OrderedSet<AnfNodePtr> nexts; 83 }; 84 85 class ParallelOpFusion : public Pass { 86 public: ParallelOpFusion(const std::string & target,const ParallelConfig & config)87 ParallelOpFusion(const std::string &target, const ParallelConfig &config) 88 : Pass("parallel_fusion"), target_(target), config_(config) {} 89 ~ParallelOpFusion() override = default; 90 bool Run(const FuncGraphPtr &graph) override; 91 92 private: 93 std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets, 94 const std::vector<bool> &used, 95 const AnfNodePtrList &nodes, 96 const std::set<int> &excludes); 97 98 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates( 99 size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices, 100 std::map<AnfNodePtr, int> *sorted_indices); 101 102 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs); 103 104 void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group, 105 std::vector<ParallelInfo> *parallel_infos); 106 107 std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); 108 109 void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info); 110 111 void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); 112 113 bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, 114 const std::shared_ptr<session::KernelGraph> &kernel_graph); 115 116 OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes); 117 std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels); 118 119 std::string target_; 120 ParallelConfig config_; 121 ParallelCostModelPtr cost_model_ptr_; 122 std::set<AnfNodePtr> virtual_noout_nodes_; 123 std::set<AnfNodePtr> ignore_noin_nodes_; 124 }; 125 using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>; 126 } // namespace opt 127 } // namespace mindspore 128 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ 129