• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_
19 
20 #include <vector>
21 #include <algorithm>
22 #include <utility>
23 #include "frontend/optimizer/irpass.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/optimizer/optimizer.h"
27 #include "frontend/optimizer/anf_visitor.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "ir/func_graph.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace irpass {
34 constexpr auto kHandledNotRecomputeNodeFlag = "handled_not_recompute_node";
35 constexpr auto kPrimalFgCallerUserDataKey = "primal_fg_caller";
36 bool EnableCellReuse();
37 
38 bool HasBpropGetter(const OptimizerPtr &opt, const AnfNodePtr &k_fg_caller);
39 
40 AnfNodePtr GetBpropCaller(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter);
41 
42 bool AddRecomputeNodes(const FuncGraphPtr &root, const opt::OptimizerPtr &opt);
43 
44 class RemoveNotRecomputeNode : public AnfVisitor {
45  public:
operator()46   AnfNodePtr operator()(const OptimizerPtr &opt, const AnfNodePtr &node) override {
47     if (!EnableCellReuse()) {
48       return nullptr;
49     }
50     Reset();
51     auto k_fg_caller = node->cast<CNodePtr>();
52     MS_EXCEPTION_IF_NULL(k_fg_caller);
53     if (!IsMatch(k_fg_caller)) {
54       return nullptr;
55     }
56 
57     MS_EXCEPTION_IF_NULL(opt);
58     auto manager = opt->manager();
59     MS_EXCEPTION_IF_NULL(manager);
60     auto fg = node->func_graph();
61     MS_EXCEPTION_IF_NULL(fg);
62 
63     bool has_bprop_getter = HasBpropGetter(opt, k_fg_caller);
64     // If the k graph has been handled, the call nodes of k and primal graph should be handled.
65     if (k_fg_->has_flag(kHandledNotRecomputeNodeFlag)) {
66       return CreateNewCallerForHandledKGraph(manager, fg, k_fg_caller, has_bprop_getter);
67     }
68 
69     // The k graph only contains primal should not be recomputed.
70     if (!has_bprop_getter) {
71       return nullptr;
72     }
73 
74     k_fg_->set_flag(kHandledNotRecomputeNodeFlag, true);
75     std::vector<AnfNodePtr> new_primal_fg_outputs{NewValueNode(prim::kPrimMakeTuple), primal_fg_->output()};
76     std::vector<AnfNodePtr> k_fg_nodes = TopoSort(k_fg_->get_return(), SuccDeeperSimple);
77     int64_t not_recompute_count = 0;
78     for (const auto &node_in_k_fg : k_fg_nodes) {
79       auto [cnode_k_fg, primal_cnode] = GetNotRecomputeKGraphAndPrimalCNode(node_in_k_fg);
80       if (cnode_k_fg == nullptr || primal_cnode == nullptr) {
81         continue;
82       }
83       ++not_recompute_count;
84       // Erase the flag to do inline later.
85       cnode_k_fg->erase_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH);
86       // Replace the primal node in k graph with the node in primal graph.
87       (void)new_primal_fg_outputs.emplace_back(primal_cnode);
88       auto para = k_fg_->add_parameter();
89       auto cnode_k_fg_output = cnode_k_fg->output();
90       if (!IsPrimitiveCNode(cnode_k_fg_output, prim::kPrimMakeTuple)) {
91         MS_LOG(INTERNAL_EXCEPTION) << "The output of k graph should be make_tuple, but got "
92                                    << cnode_k_fg_output->DebugString();
93       }
94       (void)manager->Replace(cnode_k_fg_output->cast<CNodePtr>()->input(1), para);
95     }
96     if (not_recompute_count == 0) {
97       return nullptr;
98     }
99 
100     primal_fg_->set_output(primal_fg_->NewCNode(new_primal_fg_outputs));
101     auto primal_fg_caller = k_fg_caller->user_data<CNode>(kPrimalFgCallerUserDataKey);
102     UpdateForwardResult(manager, primal_fg_caller);
103     // Add new arguments to k graph caller.
104     return CreateNewKGraphCaller(fg, k_fg_caller, primal_fg_caller, not_recompute_count);
105   }
106 
CreateNewCallerForHandledKGraph(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & k_fg_caller,bool has_bprop_getter)107   AnfNodePtr CreateNewCallerForHandledKGraph(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
108                                              const CNodePtr &k_fg_caller, bool has_bprop_getter) {
109     auto not_recompute_count = SizeToLong(k_fg_->parameters().size() - (k_fg_caller->size() - 1));
110     if (not_recompute_count == 0) {
111       return nullptr;
112     }
113     if (!has_bprop_getter) {
114       std::vector<AnfNodePtr> new_primal_caller_inputs{NewValueNode(primal_fg_)};
115       (void)new_primal_caller_inputs.insert(new_primal_caller_inputs.cend(), k_fg_caller->inputs().begin() + 1,
116                                             k_fg_caller->inputs().end());
117       auto new_primal_caller = fg->NewCNodeInOrder(new_primal_caller_inputs);
118       return new_primal_caller;
119     }
120 
121     auto primal_fg_caller = k_fg_caller->user_data<CNode>(kPrimalFgCallerUserDataKey);
122     UpdateForwardResult(manager, primal_fg_caller);
123     return CreateNewKGraphCaller(fg, k_fg_caller, primal_fg_caller, not_recompute_count);
124   }
125 
CreateNewKGraphCaller(const FuncGraphPtr & fg,const CNodePtr & k_fg_caller,const CNodePtr & primal_fg_caller,int64_t not_recompute_count)126   static AnfNodePtr CreateNewKGraphCaller(const FuncGraphPtr &fg, const CNodePtr &k_fg_caller,
127                                           const CNodePtr &primal_fg_caller, int64_t not_recompute_count) {
128     std::vector<AnfNodePtr> new_k_fg_caller_inputs;
129     (void)new_k_fg_caller_inputs.insert(new_k_fg_caller_inputs.cend(), k_fg_caller->inputs().begin(),
130                                         k_fg_caller->inputs().end());
131     auto primal_fg_caller_fg = primal_fg_caller->func_graph();
132     for (int64_t i = 1; i <= not_recompute_count; ++i) {
133       auto extra_forward_result = primal_fg_caller_fg->NewCNodeInOrder(
134         {NewValueNode(prim::kPrimTupleGetItem), primal_fg_caller, NewValueNode(i)});
135       (void)new_k_fg_caller_inputs.emplace_back(extra_forward_result);
136     }
137     auto new_k_fg_caller = fg->NewCNodeInOrder(new_k_fg_caller_inputs);
138     if (k_fg_caller->HasAttr(kAddedRecomputeDependAttr)) {
139       new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
140     }
141     return new_k_fg_caller;
142   }
143 
UpdateForwardResult(const FuncGraphManagerPtr & manager,const AnfNodePtr & primal_fg_caller)144   static void UpdateForwardResult(const FuncGraphManagerPtr &manager, const AnfNodePtr &primal_fg_caller) {
145     MS_EXCEPTION_IF_NULL(primal_fg_caller);
146     auto fg = primal_fg_caller->func_graph();
147     MS_EXCEPTION_IF_NULL(fg);
148     auto forward_result = fg->NewCNodeInOrder(
149       {NewValueNode(prim::kPrimTupleGetItem), primal_fg_caller, NewValueNode(static_cast<int64_t>(0))});
150     (void)manager->Replace(primal_fg_caller, forward_result);
151   }
152 
GetNotRecomputeKGraphAndPrimalCNode(const AnfNodePtr & node)153   static std::pair<FuncGraphPtr, AnfNodePtr> GetNotRecomputeKGraphAndPrimalCNode(const AnfNodePtr &node) {
154     auto cnode = dyn_cast<CNode>(node);
155     if (cnode == nullptr) {
156       return std::make_pair(nullptr, nullptr);
157     }
158     // call (k_fg, ...)
159     auto cnode_k_fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
160     if (cnode_k_fg == nullptr || !cnode_k_fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH)) {
161       return std::make_pair(nullptr, nullptr);
162     }
163     // k_fg -> primal
164     auto primal_cnode = GetPrimalCNode(cnode_k_fg);
165     if (primal_cnode == nullptr) {
166       MS_LOG(DEBUG) << "The cnode k_fg " << cnode_k_fg->ToString() << " should have corresponding primal_cnode.";
167       return std::make_pair(nullptr, nullptr);
168     }
169     MS_LOG(DEBUG) << "primal_cnode: " << primal_cnode->DebugString();
170     return std::make_pair(cnode_k_fg, primal_cnode);
171   }
172 
GetPrimalCNode(const FuncGraphPtr & cnode_k_fg)173   static AnfNodePtr GetPrimalCNode(const FuncGraphPtr &cnode_k_fg) {
174     auto primal_cnode_iter = cnode_k_fg->transforms().find("primal_cnode");
175     if (primal_cnode_iter == cnode_k_fg->transforms().end()) {
176       MS_LOG(DEBUG) << "Not found the primal cnode of k graph " << cnode_k_fg->ToString();
177       return nullptr;
178     }
179     auto primal_cnode = primal_cnode_iter->second.primal_cnode();
180     return primal_cnode;
181   }
182 
IsMatch(const CNodePtr & k_fg_caller)183   bool IsMatch(const CNodePtr &k_fg_caller) {
184     auto k_fg = GetValueNode<FuncGraphPtr>(k_fg_caller->input(0));
185     if (k_fg == nullptr || !k_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
186       return false;
187     }
188 
189     auto primal_iter = k_fg->transforms().find("primal");
190     if (primal_iter == k_fg->transforms().end()) {
191       return false;
192     }
193     auto primal_fg = primal_iter->second.func_graph();
194     if (primal_fg == nullptr || !primal_fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
195       return false;
196     }
197 
198     k_fg_ = k_fg;
199     primal_fg_ = primal_fg;
200     return true;
201   }
202 
Reset()203   void Reset() {
204     k_fg_ = nullptr;
205     primal_fg_ = nullptr;
206   }
207 
208  private:
209   FuncGraphPtr k_fg_{nullptr};
210   FuncGraphPtr primal_fg_{nullptr};
211 };
212 }  // namespace irpass
213 }  // namespace opt
214 }  // namespace mindspore
215 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_
216