• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021-2022 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <sstream>
24 #include <string>
25 #include <tuple>
26 #include <vector>
27 
28 #include "base/base.h"
29 #include "include/backend/anf_runtime_algorithm.h"
30 #include "include/common/utils/anfalgo.h"
31 #include "include/backend/optimizer/optimizer.h"
32 #include "backend/common/graph_kernel/parallel_cost_model.h"
33 #include "include/backend/kernel_graph.h"
34 #include "utils/ms_context.h"
35 
36 namespace mindspore::graphkernel {
37 class ParallelInfo {
38  public:
39   ParallelInfo() = default;
ParallelInfo(const AnfNodePtrList & nodes,const std::vector<DimInfoPtr> & dims,const FusionInfoPtr & fusion_info)40   ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info)
41       : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {}
42   ~ParallelInfo() = default;
43 
GetSize()44   size_t GetSize() const {
45     if (nodes_.size() != dims_.size()) {
46       MS_LOG(EXCEPTION) << "Internal error in parallel info! nodes' size is different from dims' size: "
47                         << nodes_.size() << " vs " << dims_.size();
48     }
49     return nodes_.size();
50   }
nodes()51   const AnfNodePtrList &nodes() const { return nodes_; }
dims()52   const std::vector<DimInfoPtr> &dims() const { return dims_; }
fusion_info()53   const FusionInfoPtr &fusion_info() const { return fusion_info_; }
54 
55  private:
56   AnfNodePtrList nodes_;
57   std::vector<DimInfoPtr> dims_;
58   FusionInfoPtr fusion_info_;
59 };
60 
61 class ParallelConfig {
62  public:
63   ParallelConfig() = default;
ParallelConfig(size_t max_n)64   explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {}
65   ~ParallelConfig() = default;
max_num_for_fuse()66   size_t max_num_for_fuse() const { return max_num_for_fuse_; }
67 
68  private:
69   size_t max_num_for_fuse_{10};  // Too many nodes to fuse together may produce bad result.
70 };
71 
72 struct NodeRelation {
73  public:
NodeRelationNodeRelation74   NodeRelation() {}
75   ~NodeRelation() = default;
76   OrderedSet<AnfNodePtr> pres;
77   OrderedSet<AnfNodePtr> nexts;
78 };
79 
80 class ParallelOpFusion : public opt::Pass {
81  public:
ParallelOpFusion(const std::string & target,const ParallelConfig & config)82   ParallelOpFusion(const std::string &target, const ParallelConfig &config)
83       : Pass("parallel_fusion"), target_(target), config_(config) {}
84   ~ParallelOpFusion() override = default;
85   bool Run(const FuncGraphPtr &graph) override;
86 
87  private:
88   std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets,
89                                                                          const std::vector<bool> &used,
90                                                                          const AnfNodePtrList &nodes,
91                                                                          const std::set<int> &excludes) const;
92 
93   std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates(
94     size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
95     std::map<AnfNodePtr, int> *sorted_indices);
96 
97   std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs);
98 
99   void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
100                                       std::vector<ParallelInfo> *parallel_infos);
101 
102   std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
103 
104   void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info);
105 
106   void SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info);
107 
108   bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
109                                  const std::shared_ptr<session::KernelGraph> &kernel_graph);
110 
111   OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes);
112   std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels);
113 
114   std::string target_;
115   ParallelConfig config_;
116   ParallelCostModelPtr cost_model_ptr_;
117   std::set<AnfNodePtr> virtual_noout_nodes_;
118   std::set<AnfNodePtr> ignore_noin_nodes_;
119   unsigned int parallel_level_{0};
120 };
121 using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
122 }  // namespace mindspore::graphkernel
123 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
124