• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CLUSTER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CLUSTER_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <sstream>
22 #include <set>
23 #include <string>
24 
25 #include "utils/hash_map.h"
26 #include "ir/anf.h"
27 #include "include/backend/optimizer/pass.h"
28 
29 namespace mindspore::graphkernel {
30 class Graph;
31 using GraphPtr = std::shared_ptr<Graph>;
32 class Graph {
33   struct Cluster {
34     size_t cluster_id_;        // node_id of the representative.
35     size_t cluster_size_{1};   // size of cluster, composite node is considered as one node.
36     std::set<size_t> inputs_;  // inputs' cluster_id.
37     size_t seed_{0};           // visited flag of dfs.
38 
39     Cluster(size_t node_id, const AnfNodePtr &node, const mindspore::HashMap<AnfNodePtr, size_t> &node_idx_map);
40     ~Cluster() = default;
41 
42     void Merge(Cluster *other_cluster);
43 
44     // clean the info to free memory.
CleanCluster45     void Clean() {
46       inputs_.clear();
47       cluster_size_ = 0;
48     }
49   };  // struct Cluster
50 
51  public:
52   static GraphPtr Build(const FuncGraphPtr &func_graph, AnfNodePtrList *nodes = nullptr,
53                         HashMap<AnfNodePtr, size_t> *node_idx_map = nullptr);
54   ~Graph() = default;
55 
56   // find the representative of the cluster
57   size_t Find(size_t node_id);
58 
59   // merge clusters, the smallest cluster id will be the new cluster id.
60   void Merge(const std::vector<size_t> &candidates);
61 
62   // Collect nodes together that are in the same cluster.
63   std::vector<std::vector<size_t>> CollectClusters();
64 
65   using VisitFunc = std::function<IncludeType(size_t)>;
66   void Dfs(size_t node_id, const VisitFunc &visitor);
67 
68   // Get cluster size
GetSize(size_t cluster_id)69   size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
70 
71   // Get cluster's inputs
72   const std::set<size_t> &GetInputs(size_t cluster_id);
73 
74   // public constructor for std::make_shared, do not call it manually.
75   Graph(const AnfNodePtrList &nodes, const HashMap<AnfNodePtr, size_t> &node_idx_map);
76 
77  private:
78   void RefreshInputs(size_t i);
79   void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor);
80 
81   std::vector<Cluster> clusters_;
82   size_t seen_{0};
83 };  // Graph
84 
85 class CircleChecker {
86  public:
CircleChecker(const GraphPtr & graph)87   explicit CircleChecker(const GraphPtr &graph) : graph_(graph) {}
88   ~CircleChecker() = default;
89 
90   void RemoveCircle(std::vector<size_t> *candidates);
91 
92  private:
93   bool CheckCircle(size_t basenode);
94 
95   // remove all circle nodes from candidates
96   void RemoveCircleNodesFromCandidates();
97 
98   GraphPtr graph_;               // bind the global graph
99   std::set<size_t> candidates_;  // bind the input candidates
100   std::vector<size_t> circle_nodes_;
101   std::set<size_t> acyclic_nodes_;
102 };  // CircleChecker
103 
104 class GraphKernelCluster : public opt::Pass {
105  public:
Pass(pass_name)106   explicit GraphKernelCluster(const std::string &pass_name = "graph_kernel_cluster") : Pass(pass_name) {}
107   ~GraphKernelCluster() override = default;
108   bool Run(const FuncGraphPtr &func_graph) override;
109 
110  protected:
GetClusterableOpList()111   virtual std::vector<PrimitivePtr> GetClusterableOpList() { return {}; }
112   virtual bool IsClusterableOp(const AnfNodePtr &node) = 0;
113   void Init(const FuncGraphPtr &func_graph);
114   bool Process(const FuncGraphPtr &func_graph);
115   std::vector<size_t> FindCandidates(size_t basenode_id);
116   void RemoveWildGetitem(std::vector<size_t> *candidates);
117   virtual void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id);
118   void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node);
119   void DumpToFile();
Clean()120   void Clean() {
121     nodes_.clear();
122     graph_ = nullptr;
123   }
124 
125   GraphPtr graph_{nullptr};
126   std::vector<AnfNodePtr> nodes_;
127   std::stringstream dump_buf_;
128   std::vector<PrimitivePtr> op_list_;
129 };
130 }  // namespace mindspore::graphkernel
131 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CLUSTER_H_
132