1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include <algorithm>
27 #include "ir/anf.h"
28 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
29 #include "frontend/optimizer/ad/adjoint.h"
30 #include "frontend/optimizer/ad/dfunctor.h"
31 #include "frontend/optimizer/ad/kpynative.h"
32 #include "frontend/operator/ops.h"
33 #include "utils/info.h"
34 #include "debug/anf_ir_dump.h"
35 #include "debug/trace.h"
36
37 namespace mindspore {
38 namespace ad {
39 using CacheKey = std::pair<std::string, size_t>;
40
41 static KPrim g_k_prims_pynative;
42 static ValuePtr add_ops;
43 static ValuePtr ones_like_ops;
44 static ValuePtr zeros_like_ops;
45 static std::shared_ptr<const opt::irpass::OptimizeIRPassLib> irpass;
46 static std::map<CacheKey, FuncGraphPtr> bprop_func_graph_cache;
47 static std::unordered_map<abstract::AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
48 abstract::AbstractBasePtrListEqual>
49 zeros_like_funcgraph_cache;
50 static std::unordered_map<abstract::AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
51 abstract::AbstractBasePtrListEqual>
52 ones_like_funcgraph_cache;
53
54 namespace {
ZerosLikePrimOptPass(const pipeline::ResourcePtr & res)55 FuncGraphPtr ZerosLikePrimOptPass(const pipeline::ResourcePtr &res) {
56 if (irpass == nullptr) {
57 irpass = std::make_shared<opt::irpass::OptimizeIRPassLib>();
58 }
59 opt::OptPassConfig eliminate_zeros_like_prim_pass = opt::OptPassConfig({
60 irpass->zero_like_fill_zero_,
61 });
62
63 opt::OptPassGroupMap map({{"eliminate_zeros_like_prim_", eliminate_zeros_like_prim_pass}});
64
65 auto eliminate_zeros_like_prim = opt::Optimizer::MakeOptimizer("eliminate_zeros_like_prim", res, map);
66 FuncGraphPtr func_graph = res->func_graph();
67 WITH(MsProfile::GetProfile()->Step("eliminate_zeros_like_prim"))[&eliminate_zeros_like_prim, &func_graph]() {
68 func_graph = eliminate_zeros_like_prim->step(func_graph, true);
69 };
70 return func_graph;
71 }
72
GetZerosLike(const abstract::AbstractBasePtrList & args_spec)73 FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
74 if (zeros_like_ops == nullptr) {
75 zeros_like_ops = prim::GetPythonOps("zeros_like");
76 }
77 auto iter = zeros_like_funcgraph_cache.find(args_spec);
78 if (iter != zeros_like_funcgraph_cache.end()) {
79 MS_LOG(DEBUG) << "Cache hit for zeros_like: " << mindspore::ToString(args_spec);
80 return BasicClone(iter->second);
81 }
82 if (!zeros_like_ops->isa<MetaFuncGraph>()) {
83 MS_LOG(EXCEPTION) << "zeros_like is not a MetaFuncGraph";
84 }
85 auto zeros_like = zeros_like_ops->cast<MetaFuncGraphPtr>();
86 auto zeros_like_fg = zeros_like->GenerateFuncGraph(args_spec);
87 MS_EXCEPTION_IF_NULL(zeros_like_fg);
88 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
89 auto specialized_zeros_like_fg = pipeline::Renormalize(resource, zeros_like_fg, args_spec);
90 MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
91 auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
92 MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
93 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
94 if (enable_grad_cache) {
95 zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
96 }
97 return opted_zeros_like_fg;
98 }
99
GetHyperAdd(const abstract::AbstractBasePtrList & args_spec)100 FuncGraphPtr GetHyperAdd(const abstract::AbstractBasePtrList &args_spec) {
101 if (add_ops == nullptr) {
102 add_ops = prim::GetPythonOps("hyper_add");
103 }
104 if (!add_ops->isa<MetaFuncGraph>()) {
105 MS_LOG(EXCEPTION) << "add is not a MetaFuncGraph";
106 }
107 auto add = add_ops->cast<MetaFuncGraphPtr>();
108 auto add_fg = add->GenerateFuncGraph(args_spec);
109 MS_EXCEPTION_IF_NULL(add_fg);
110 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
111 auto specialized_add_fg = pipeline::Renormalize(resource, add_fg, args_spec);
112 MS_EXCEPTION_IF_NULL(specialized_add_fg);
113 return specialized_add_fg;
114 }
115
BuildZerosLikeNode(const FuncGraphPtr & tape,const AnfNodePtr & node)116 AnfNodePtr BuildZerosLikeNode(const FuncGraphPtr &tape, const AnfNodePtr &node) {
117 // Build zeros_like(node) as dout
118 abstract::AbstractBasePtrList args_spec{node->abstract()->Broaden()};
119 auto zeros_like_fg = GetZerosLike(args_spec);
120 auto zeros_like_node = tape->NewCNode({NewValueNode(zeros_like_fg), node});
121 zeros_like_node->set_abstract(zeros_like_fg->output()->abstract());
122 return zeros_like_node;
123 }
124
BuildZerosLikeValue(const FuncGraphPtr & tape,const ValuePtr & out)125 AnfNodePtr BuildZerosLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
126 // Build zeros_like(out) as dout
127 abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()};
128 auto zeros_like_fg = GetZerosLike(args_spec);
129 auto zeros_like_value = tape->NewCNode({NewValueNode(zeros_like_fg), NewValueNode(out)});
130 zeros_like_value->set_abstract(zeros_like_fg->output()->abstract());
131 return zeros_like_value;
132 }
133
GetOnesLike(const abstract::AbstractBasePtrList & args_spec)134 FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
135 if (ones_like_ops == nullptr) {
136 ones_like_ops = prim::GetPythonOps("ones_like");
137 }
138 auto iter = ones_like_funcgraph_cache.find(args_spec);
139 if (iter != ones_like_funcgraph_cache.end()) {
140 MS_LOG(DEBUG) << "Cache hit for ones_like: " << mindspore::ToString(args_spec);
141 return BasicClone(iter->second);
142 }
143 if (!ones_like_ops->isa<MetaFuncGraph>()) {
144 MS_LOG(EXCEPTION) << "ones_like is not a MetaFuncGraph";
145 }
146 auto ones_like = ones_like_ops->cast<MetaFuncGraphPtr>();
147 auto ones_like_fg = ones_like->GenerateFuncGraph(args_spec);
148 MS_EXCEPTION_IF_NULL(ones_like_fg);
149 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
150 auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
151 MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
152 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
153 if (enable_grad_cache) {
154 ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
155 }
156 return specialized_ones_like_fg;
157 }
158
BuildOnesLikeValue(const FuncGraphPtr & tape,const ValuePtr & out)159 AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
160 // Build ones_like(out) as dout
161 abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()};
162 auto ones_like_fg = GetOnesLike(args_spec);
163 auto ones_like_value = tape->NewCNode({NewValueNode(ones_like_fg), NewValueNode(out)});
164 ones_like_value->set_abstract(ones_like_fg->output()->abstract());
165 return ones_like_value;
166 }
167
168 // This Faked BProp func_graph should not be present in the final top bprop func_graph.
BuildFakeBProp(const PrimitivePtr & prim,size_t inputs_num)169 FuncGraphPtr BuildFakeBProp(const PrimitivePtr &prim, size_t inputs_num) {
170 auto func_graph = std::make_shared<FuncGraph>();
171 std::vector<AnfNodePtr> outputs;
172 outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
173
174 auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
175 (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
176 auto fake_input_sens = func_graph->NewCNode({NewValueNode(fake_bprop), NewValueNode(true)});
177
178 for (size_t i = 0; i < inputs_num; ++i) {
179 // Mock params for inputs
180 auto param = func_graph->add_parameter();
181 MS_EXCEPTION_IF_NULL(param);
182 // Mock derivatives for each inputs
183 outputs.push_back(fake_input_sens);
184 }
185 // mock params for out and dout
186 (void)func_graph->add_parameter();
187 (void)func_graph->add_parameter();
188 func_graph->set_output(func_graph->NewCNode(outputs));
189 return func_graph;
190 }
191 } // namespace
192
193 class PynativeAdjoint {
194 public:
195 enum FuncGraphType { kForwardPropagate, kBackwardPropagate };
PynativeAdjoint(const FuncGraphPtr & tape,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fg,FuncGraphType fg_type=kBackwardPropagate)196 PynativeAdjoint(const FuncGraphPtr &tape, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &fg,
197 FuncGraphType fg_type = kBackwardPropagate)
198 : tape_(tape), op_args_(op_args), out_(out), fg_(fg), fg_type_(fg_type) {}
199
200 ~PynativeAdjoint() = default;
users()201 AnfNodePtrList &users() { return users_; }
op_args() const202 const ValuePtrList &op_args() const { return op_args_; }
out() const203 const ValuePtr &out() const { return out_; }
fg() const204 const FuncGraphPtr &fg() const { return fg_; }
fg_type() const205 const FuncGraphType &fg_type() const { return fg_type_; }
RealDout()206 AnfNodePtr RealDout() {
207 if (dout_ != nullptr) {
208 return dout_;
209 }
210 return BuildZerosLikeValue(tape_, out_);
211 }
212
AccumulateDout(const AnfNodePtr & dout_factor)213 void AccumulateDout(const AnfNodePtr &dout_factor) {
214 if (dout_factor->abstract() == nullptr) {
215 MS_LOG(EXCEPTION) << "Abstract of dout_factor should not be null: " << dout_factor->ToString();
216 }
217 if (dout_ != nullptr) {
218 MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
219 auto arg = out_->ToAbstract()->Broaden();
220 abstract::AbstractBasePtrList args_spec{arg, arg};
221 auto add_fg = GetHyperAdd(args_spec);
222 MS_EXCEPTION_IF_NULL(add_fg);
223 dout_ = tape_->NewCNode({NewValueNode(add_fg), dout_, dout_factor});
224 dout_->set_abstract(add_fg->output()->abstract());
225 MS_LOG(DEBUG) << "New dout_ " << dout_->DebugString();
226 return;
227 }
228 dout_ = dout_factor;
229 }
230
k_node() const231 AnfNodePtr k_node() const { return k_node_; }
set_k_node(const AnfNodePtr & k_node)232 void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
233
234 private:
235 const FuncGraphPtr tape_;
236 AnfNodePtr dout_{nullptr};
237 // Used by whose
238 AnfNodePtrList users_;
239 // cache these arguments from ad caller.
240 const ValuePtrList op_args_;
241 // For CNode , it's output of cnode. For Parameter or ValueNode, it's its value.
242 const ValuePtr out_;
243 // fg_ is a bprop_fg generated from Primitive.
244 // or a fprop_fg passed from caller.
245 // FuncGraph to tape_;
246 const FuncGraphPtr fg_;
247 const FuncGraphType fg_type_;
248 // k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
249 AnfNodePtr k_node_;
250 };
251 using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
252
253 class KPynativeCellImpl : public KPynativeCell {
254 public:
KPynativeCellImpl(const AnfNodePtrList & cell_inputs,const std::vector<ValuePtr> & input_param_values)255 KPynativeCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
256 : tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
257 tape_->debug_info()->set_name("grad_top");
258 for (size_t i = 0; i < cell_inputs.size(); ++i) {
259 TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
260 (void)tape_->add_parameter();
261 // Build adjoint for every input parameter
262 auto input_adjoint =
263 std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, input_param_values[i], FuncGraphPtr(nullptr));
264 (void)anfnode_to_adjoin_.insert(std::make_pair(cell_inputs[i], input_adjoint));
265 }
266 }
267 ~KPynativeCellImpl() override = default;
268 bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
269 bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
270 const FuncGraphPtr &bprop_fg);
271 bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
272 const FuncGraphPtr &fprop_fg) override;
273 void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) override;
274 // Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
275 // can be grad again.
276 FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg,
277 bool build_formal_param);
278
279 private:
280 bool need_propagate_stop_gradient_{false};
281 // Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
282 AnfNodePtr last_node_{nullptr};
283 FuncGraphPtr tape_;
284 AnfNodePtrList cell_inputs_;
285 // These weights need to calculate gradient.
286 std::unordered_set<AnfNodePtr> need_grad_weights_;
287 OrderedMap<AnfNodePtr, PynativeAdjointPtr> anfnode_to_adjoin_;
288
289 // For CNode like TupleGetItem, ListGetItem, MakeTuple, MakeList, it's bypassed by caller so
290 // no KPynativeOp is called for these CNode. Here we forge Adjoint for these CNode.
291 PynativeAdjointPtr ForgeCNodeAdjoint(const CNodePtr &cnode);
292 PynativeAdjointPtr ForgeGetItemAdjoint(const CNodePtr &cnode);
293 PynativeAdjointPtr ForgeMakeSequenceAdjoint(const CNodePtr &cnode);
294 bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
295 const FuncGraphPtr &bprop_fg,
296 PynativeAdjoint::FuncGraphType fg_type = PynativeAdjoint::kBackwardPropagate);
297 void BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args);
298 void PropagateStopGradient();
299 bool AllReferencesStopped(const CNodePtr &curr_cnode);
300 OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
301 // Back propagate for all node;
302 // if by_value is true, in bprop_app cnode, every input is value node;
303 // if by_value is false, in bprop_app cnode, input is the k mapped node, so it can be grad again.
304 bool BackPropagate(bool by_value);
305 bool BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint,
306 const FuncGraphPtr &bprop_fg, bool by_value);
307 bool BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint,
308 const FuncGraphPtr &fprop_fg, bool by_value);
309 bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
310 AnfNodePtr BuildKNodeForCNodeInput(const PynativeAdjointPtr &cnode_adjoint, const AnfNodePtr &input_node,
311 size_t input_index);
312 const AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint);
313 FuncGraphPtr BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
314 // Back propagate for MakeList or MakeTuple is generated from MetaFuncGraph.
315 FuncGraphPtr BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode);
316 // Replace input or weights parameter from primal funcgraph to parameters of tape_;
317 void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
318 // Set sens and weights parameter nodes by user input info
319 void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
320 // Set return node according to grad flag
321 void SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
322
323 // for higher order gradient;
324 // Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
325 // used in tape_ to keep tracking the cnode dependency.
326 bool BuildKNode();
327 CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args);
328 };
329 using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
330
GradPynativeCellBegin(const AnfNodePtrList & cell_inputs,const std::vector<ValuePtr> & input_param_values)331 KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
332 const std::vector<ValuePtr> &input_param_values) {
333 auto abstract_are_set = std::all_of(cell_inputs.cbegin(), cell_inputs.cend(),
334 [](const AnfNodePtr &node) { return node->abstract() != nullptr; });
335 if (!abstract_are_set) {
336 MS_LOG(EXCEPTION) << "Not all abstract_value in cell_inputs are set";
337 }
338 if (cell_inputs.size() != input_param_values.size()) {
339 MS_LOG(EXCEPTION) << "The size of cell inputs " << cell_inputs.size()
340 << " is not equal to the size of input parameter values " << input_param_values.size();
341 }
342 return std::make_shared<KPynativeCellImpl>(cell_inputs, input_param_values);
343 }
344
GradPynativeCellEnd(const KPynativeCellPtr & k_cell,const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights,bool has_sens_arg,bool build_formal_param)345 FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
346 bool grad_weights, bool has_sens_arg, bool build_formal_param) {
347 auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
348 return k_cell_impl->Finish(weights, grad_inputs, grad_weights, has_sens_arg, build_formal_param);
349 }
350
Finish(const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights,bool has_sens_arg,bool build_formal_param)351 FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights,
352 bool has_sens_arg, bool build_formal_param) {
353 // propagate stop_gradient flag to cnode before back propagate;
354 PropagateStopGradient();
355 // Set sens node and weights node
356 SetSensAndWeights(weights, has_sens_arg);
357 // Build forward CNode;
358 if (build_formal_param) {
359 (void)BuildKNode();
360 }
361 // BackPropagate sensitivity, except when the last node is a valuenode which may be obtained by constant folding;
362 if (!last_node_->isa<ValueNode>()) {
363 (void)BackPropagate(!build_formal_param);
364 }
365 // Return the gradient;
366 SetOutput(weights, grad_inputs, grad_weights);
367 // Replace Parameter of primal funcgraph with parameter of tape_;
368 ReplacePrimalParameter(weights, has_sens_arg);
369 #ifdef ENABLE_DUMP_IR
370 auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
371 if (save_graphs_flg) {
372 DumpIR("before_final_opt.ir", tape_);
373 }
374 #endif
375 return tape_;
376 }
377
GradPynativeOp(const KPynativeCellPtr & k_cell,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)378 bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
379 const ValuePtr &out) {
380 auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
381 return k_cell_impl->KPynativeOp(cnode, op_args, out);
382 }
383
KPynativeOp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)384 bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
385 MS_EXCEPTION_IF_NULL(cnode);
386 auto prim = GetCNodePrimitive(cnode);
387 if (prim == nullptr) {
388 MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
389 }
390 if (IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
391 need_propagate_stop_gradient_ = true;
392 }
393
394 FuncGraphPtr bprop_fg = nullptr;
395 if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) {
396 bprop_fg = BuildBPropCutFuncGraph(prim, cnode);
397 } else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
398 bprop_fg = BuildMakeSequenceBprop(prim, cnode);
399 } else {
400 bprop_fg = g_k_prims_pynative.GetPossibleBprop(prim);
401 if (bprop_fg == nullptr) {
402 MS_LOG(DEBUG) << "Cannot find defined bprop for cnode prim: " << cnode->DebugString();
403 bprop_fg = BuildFakeBProp(prim, cnode->size() - 1);
404 }
405 }
406 MS_EXCEPTION_IF_NULL(bprop_fg);
407 (void)BuildAdjoint(cnode, op_args, out, bprop_fg);
408
409 return true;
410 }
411
GradPynativeWithBProp(const KPynativeCellPtr & k_cell,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & bprop_fg)412 bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
413 const ValuePtr &out, const FuncGraphPtr &bprop_fg) {
414 auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
415 return k_cell_impl->KPynativeWithBProp(cnode, op_args, out, bprop_fg);
416 }
417
KPynativeWithBProp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & bprop_fg)418 bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
419 const FuncGraphPtr &bprop_fg) {
420 MS_EXCEPTION_IF_NULL(cnode);
421 auto primal_fg = GetCNodeFuncGraph(cnode);
422 if (primal_fg == nullptr) {
423 MS_LOG(EXCEPTION) << "Should be func graph, but: " << cnode->DebugString();
424 }
425 MS_EXCEPTION_IF_NULL(bprop_fg);
426 (void)BuildAdjoint(cnode, op_args, out, bprop_fg);
427
428 return true;
429 }
430
KPynativeWithFProp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fprop_fg)431 bool KPynativeCellImpl::KPynativeWithFProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
432 const FuncGraphPtr &fprop_fg) {
433 MS_EXCEPTION_IF_NULL(cnode);
434 MS_EXCEPTION_IF_NULL(fprop_fg);
435
436 (void)BuildAdjoint(cnode, op_args, out, fprop_fg, PynativeAdjoint::kForwardPropagate);
437
438 return true;
439 }
440
UpdateOutputNodeOfTopCell(const AnfNodePtr & output_node)441 void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) {
442 MS_EXCEPTION_IF_NULL(output_node);
443 MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
444 last_node_ = output_node;
445
446 auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
447 if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
448 if (IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem) ||
449 IsPrimitiveCNode(output_node, prim::kPrimListGetItem)) {
450 MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << output_node->DebugString();
451 auto cnode = output_node->cast<CNodePtr>();
452 (void)ForgeGetItemAdjoint(cnode);
453 return;
454 } else if (output_node->isa<ValueNode>()) {
455 auto v_node = output_node->cast<ValueNodePtr>();
456 MS_LOG(DEBUG) << "Build adjoint for valuenode: " << v_node->ToString();
457 auto v_node_pynative_adjoint =
458 std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, v_node->value(), FuncGraphPtr(nullptr));
459 (void)anfnode_to_adjoin_.insert(std::make_pair(output_node, v_node_pynative_adjoint));
460 return;
461 }
462 MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
463 }
464 }
465
466 namespace {
ShallowCopyValue(const ValuePtr & value)467 ValuePtr ShallowCopyValue(const ValuePtr &value) {
468 MS_EXCEPTION_IF_NULL(value);
469 if (value->isa<mindspore::tensor::Tensor>()) {
470 auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
471 return std::make_shared<mindspore::tensor::Tensor>(*tensor_value);
472 } else if (value->isa<ValueTuple>()) {
473 std::vector<ValuePtr> values;
474 auto value_tuple = value->cast<ValueTuplePtr>();
475 (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
476 [](const ValuePtr &elem) { return ShallowCopyValue(elem); });
477 return std::make_shared<ValueTuple>(values);
478 } else {
479 return value;
480 }
481 }
482 } // namespace
483
ForgeGetItemAdjoint(const CNodePtr & cnode)484 PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) {
485 if (cnode->size() != 3) {
486 MS_LOG(EXCEPTION) << "TupleGetItem/ListGetItem CNode should have 3 inputs, but CNode: " << cnode->DebugString();
487 }
488 // Input 1 of CNode;
489 PynativeAdjointPtr input_1_adjoint = nullptr;
490 auto input_1 = cnode->input(1);
491 auto input_1_adjoint_iter = anfnode_to_adjoin_.find(input_1);
492 if (input_1_adjoint_iter == anfnode_to_adjoin_.end()) {
493 if (!input_1->isa<CNode>()) {
494 MS_LOG(EXCEPTION) << "Input 1 of CNode should be a CNode, CNode: " << cnode->DebugString();
495 }
496 input_1_adjoint = ForgeCNodeAdjoint(input_1->cast<CNodePtr>());
497 if (input_1_adjoint == nullptr) {
498 MS_LOG(EXCEPTION) << "Build adjoint for input 1 of CNode failed, CNode: " << cnode->DebugString();
499 }
500 input_1_adjoint->users().push_back(cnode);
501 } else {
502 input_1_adjoint = input_1_adjoint_iter->second;
503 }
504 if (!input_1_adjoint->out()->isa<ValueSequeue>()) {
505 MS_LOG(EXCEPTION) << "Input of CNode should be evaluated to a ValueSequence. CNode: " << cnode->DebugString()
506 << ", out of input1: " << input_1_adjoint->out()->ToString();
507 }
508 auto input_1_out = input_1_adjoint->out()->cast<ValueSequeuePtr>();
509
510 // Input 2 of CNode;
511 auto index_value = GetValueNode<Int64ImmPtr>(cnode->input(2));
512 if (index_value == nullptr) {
513 MS_LOG(EXCEPTION) << "CNode input 2 should be a Int64Imm, CNode: " << cnode->DebugString();
514 }
515 if (index_value->value() < 0) {
516 MS_LOG(EXCEPTION) << "CNode input 2 should not less than 0, CNode: " << cnode->DebugString();
517 }
518 size_t index_value_imm = LongToSize(index_value->value());
519 if (index_value_imm >= input_1_out->size()) {
520 MS_LOG(EXCEPTION) << "CNode input 2 should be index between [0, " << input_1_out->size()
521 << ", but: " << index_value->ToString();
522 }
523 auto cnode_out = (*input_1_out)[index_value_imm];
524 ValuePtrList op_args{input_1_out, index_value};
525 auto built = KPynativeOp(cnode, op_args, cnode_out);
526 if (!built) {
527 MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString();
528 }
529 auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
530 if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) {
531 MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString();
532 }
533 return cnode_adjoint_iter->second;
534 }
535
ForgeMakeSequenceAdjoint(const CNodePtr & cnode)536 PynativeAdjointPtr KPynativeCellImpl::ForgeMakeSequenceAdjoint(const CNodePtr &cnode) {
537 // () or [] is not supported yet.
538 if (cnode->size() <= 1) {
539 MS_LOG(DEBUG) << "MakeTuple/MakeList CNode is empty Tuple/List, CNode: " << cnode->DebugString();
540 auto empty_tuple = MakeValue(std::vector<ValuePtr>{});
541 auto dummy_adjoint =
542 std::make_shared<PynativeAdjoint>(FuncGraphPtr(nullptr), ValuePtrList{}, empty_tuple, FuncGraphPtr(nullptr));
543 anfnode_to_adjoin_[cnode] = dummy_adjoint;
544 cnode->set_stop_gradient(true);
545 return dummy_adjoint;
546 }
547 ValuePtrList op_args;
548 for (size_t i = 1; i < cnode->size(); ++i) {
549 const auto &input = cnode->input(i);
550 auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
551 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
552 MS_LOG(DEBUG) << "Item in CNode cannot found in cache. Input is: " << input->DebugString();
553 if (input->isa<CNode>()) {
554 const auto input_cnode = input->cast<CNodePtr>();
555 MS_EXCEPTION_IF_NULL(input_cnode);
556 auto forged_input_adjoint = ForgeCNodeAdjoint(input->cast<CNodePtr>());
557 op_args.push_back(forged_input_adjoint->out());
558 } else if (input->isa<ValueNode>()) {
559 const auto &input_value = GetValueNode(input);
560 op_args.push_back(input_value);
561 } else {
562 MS_LOG(EXCEPTION) << "Input of MakeTuple/MakeLis is not a CNode or ValueNode, but: " << input->DebugString();
563 }
564 } else {
565 op_args.push_back(input_adjoint_iter->second->out());
566 }
567 }
568 ValuePtr cnode_out = nullptr;
569 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
570 cnode_out = MakeValue(op_args);
571 }
572 if (IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
573 cnode_out = std::make_shared<ValueList>(op_args);
574 }
575 // op_args is real inputs find by prev cnode outputs
576 auto built = KPynativeOp(cnode, op_args, cnode_out);
577 if (!built) {
578 MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString();
579 }
580 auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
581 if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) {
582 MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString();
583 }
584 return cnode_adjoint_iter->second;
585 }
586
ForgeCNodeAdjoint(const CNodePtr & cnode)587 PynativeAdjointPtr KPynativeCellImpl::ForgeCNodeAdjoint(const CNodePtr &cnode) {
588 if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimListGetItem)) {
589 MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString();
590 return ForgeGetItemAdjoint(cnode);
591 }
592
593 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
594 MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString();
595 return ForgeMakeSequenceAdjoint(cnode);
596 }
597 MS_LOG(EXCEPTION) << "Unknown cnode: " << cnode->DebugString();
598 }
599
BuildAdjointForInput(const CNodePtr & cnode,const ValuePtrList & op_args)600 void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args) {
601 auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
602 if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
603 MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString();
604 }
605 // Book-keeping last cnode, as dout of this node will be given from outside;
606 last_node_ = cnode;
607
608 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
609 auto input = cnode->input(i);
610 auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
611 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
612 if (input->isa<CNode>()) {
613 auto cnode_input = input->cast<CNodePtr>();
614 auto forged_adjoint = ForgeCNodeAdjoint(cnode_input);
615 if (forged_adjoint == nullptr) {
616 MS_LOG(EXCEPTION) << "Cannot forge adjoint for anfnode: " << input->DebugString();
617 }
618 forged_adjoint->users().push_back(cnode);
619 } else {
620 MS_EXCEPTION_IF_NULL(op_args[i - 1]);
621 auto input_adjoint =
622 std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, op_args[i - 1], FuncGraphPtr(nullptr));
623 (void)anfnode_to_adjoin_.insert(std::make_pair(input, input_adjoint));
624 input_adjoint->users().push_back(cnode);
625 }
626 } else {
627 input_adjoint_iter->second->users().push_back(cnode);
628 }
629 }
630 }
631
BuildAdjoint(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fg,const PynativeAdjoint::FuncGraphType fg_type)632 bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
633 const FuncGraphPtr &fg, const PynativeAdjoint::FuncGraphType fg_type) {
634 // Optimize the bprop_fg based on value.
635 // Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor
636 // is not used in bprop_fg;
637 ValuePtrList cloned_op_args;
638 (void)std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args),
639 [](const ValuePtr &value) { return ShallowCopyValue(value); });
640 ValuePtr cloned_out = ShallowCopyValue(out);
641 PynativeAdjointPtr cnode_adjoint;
642 if (fg_type == PynativeAdjoint::kBackwardPropagate) {
643 auto optimized_bprop_fg = OptimizeBPropFuncGraph(fg, cnode, cloned_op_args, cloned_out);
644 cnode_adjoint = std::make_shared<PynativeAdjoint>(tape_, cloned_op_args, cloned_out, optimized_bprop_fg);
645 } else {
646 cnode_adjoint = std::make_shared<PynativeAdjoint>(tape_, cloned_op_args, cloned_out, fg, fg_type);
647 }
648
649 BuildAdjointForInput(cnode, op_args);
650
651 (void)anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_adjoint));
652
653 return true;
654 }
655
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)656 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
657 const ValuePtr &out) {
658 auto optimized_bprop_fg =
659 PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
660 return optimized_bprop_fg;
661 }
662
BackPropagate(const CNodePtr & cnode_primal,const CNodePtr & bprop_app)663 bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app) {
664 abstract::AbstractTuplePtr abstract_tuple = nullptr;
665 auto bprop_app_abstract = bprop_app->abstract();
666 // if input 0 of bprop_app is a CNode other than FuncGraph ValueNode, bprop_app_abstract is nullptr;
667 // After tape_ returned, caller should renormalize tape_ to set abstract of each AnfNode.
668 if (bprop_app_abstract != nullptr) {
669 abstract_tuple = bprop_app_abstract->cast<abstract::AbstractTuplePtr>();
670 if (abstract_tuple->size() != (cnode_primal->size() - 1)) {
671 MS_LOG(EXCEPTION) << "AbstractTuple size: " << abstract_tuple->ToString()
672 << " not match primal cnode input size: " << cnode_primal->DebugString();
673 }
674 }
675 for (size_t i = 1; i < cnode_primal->size(); i++) {
676 auto input = cnode_primal->input(i);
677 // Useless to accumulate sens for ValueNode, the sens for ValueNode should be zeros_like;
678 if (input->isa<ValueNode>()) {
679 continue;
680 }
681 auto cnode_input = input->cast<CNodePtr>();
682 if (cnode_input != nullptr && cnode_input->stop_gradient()) {
683 MS_LOG(DEBUG) << "Bypass accumulate dout to cnode with stop_gradient flag, cnode: " << input->DebugString();
684 continue;
685 }
686 // Backprop sens wrt inputs.
687 auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
688 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
689 MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->DebugString();
690 }
691 AnfNodePtr din;
692 if (abstract_tuple != nullptr) {
693 din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i - 1))});
694 din->set_abstract((*abstract_tuple)[i - 1]);
695 } else {
696 // bprop_app[0] is env;
697 din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
698 din->set_abstract(input_adjoint_iter->second->out()->ToAbstract()->Broaden());
699 }
700 input_adjoint_iter->second->AccumulateDout(din);
701 }
702 return true;
703 }
704
BuildKNodeForCNodeInput(const PynativeAdjointPtr & cnode_adjoint,const AnfNodePtr & input_node,size_t input_index)705 AnfNodePtr KPynativeCellImpl::BuildKNodeForCNodeInput(const PynativeAdjointPtr &cnode_adjoint,
706 const AnfNodePtr &input_node, size_t input_index) {
707 MS_EXCEPTION_IF_NULL(cnode_adjoint);
708 MS_EXCEPTION_IF_NULL(input_node);
709 if (input_node->isa<CNode>()) {
710 auto input_adjoint_iter = anfnode_to_adjoin_.find(input_node);
711 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
712 MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
713 }
714 return input_adjoint_iter->second->k_node();
715 } else {
716 if (input_node->isa<Parameter>()) {
717 bool is_weight = input_node->cast<ParameterPtr>()->has_default();
718 // If weight does not need to calculate gradient, it will be converted to value node.
719 if (is_weight && need_grad_weights_.find(input_node) == need_grad_weights_.end()) {
720 return NewValueNode(cnode_adjoint->op_args()[input_index - 1]);
721 }
722 }
723 return input_node;
724 }
725 }
726
BuildKNodeListFromPrimalCNode(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint)727 const AnfNodePtrList KPynativeCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode,
728 const PynativeAdjointPtr &adjoint) {
729 MS_EXCEPTION_IF_NULL(cnode);
730 MS_EXCEPTION_IF_NULL(adjoint);
731 AnfNodePtrList node_list;
732 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
733 (void)node_list.emplace_back(BuildKNodeForCNodeInput(adjoint, cnode->input(i), i));
734 }
735 return node_list;
736 }
737
BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint,const FuncGraphPtr & bprop_fg,bool by_value)738 bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode,
739 const PynativeAdjointPtr &adjoint,
740 const FuncGraphPtr &bprop_fg, bool by_value) {
741 AnfNodePtrList node_list;
742 abstract::AbstractBasePtr bprop_output_abs;
743
744 bprop_output_abs = bprop_fg->output()->abstract();
745 if (bprop_output_abs == nullptr) {
746 MS_LOG(EXCEPTION) << "Abstract of bprop_output_abs is not AbstractTuple, but nullptr";
747 }
748 if (!bprop_output_abs->isa<abstract::AbstractTuple>()) {
749 MS_LOG(EXCEPTION) << "Abstract of bprop_output_abs is not AbstractTuple, but: " << bprop_output_abs->ToString();
750 }
751 node_list.push_back(NewValueNode(bprop_fg));
752
753 if (by_value) {
754 for (size_t i = 0; i < adjoint->op_args().size(); ++i) {
755 auto input_node = cnode->input(i + 1);
756 if (input_node->isa<Parameter>()) {
757 bool is_weight = input_node->cast<ParameterPtr>()->has_default();
758 if (!is_weight || need_grad_weights_.find(input_node) != need_grad_weights_.end()) {
759 node_list.push_back(input_node);
760 continue;
761 }
762 }
763 auto v_node = NewValueNode(adjoint->op_args()[i]);
764 v_node->set_abstract(adjoint->op_args()[i]->ToAbstract()->Broaden());
765 node_list.push_back(v_node);
766 }
767 auto out_node = NewValueNode(adjoint->out());
768 out_node->set_abstract(adjoint->out()->ToAbstract()->Broaden());
769 node_list.push_back(out_node);
770 node_list.push_back(adjoint->RealDout());
771 } else {
772 const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint);
773 (void)node_list.insert(node_list.end(), k_node_list.begin(), k_node_list.end());
774 // out;
775 node_list.push_back(adjoint->k_node());
776 // dout
777 node_list.push_back(adjoint->RealDout());
778 }
779 // Back propagate process
780 auto bprop_app = tape_->NewCNode(node_list);
781 bprop_app->set_abstract(bprop_output_abs);
782 (void)BackPropagate(cnode, bprop_app);
783 return true;
784 }
785
BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint,const FuncGraphPtr & fprop_fg,bool by_value)786 bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &cnode,
787 const PynativeAdjointPtr &adjoint,
788 const FuncGraphPtr &fprop_fg, bool by_value) {
789 MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
790
791 AnfNodePtrList node_list;
792 CNodePtr bprop_cnode;
793 if (by_value) {
794 AnfNodePtrList args_node_list;
795 for (size_t i = 0; i < adjoint->op_args().size(); ++i) {
796 auto input_node = cnode->input(i + 1);
797 if (input_node->isa<Parameter>()) {
798 bool is_weight = input_node->cast<ParameterPtr>()->has_default();
799 if (!is_weight || need_grad_weights_.find(input_node) != need_grad_weights_.end()) {
800 args_node_list.push_back(input_node);
801 continue;
802 }
803 }
804 auto v_node = NewValueNode(adjoint->op_args()[i]);
805 v_node->set_abstract(adjoint->op_args()[i]->ToAbstract()->Broaden());
806 args_node_list.push_back(v_node);
807 }
808 bprop_cnode = GetBPropFromFProp(fprop_fg, args_node_list);
809 } else {
810 const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint);
811 bprop_cnode = GetBPropFromFProp(fprop_fg, k_node_list);
812 }
813 node_list.push_back(bprop_cnode);
814 // dout;
815 node_list.push_back(adjoint->RealDout());
816 // Back propagate process
817 auto bprop_app = tape_->NewCNode(node_list);
818 (void)BackPropagate(cnode, bprop_app);
819 return true;
820 }
821
GetLastNodeReverseIter()822 OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator KPynativeCellImpl::GetLastNodeReverseIter() {
823 for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
824 if (!iter->first->isa<CNode>()) {
825 continue;
826 }
827 if (iter->first->cast<CNodePtr>() == last_node_) {
828 return iter;
829 }
830 }
831 return anfnode_to_adjoin_.rend();
832 }
833
BackPropagate(bool by_value)834 bool KPynativeCellImpl::BackPropagate(bool by_value) {
835 auto last_node_reverse_iter = GetLastNodeReverseIter();
836 for (auto iter = last_node_reverse_iter; iter != anfnode_to_adjoin_.rend(); ++iter) {
837 if (!iter->first->isa<CNode>()) {
838 continue;
839 }
840 auto cnode = iter->first->cast<CNodePtr>();
841 if (cnode->stop_gradient()) {
842 MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->DebugString();
843 continue;
844 }
845 MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
846 auto fg = iter->second->fg();
847 auto fg_type = iter->second->fg_type();
848 if (fg_type == PynativeAdjoint::kBackwardPropagate) {
849 (void)BackPropagateOneCNodeWithBPropFuncGraph(cnode, iter->second, fg, by_value);
850 } else {
851 (void)BackPropagateOneCNodeWithFPropFuncGraph(cnode, iter->second, fg, by_value);
852 }
853 }
854 return true;
855 }
856
AllReferencesStopped(const CNodePtr & curr_cnode)857 bool KPynativeCellImpl::AllReferencesStopped(const CNodePtr &curr_cnode) {
858 // If all CNode use curr_cnode has stop_gradient_ flag, then curr_cnode also can set that flag.
859 auto iter = anfnode_to_adjoin_.find(curr_cnode);
860 if (iter == anfnode_to_adjoin_.end()) {
861 MS_LOG(EXCEPTION) << "Cannot find adjoint for cnode: " << curr_cnode->DebugString();
862 }
863 auto users = iter->second->users();
864 if (users.empty()) {
865 return false;
866 }
867 auto all_users_have_stopped = std::all_of(users.cbegin(), users.cend(), [](const AnfNodePtr &user) {
868 if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
869 return false;
870 }
871 return true;
872 });
873 return all_users_have_stopped;
874 }
875
PropagateStopGradient()876 void KPynativeCellImpl::PropagateStopGradient() {
877 // propagate need_stop_gradient_ to cnode before back propagate;
878 if (need_propagate_stop_gradient_) {
879 for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
880 const auto &node = iter->first;
881 if (node->isa<CNode>()) {
882 auto cnode = node->cast<CNodePtr>();
883 if (!cnode->stop_gradient()) {
884 // Cut off the cnode only when it's not referred any more
885 if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
886 AllReferencesStopped(cnode)) {
887 MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->DebugString();
888 cnode->set_stop_gradient(true);
889 }
890 }
891 }
892 }
893 }
894 }
895
BuildBPropCutFuncGraph(const PrimitivePtr & prim,const CNodePtr & cnode)896 FuncGraphPtr KPynativeCellImpl::BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
897 auto inputs_num = cnode->size() - 1;
898
899 auto func_graph = std::make_shared<FuncGraph>();
900 std::vector<AnfNodePtr> outputs;
901
902 auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
903 bprop_cut->CopyHookFunction(prim);
904
905 auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
906 if (cell_id != "") {
907 (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
908 (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
909 }
910
911 outputs.push_back(NewValueNode(bprop_cut));
912 for (size_t i = 0; i < inputs_num; ++i) {
913 auto param = func_graph->add_parameter();
914 outputs.push_back(param);
915 }
916 // out, dout
917 auto p1 = func_graph->add_parameter();
918 auto p2 = func_graph->add_parameter();
919 outputs.push_back(p1);
920 outputs.push_back(p2);
921
922 func_graph->set_output(func_graph->NewCNode(outputs));
923 return func_graph;
924 }
925
BuildMakeSequenceBprop(const PrimitivePtr & prim,const CNodePtr & cnode)926 FuncGraphPtr KPynativeCellImpl::BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode) {
927 auto inputs_num = cnode->size() - 1;
928 CacheKey key{prim->name(), inputs_num};
929 auto bprop_func_graph_iter = bprop_func_graph_cache.find(key);
930 if (bprop_func_graph_iter != bprop_func_graph_cache.end()) {
931 return bprop_func_graph_iter->second;
932 }
933
934 FuncGraphPtr b = std::make_shared<FuncGraph>();
935
936 std::ostringstream ss;
937 ss << "◀" << prim->ToString() << inputs_num;
938 b->debug_info()->set_name(ss.str());
939 for (size_t i = 0; i < inputs_num; ++i) {
940 auto param = b->add_parameter();
941 MS_EXCEPTION_IF_NULL(param);
942 }
943 // out, dout
944 auto p1 = b->add_parameter();
945 MS_EXCEPTION_IF_NULL(p1);
946 AnfNodePtr dout = b->add_parameter();
947
948 std::vector<AnfNodePtr> grads;
949 PrimitivePtr getitem_prim;
950
951 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
952 getitem_prim = prim::kPrimTupleGetItem;
953 } else if (IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
954 getitem_prim = prim::kPrimListGetItem;
955 } else {
956 MS_LOG(EXCEPTION) << "Prim should be MakeTuple or MakeList, Invalid prim: " << prim->ToString();
957 }
958
959 grads.push_back(NewValueNode(prim));
960 for (size_t i = 0; i < inputs_num; ++i) {
961 grads.push_back(b->NewCNode({NewValueNode(getitem_prim), dout, NewValueNode(SizeToLong(i))}));
962 }
963
964 b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
965 b->set_output(b->NewCNode(grads));
966
967 bprop_func_graph_cache[key] = b;
968 return b;
969 }
970
SetSensAndWeights(const AnfNodePtrList & weights,bool has_sens_arg)971 void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
972 MS_EXCEPTION_IF_NULL(last_node_);
973 MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
974 auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
975 if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
976 MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
977 }
978 // Add sens parameter
979 if (has_sens_arg) {
980 auto sens_param = tape_->add_parameter();
981 sens_param->debug_info()->set_name("sens");
982 sens_param->set_abstract(last_node_adjoint_iter->second->out()->ToAbstract()->Broaden());
983 // Set dout of last node to sens;
984 last_node_adjoint_iter->second->AccumulateDout(sens_param);
985 } else {
986 auto sens_node = BuildOnesLikeValue(tape_, last_node_adjoint_iter->second->out());
987 last_node_adjoint_iter->second->AccumulateDout(sens_node);
988 }
989 // Add weights parameter
990 need_grad_weights_.clear();
991 for (const auto &weight : weights) {
992 TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
993 auto p = tape_->add_parameter();
994 (void)need_grad_weights_.emplace(weight);
995 auto input_w = weight->cast<ParameterPtr>();
996 MS_EXCEPTION_IF_NULL(input_w);
997 // Use name to match weight parameter in high order
998 p->set_name(input_w->name());
999 p->set_default_param(input_w->default_param());
1000 }
1001 }
1002
SetOutput(const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights)1003 void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
1004 AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
1005 AbstractBasePtr grad_inputs_spec;
1006 if (grad_inputs) {
1007 AbstractBasePtrList grad_inputs_abs_list;
1008 for (const auto &input : cell_inputs_) {
1009 MS_EXCEPTION_IF_NULL(input);
1010 auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
1011 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1012 // If input is not used in the network, just return zeros_like() as dout;
1013 MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
1014 auto dout = BuildZerosLikeNode(tape_, input);
1015 grad_inputs_list.push_back(dout);
1016 } else {
1017 grad_inputs_list.push_back(input_adjoint_iter->second->RealDout());
1018 }
1019 grad_inputs_abs_list.push_back(grad_inputs_list.back()->abstract());
1020 }
1021 grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
1022 }
1023
1024 AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
1025 AbstractBasePtr grad_weights_spec;
1026 if (grad_weights) {
1027 AbstractBasePtrList grad_weights_abs_list;
1028 for (const auto &weight : weights) {
1029 MS_EXCEPTION_IF_NULL(weight);
1030 auto input_adjoint_iter = anfnode_to_adjoin_.find(weight);
1031 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1032 // If weight is not used in the network, just return zeros_like() as dout;
1033 MS_LOG(WARNING) << "Weight is not used in network, weight: " << weight->ToString();
1034 auto input_w = weight->cast<ParameterPtr>();
1035 MS_EXCEPTION_IF_NULL(input_w);
1036 auto default_param = input_w->default_param();
1037 MS_EXCEPTION_IF_NULL(default_param);
1038 auto dout = BuildZerosLikeValue(tape_, default_param);
1039 grad_weights_list.push_back(dout);
1040 } else {
1041 grad_weights_list.push_back(input_adjoint_iter->second->RealDout());
1042 }
1043 grad_weights_abs_list.push_back(grad_weights_list.back()->abstract());
1044 }
1045 grad_weights_spec = std::make_shared<abstract::AbstractTuple>(grad_weights_abs_list);
1046 }
1047
1048 AnfNodePtr tape_output;
1049 if (grad_inputs && grad_weights) {
1050 tape_output = tape_->NewCNode(
1051 {NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
1052 tape_output->set_abstract(
1053 std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
1054 } else if (grad_inputs) {
1055 tape_output = tape_->NewCNode(grad_inputs_list);
1056 tape_output->set_abstract(grad_inputs_spec);
1057 } else if (grad_weights) {
1058 tape_output = tape_->NewCNode(grad_weights_list);
1059 tape_output->set_abstract(grad_weights_spec);
1060 } else if (cell_inputs_.empty()) {
1061 tape_output = tape_->NewCNode(grad_inputs_list);
1062 tape_output->set_abstract(grad_inputs_spec);
1063 } else {
1064 auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[0]);
1065 if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1066 // If input is not used in the network, just return zeros_like() as dout;
1067 MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[0]->ToString();
1068 tape_output = BuildZerosLikeNode(tape_, cell_inputs_[0]);
1069 } else {
1070 tape_output = input_adjoint_iter->second->RealDout();
1071 }
1072 }
1073 tape_->set_output(tape_output);
1074 }
1075
BuildKNode()1076 bool KPynativeCellImpl::BuildKNode() {
1077 for (auto iter = anfnode_to_adjoin_.cbegin(); iter != anfnode_to_adjoin_.cend(); ++iter) {
1078 if (!iter->first->isa<CNode>()) {
1079 continue;
1080 }
1081
1082 AnfNodePtrList node_list;
1083 auto cnode = iter->first->cast<CNodePtr>();
1084 MS_EXCEPTION_IF_NULL(cnode);
1085 for (size_t i = 0; i < cnode->inputs().size(); ++i) {
1086 (void)node_list.emplace_back(BuildKNodeForCNodeInput(iter->second, cnode->input(i), i));
1087 }
1088 auto k_node = tape_->NewCNode(node_list);
1089 k_node->set_abstract(iter->second->out()->ToAbstract()->Broaden());
1090 iter->second->set_k_node(k_node);
1091 }
1092 return true;
1093 }
1094
GetBPropFromFProp(const FuncGraphPtr & fprop_fg,const AnfNodePtrList & args)1095 CNodePtr KPynativeCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args) {
1096 // Wrap tuple_getitem(fprop_app, 1) in a FuncGraph and optimize it;
1097 auto bprop_builder = std::make_shared<FuncGraph>();
1098 bprop_builder->debug_info()->set_name("bprop_builder");
1099
1100 AnfNodePtrList fprop_app_inputs{NewValueNode(fprop_fg)};
1101 AnfNodePtrList bprop_builder_inputs;
1102 for (const auto &arg : args) {
1103 auto param = bprop_builder->add_parameter();
1104 fprop_app_inputs.push_back(param);
1105 bprop_builder_inputs.push_back(arg);
1106 }
1107 auto fprop_app = bprop_builder->NewCNode(fprop_app_inputs);
1108 auto get_bprop =
1109 bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
1110 bprop_builder->set_output(get_bprop);
1111 (void)bprop_builder_inputs.insert(bprop_builder_inputs.begin(), NewValueNode(bprop_builder));
1112 get_bprop = tape_->NewCNode(bprop_builder_inputs);
1113
1114 return get_bprop;
1115 }
1116
ReplacePrimalParameter(const AnfNodePtrList & weights,bool has_sens_arg)1117 void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
1118 auto mng = MakeManager({tape_}, false);
1119 auto tr = mng->Transact();
1120 const auto ¶meters = tape_->parameters();
1121 auto cell_inputs_size = cell_inputs_.size();
1122 for (size_t i = 0; i < cell_inputs_size; ++i) {
1123 (void)tr.Replace(cell_inputs_[i], parameters[i]);
1124 }
1125 // (Inputs, sens, weights) or (Inputs, weights)
1126 size_t weight_offset = cell_inputs_size;
1127 if (has_sens_arg) {
1128 weight_offset = weight_offset + 1;
1129 }
1130 for (size_t i = 0; i < weights.size(); ++i) {
1131 (void)tr.Replace(weights[i], parameters[weight_offset + i]);
1132 }
1133 tr.Commit();
1134 }
1135
ClearKPynativeCellStaticRes()1136 void ClearKPynativeCellStaticRes() {
1137 irpass = nullptr;
1138 add_ops = nullptr;
1139 ones_like_ops = nullptr;
1140 zeros_like_ops = nullptr;
1141 g_k_prims_pynative.clear();
1142 bprop_func_graph_cache.clear();
1143 zeros_like_funcgraph_cache.clear();
1144 ones_like_funcgraph_cache.clear();
1145 }
1146 } // namespace ad
1147 } // namespace mindspore
1148