• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <sstream>
24 #include <string>
25 #include <tuple>
26 #include <vector>
27 
28 #include "base/base.h"
29 #include "backend/session/anf_runtime_algorithm.h"
30 #include "backend/optimizer/common/optimizer.h"
31 #include "backend/optimizer/graph_kernel/parallel_cost_model.h"
32 #include "backend/session/kernel_graph.h"
33 #include "utils/ms_context.h"
34 
35 namespace mindspore {
36 namespace opt {
37 class ParallelInfo {
38  public:
39   ParallelInfo() = default;
ParallelInfo(const AnfNodePtrList & nodes,const std::vector<DimInfoPtr> & dims,const FusionInfoPtr & fusion_info)40   ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info)
41       : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {}
ParallelInfo(const ParallelInfo & obj)42   ParallelInfo(const ParallelInfo &obj) {
43     nodes_ = obj.nodes_;
44     dims_ = obj.dims_;
45     fusion_info_ = obj.fusion_info_;
46   }
47   ~ParallelInfo() = default;
48 
GetSize()49   size_t GetSize() const {
50     if (nodes_.size() != dims_.size()) {
51       MS_LOG(EXCEPTION) << "Internal error in parallel info!";
52     }
53     return nodes_.size();
54   }
nodes()55   const AnfNodePtrList &nodes() const { return nodes_; }
dims()56   const std::vector<DimInfoPtr> &dims() const { return dims_; }
fusion_info()57   const FusionInfoPtr &fusion_info() const { return fusion_info_; }
58 
59  private:
60   AnfNodePtrList nodes_;
61   std::vector<DimInfoPtr> dims_;
62   FusionInfoPtr fusion_info_;
63 };
64 
65 class ParallelConfig {
66  public:
67   ParallelConfig() = default;
ParallelConfig(size_t max_n)68   explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {}
ParallelConfig(const ParallelConfig & obj)69   explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; }
70   ~ParallelConfig() = default;
max_num_for_fuse()71   size_t max_num_for_fuse() const { return max_num_for_fuse_; }
72 
73  private:
74   size_t max_num_for_fuse_{10};  // Too many nodes to fuse together may produce bad result.
75 };
76 
77 struct NodeRelation {
78  public:
NodeRelationNodeRelation79   NodeRelation() {}
80   ~NodeRelation() = default;
81   OrderedSet<AnfNodePtr> pres;
82   OrderedSet<AnfNodePtr> nexts;
83 };
84 
85 class ParallelOpFusion : public Pass {
86  public:
ParallelOpFusion(const std::string & target,const ParallelConfig & config)87   ParallelOpFusion(const std::string &target, const ParallelConfig &config)
88       : Pass("parallel_fusion"), target_(target), config_(config) {}
89   ~ParallelOpFusion() override = default;
90   bool Run(const FuncGraphPtr &graph) override;
91 
92  private:
93   std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets,
94                                                                          const std::vector<bool> &used,
95                                                                          const AnfNodePtrList &nodes,
96                                                                          const std::set<int> &excludes);
97 
98   std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates(
99     size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
100     std::map<AnfNodePtr, int> *sorted_indices);
101 
102   std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs);
103 
104   void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
105                                       std::vector<ParallelInfo> *parallel_infos);
106 
107   std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
108 
109   void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info);
110 
111   void SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info);
112 
113   bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
114                                  const std::shared_ptr<session::KernelGraph> &kernel_graph);
115 
116   OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes);
117   std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels);
118 
119   std::string target_;
120   ParallelConfig config_;
121   ParallelCostModelPtr cost_model_ptr_;
122   std::set<AnfNodePtr> virtual_noout_nodes_;
123   std::set<AnfNodePtr> ignore_noin_nodes_;
124 };
125 using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
126 }  // namespace opt
127 }  // namespace mindspore
128 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
129