• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
18 
19 #include <map>
20 #include <ostream>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 #include "backend/common/graph_kernel/graph_kernel_flags.h"
26 #include "include/backend/optimizer/pass.h"
27 #include "ir/func_graph.h"
28 
29 namespace mindspore::graphkernel {
30 /*
31  * Recompute some operator to reduce temporary memory peak.
32  *
33  *   (a)  (b)                  (a)   (b)
34  *     \  /                     |     |
35  *      Gs                     Gs1    |
36  *  (c)/ |                   (c)|     |
37  *    /  |                      Go    |
38  *  Go   |(d)   =========>      │└─depend
39  *    \  |                      │     │
40  *  (e)\ |                   (e)│    Gs2
41  *      \|                      │     │
42  *      Gt                      ├────(d)
43  *                              Gt
44  *
45  * Where, split Gs to Gs1 and Gs2, and (x) means the temporary tensor.
46  * For left graph, the memory is (a+b) -> (c+d) -> (d+e)
47  * As for right graph, memory is (a+b) -> (b+c) -> (b+e) -> (d+e)
48  * If the (c+d) reach the threshold memory, and (b+c) or (b+e) is less than it,
49  * it may ease the memory burden.
50  */
51 enum class EdgeLifeTimeType : char { ShortTerm, LongTerm };
52 inline std::ostream &operator<<(std::ostream &os, EdgeLifeTimeType type) {
53   std::map<EdgeLifeTimeType, std::string> out_str = {{EdgeLifeTimeType::ShortTerm, "[ShortTerm]"},
54                                                      {EdgeLifeTimeType::LongTerm, "[LongTerm]"}};
55   return os << out_str[type];
56 }
57 using OutPosLinkList = std::vector<std::tuple<AnfNodePtr, std::vector<int>, EdgeLifeTimeType>>;
58 using OutPosLinkMap = std::map<AnfNodePtr, std::vector<int>>;
59 using MemorySize = int64_t;
60 struct Candidate {
61   AnfNodePtr source_graph;
62   AnfNodePtr target_graph;
63   EdgeLifeTimeType type;
64   AnfNodePtrList recompute_edges;  // getitem list for recompute edges.
65 };
66 
67 class AutoRecompute {
68  public:
69   AutoRecompute() = default;
70   virtual ~AutoRecompute() = default;
71 
72   virtual std::vector<Candidate> Run(const FuncGraphPtr &func_graph);
73 
74  protected:
75   using NodeRecomputeCandidates =
76     OrderedMap<AnfNodePtr, OrderedMap<AnfNodePtr, std::pair<EdgeLifeTimeType, AnfNodePtrList>>>;
77   virtual NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node,
78                                                               const OutPosLinkList &target_graphs,
79                                                               const FuncGraphManagerPtr &mng);
80   void FindCandidates(const FuncGraphPtr &func_graph);
81   std::vector<Candidate> candidates_;
82 
83  private:
84   OutPosLinkList JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
85   AnfNodePtrList Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
86                         const FuncGraphManagerPtr &mng);
87   int GetSourceLinkOutPos(const AnfNodePtr &target, int pos) const;
88   std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> GetValidUsers(const AnfNodePtr &node,
89                                                                               const FuncGraphManagerPtr &mng);
90   MemorySize SelectThreshold(EdgeLifeTimeType type) const;
91   bool IsThresholdDefaultValue() const;
92 
93   std::map<AnfNodePtr, MemorySize> topo_indice_;
94   MemorySize lifetime_threshold_{0};
95   MemorySize local_peak_threshold_{0};
96 
97   void RecomputeLinkEdgeLog(const AnfNodePtr &node, const OrderedSet<AnfNodePtr> &direct_users,
98                             const OutPosLinkList &target_link_infos) const;
99   void RecomputeCandidatesLog(const std::vector<Candidate> &candidates) const;
100 };
101 
102 class CSRRecompute : public AutoRecompute {
103  public:
104   std::vector<Candidate> Run(const FuncGraphPtr &func_graph) override;
105 
106  protected:
107   NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node, const OutPosLinkList &target_graphs,
108                                                       const FuncGraphManagerPtr &mng) override;
109 
110  private:
111   bool CheckPrimitiveInput(AnfNodePtr base, const PrimitivePtr &prim_type) const;
112 };
113 
114 class GraphKernelRecompute : public opt::Pass {
115  public:
GraphKernelRecompute()116   GraphKernelRecompute() : Pass("graph_kernel_recompute") {}
117   ~GraphKernelRecompute() override = default;
118   bool Run(const FuncGraphPtr &func_graph) override;
119 
120  private:
121   bool DoRun(const FuncGraphPtr &func_graph, bool use_csr = false);
122   void Process(const Candidate &candidate) const;
123   std::pair<FuncGraphPtr, AnfNodePtrList> CloneGraph(const CNodePtr &source_graph,
124                                                      const AnfNodePtrList &recompute_edges) const;
125   void LinkIntoTargetFuncGraph(
126     const Candidate &candidate, const FuncGraphPtr &cloned_func, const AnfNodePtrList &cloned_inputs,
127     const std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> &edge_match_func) const;
128 
129   std::vector<Candidate> candidates_;
130 };
131 }  // namespace mindspore::graphkernel
132 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
133