• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/jit/jit_grad.h"
18 
19 #include <utility>
20 #include "frontend/optimizer/ad/grad.h"
21 #include "ops/structure_op_name.h"
22 #include "ops/framework_op_name.h"
23 #include "ops/sequence_ops.h"
24 #include "pipeline/pynative/pynative_utils.h"
25 #include "pipeline/pynative/grad/jit/jit_dfunctor.h"
26 #include "ir/func_graph_cloner.h"
27 #include "frontend/expander/bprop/bprop.h"
28 
29 namespace mindspore {
30 namespace pynative {
31 namespace {
32 constexpr char kAddedValue[] = "added_value";
33 
34 const mindspore::HashSet<std::string> kExpanderWhiteList{
35   kVmapStackAssignOpName,
36   kVmapUnstackAssignOpName,
37   kPyExecuteOpName,
38   kPrintOpName,
39 };
40 
GetOpRunInfo(const py::object & out,const py::args & args,const std::string & graph_phase,bool modify_output,const FuncGraphPtr & jit_forward_graph,ValuePtr * added_out_v)41 FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, const std::string &graph_phase,
42                                   bool modify_output, const FuncGraphPtr &jit_forward_graph, ValuePtr *added_out_v) {
43   auto op_run_info = std::make_shared<FrontendOpRunInfo>();
44   op_run_info->requires_grad = true;
45   op_run_info->is_jit_input = true;
46   op_run_info->base_op_run_info.op_name = graph_phase;
47   PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
48   // Set input abs
49   const auto &original_params = jit_forward_graph->parameters();
50   for (size_t i = 0; i < op_run_info->input_size; ++i) {
51     op_run_info->op_grad_info->input_abs[i] = original_params[i]->abstract();
52   }
53   if (modify_output) {
54     if (!py::isinstance<py::tuple>(out)) {
55       MS_LOG(EXCEPTION) << "The output value of jit func graph should be a tuple.";
56     }
57     auto tuple_out = py::cast<py::tuple>(out);
58     constexpr size_t tuple_out_size = 2;
59     if (tuple_out.size() != tuple_out_size) {
60       MS_LOG(EXCEPTION) << "The tuple size of output value of jit func graph should be 2.";
61     }
62     MS_EXCEPTION_IF_NULL(added_out_v);
63     // Forward output of op in jit graph
64     *added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
65     op_run_info->real_out = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[0]);
66   } else {
67     op_run_info->real_out = PyNativeAlgo::DataConvert::PyObjToValue(out);
68   }
69   return op_run_info;
70 }
71 
GetTensorNumFromAbstract(const abstract::AbstractBasePtr & abs)72 size_t GetTensorNumFromAbstract(const abstract::AbstractBasePtr &abs) {
73   MS_EXCEPTION_IF_NULL(abs);
74   if (abs->isa<abstract::AbstractTensor>()) {
75     // Is a tensor
76     constexpr size_t kTensorOutputNum = 1;
77     return kTensorOutputNum;
78   } else if (abs->isa<abstract::AbstractSequence>()) {
79     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>()->elements();
80     return std::accumulate(abs_seq.begin(), abs_seq.end(), 0, [](size_t out_num, const abstract::AbstractBasePtr &abs) {
81       return out_num + GetTensorNumFromAbstract(abs);
82     });
83   } else if (abs->isa<abstract::AbstractCSRTensor>()) {
84     // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
85     constexpr size_t kCSRTensorOutputNum = 5;
86     return kCSRTensorOutputNum;
87   } else if (abs->isa<abstract::AbstractCOOTensor>()) {
88     // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
89     constexpr size_t kCOOTensorOutputNum = 4;
90     return kCOOTensorOutputNum;
91   }
92   return 0;
93 }
94 
95 // Modify the output node of func_graph to add forward nodes used in bprop graph.
ModifyOutputNode(const FuncGraphPtr & func_graph)96 void ModifyOutputNode(const FuncGraphPtr &func_graph) {
97   MS_EXCEPTION_IF_NULL(func_graph);
98   const auto &used_forward_nodes = func_graph->used_forward_nodes();
99   if (used_forward_nodes.empty()) {
100     return;
101   }
102 
103   // Create a new make tuple node to hold all forward used nodes.
104   abstract::AbstractBasePtrList added_abs_list;
105   AnfNodePtrList added_node_list{NewValueNode(prim::kPrimMakeTuple)};
106   for (const auto &node : used_forward_nodes) {
107     MS_EXCEPTION_IF_NULL(node);
108     (void)added_node_list.emplace_back(node);
109     (void)added_abs_list.emplace_back(node->abstract());
110   }
111   AnfNodePtr added_output_node = func_graph->NewCNode(std::move(added_node_list));
112   AbstractBasePtr added_output_abs = std::make_shared<abstract::AbstractTuple>(added_abs_list);
113   added_output_node->set_abstract(added_output_abs);
114 
115   // Get original output node and abstract, and merge original output node and used forward nodes to return node.
116   auto original_output_node = func_graph->output();
117   MS_EXCEPTION_IF_NULL(original_output_node);
118   auto original_output_abs = original_output_node->abstract();
119   MS_EXCEPTION_IF_NULL(original_output_abs);
120   AnfNodePtrList new_output_nodes{NewValueNode(prim::kPrimMakeTuple), original_output_node, added_output_node};
121   auto merge_node = func_graph->NewCNode(std::move(new_output_nodes));
122   abstract::AbstractBasePtrList new_output_abs{original_output_abs, added_output_abs};
123   merge_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_output_abs));
124   func_graph->set_output(merge_node);
125 
126   // Clear
127   func_graph->set_modify_output(true);
128   func_graph->ClearUsedForwardNodes();
129 }
130 
GetAddedNode(const FuncGraphPtr & jit_forward_graph)131 CNodePtr GetAddedNode(const FuncGraphPtr &jit_forward_graph) {
132   MS_EXCEPTION_IF_NULL(jit_forward_graph);
133   if (!jit_forward_graph->modify_output()) {
134     return nullptr;
135   }
136   // Get added forward nodes.
137   auto merge_node = jit_forward_graph->output();
138   MS_EXCEPTION_IF_NULL(merge_node);
139   auto merge_make_tuple = merge_node->cast<CNodePtr>();
140   MS_EXCEPTION_IF_NULL(merge_make_tuple);
141   constexpr size_t merge_output_size = 3;
142   // First is make_tuple, second is actual output, third is added output
143   if (merge_make_tuple->size() != merge_output_size) {
144     MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
145   }
146   constexpr size_t added_output_index = 2;
147   return merge_make_tuple->input(added_output_index)->cast<CNodePtr>();
148 }
149 
IsGraphDynamic(const FuncGraphPtr & func_graph)150 bool IsGraphDynamic(const FuncGraphPtr &func_graph) {
151   for (const auto &param : func_graph->parameters()) {
152     if (param->isa<Parameter>() && !param->cast<ParameterPtr>()->has_default()) {
153       const auto &abs = param->abstract();
154       if (abs != nullptr && abs->BuildShape()->IsDynamic()) {
155         return true;
156       }
157     }
158   }
159   MS_EXCEPTION_IF_NULL(func_graph->output());
160   if (auto abs = func_graph->output()->abstract(); abs != nullptr && abs->BuildShape()->IsDynamic()) {
161     return true;
162   }
163   return false;
164 }
165 
JitOutputHasDict(const abstract::AbstractBasePtr & abs)166 bool JitOutputHasDict(const abstract::AbstractBasePtr &abs) {
167   MS_EXCEPTION_IF_NULL(abs);
168   if (abs->isa<abstract::AbstractDictionary>()) {
169     return true;
170   } else if (abs->isa<abstract::AbstractSequence>()) {
171     const auto &abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
172     return std::any_of(abs_sequence->elements().begin(), abs_sequence->elements().end(),
173                        [](const abstract::AbstractBasePtr &item) { return JitOutputHasDict(item); });
174   }
175   return false;
176 }
177 }  // namespace
178 
RunReplace(const CNodePtr & added_node,const ValuePtrList & total_output_tensors) const179 void Jit::RunReplace(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const {
180   MS_EXCEPTION_IF_NULL(added_node);
181   size_t index = 0;
182   for (size_t i = 1; i < added_node->size(); ++i) {
183     const auto &input_i = added_node->input(i);
184     MS_EXCEPTION_IF_NULL(input_i);
185     auto cnode = input_i->cast<CNodePtr>();
186     MS_EXCEPTION_IF_NULL(cnode);
187     MS_LOG(DEBUG) << "Replace output tensors for cnode: " << cnode->DebugString();
188     const auto &output_vnode = cnode->forward().first;
189     MS_EXCEPTION_IF_NULL(output_vnode);
190     MS_LOG(DEBUG) << "Old output value node: " << output_vnode->ToString();
191     MS_EXCEPTION_IF_NULL(output_vnode->abstract());
192     bool is_tuple_out = output_vnode->abstract()->isa<abstract::AbstractSequence>();
193     size_t output_num = GetTensorNumFromAbstract(cnode->abstract());
194     if (output_num == 0) {
195       MS_LOG(DEBUG) << "The output value out is not include tensor";
196       continue;
197     }
198     if (index + output_num > total_output_tensors.size()) {
199       MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
200                         << ", but the current index: " << index << ", output num: " << output_num;
201     }
202     // Get new tensors.
203     std::vector<ValuePtr> new_values;
204     for (size_t j = index; j < index + output_num; ++j) {
205       // If jit graph reused in dynamic shape, added output tensor should be update tensor address in run actor
206       auto tensor = total_output_tensors[j]->cast<tensor::BaseTensorPtr>();
207       if (tensor != nullptr) {
208         tensor->set_is_forward_output(true);
209       }
210       (void)new_values.emplace_back(total_output_tensors[j]);
211     }
212     index = index + output_num;
213     // Replace new tensors.
214     // Can not use output_num > 1, because output can be (a), tuple just have only one element
215     if (is_tuple_out) {
216       output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
217     } else {
218       output_vnode->set_value(new_values[0]);
219     }
220     MS_LOG(DEBUG) << "New output value node: " << output_vnode->ToString();
221   }
222   // Save op info with new tensors for current running jit func graph.
223   if (index != total_output_tensors.size()) {
224     MS_LOG(EXCEPTION) << "The index: " << index
225                       << " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
226   }
227 }
228 
ReplaceAddedCnodeActualOutput(const CNodePtr & added_node,const ValuePtrList & total_output_tensors) const229 void Jit::ReplaceAddedCnodeActualOutput(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const {
230   MS_EXCEPTION_IF_NULL(added_node);
231   // Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
232   MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_node->DebugString();
233   // The forward node in jit graph is created during compilation and is a placeholder.
234   // After running jit, need to update to real value.
235   RunReplace(added_node, total_output_tensors);
236 }
237 
GetInputArgsNode(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,AnfNodePtrList * input_nodes) const238 void Jit::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
239                            AnfNodePtrList *input_nodes) const {
240   MS_EXCEPTION_IF_NULL(op_run_info);
241   MS_EXCEPTION_IF_NULL(input_nodes);
242   MS_EXCEPTION_IF_NULL(grad_executor);
243   for (size_t i = 0; i < op_run_info->input_size; ++i) {
244     const auto &input_i_value = op_run_info->op_grad_info->input_value[i];
245     const auto &id = PyNativeAlgo::Common::GetIdByValue(input_i_value);
246     const auto &input_i_node = grad_executor->GetInput(input_i_value, id);
247     MS_EXCEPTION_IF_NULL(input_i_node);
248     MS_LOG(DEBUG) << "The input " << i << " id " << id << " , node is: " << input_i_node->DebugString();
249     (void)input_nodes->emplace_back(input_i_node);
250   }
251 }
252 
GetWeightsNode(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,AnfNodePtrList * input_nodes) const253 void Jit::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
254                          const FuncGraphPtr &jit_forward_graph, AnfNodePtrList *input_nodes) const {
255   MS_EXCEPTION_IF_NULL(grad_executor);
256   MS_EXCEPTION_IF_NULL(input_nodes);
257   const auto &top_cell = grad_executor->top_cell();
258   const auto &graph_info = top_cell->graph_info_map().at(top_cell->fg());
259   MS_EXCEPTION_IF_NULL(graph_info);
260   // Get weights info of jit
261   MS_EXCEPTION_IF_NULL(jit_forward_graph);
262   const auto &original_params = jit_forward_graph->parameters();
263   size_t params_size = original_params.size();
264   MS_EXCEPTION_IF_NULL(op_run_info);
265   for (size_t i = 0; i < params_size; ++i) {
266     if (i < op_run_info->input_size) {  // non-weights node.
267       continue;
268     }
269     // Must weight param
270     auto param = original_params[i]->cast<ParameterPtr>();
271     const auto tensor_value = PyNativeAlgo::Common::GetTensorFromParam(original_params[i]);
272     MS_EXCEPTION_IF_NULL(tensor_value);
273     const auto it = graph_info->weight_params.find(tensor_value->id());
274     if (it != graph_info->weight_params.end()) {
275       param = it->second;
276     } else {
277       top_cell->fg()->add_parameter(param);
278       param->debug_info()->set_name(param->name());
279       top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
280     }
281     (void)input_nodes->emplace_back(param);
282     MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
283                   << tensor_value->ToString() << ". Its name is: " << param->name();
284   }
285 }
286 
MakeCNodeForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,CNodePtr * jit_cnode) const287 void Jit::MakeCNodeForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
288                           const FuncGraphPtr &jit_forward_graph, CNodePtr *jit_cnode) const {
289   MS_EXCEPTION_IF_NULL(op_run_info);
290   MS_EXCEPTION_IF_NULL(jit_forward_graph);
291   // Get input node info of jit
292   AnfNodePtrList input_nodes{NewValueNode(jit_forward_graph)};
293   MS_EXCEPTION_IF_NULL(grad_executor);
294   GetInputArgsNode(op_run_info, grad_executor, &input_nodes);
295   // Get weights node info of jit.
296   GetWeightsNode(op_run_info, grad_executor, jit_forward_graph, &input_nodes);
297   // Make a CNode which includes jit fprop graph and inputs node
298   MS_EXCEPTION_IF_NULL(jit_cnode);
299   *jit_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
300   (*jit_cnode)->set_abstract(jit_forward_graph->output()->abstract());
301   MS_LOG(DEBUG) << "Make jit forward CNode: " << (*jit_cnode)->DebugString();
302 }
303 
MakeAdjointForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,const FuncGraphPtr & jit_grad_graph,bool has_added_v) const304 void Jit::MakeAdjointForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
305                             const FuncGraphPtr &jit_forward_graph, const FuncGraphPtr &jit_grad_graph,
306                             bool has_added_v) const {
307   MS_EXCEPTION_IF_NULL(op_run_info);
308   MS_EXCEPTION_IF_NULL(grad_executor);
309 
310   const auto &top_cell = grad_executor->top_cell();
311   PyNativeAlgo::Common::SetGraphInputAndWeightsInfo(op_run_info, jit_forward_graph, top_cell);
312   RecordForwardGraphForJit(op_run_info, grad_executor, jit_forward_graph);
313   // Connect grad graph of jit to context.
314   (void)PyNativeAlgo::Common::SetValueGradInfo(op_run_info->real_out, top_cell, InputType::kOpOutput);
315   MS_EXCEPTION_IF_NULL(jit_forward_graph);
316   MS_EXCEPTION_IF_NULL(jit_forward_graph->output()->abstract());
317   if (grad_executor->dynamic_shape()->enable_unknown_shape() &&
318       jit_forward_graph->output()->abstract()->BuildShape()->IsDynamic()) {
319     MS_LOG(DEBUG) << "Set jit unknown shape out to abs cache";
320     grad_executor->dynamic_shape()->SaveUnknownShapeAbsFromJit(op_run_info->real_out,
321                                                                jit_forward_graph->output()->abstract(), 0);
322   }
323   auto op_grad_info = std::make_shared<OpGradInfo>();
324   op_grad_info->input_value = op_run_info->op_grad_info->input_value;
325   op_grad_info->input_abs = op_run_info->op_grad_info->input_abs;
326   op_grad_info->out_value = op_run_info->real_out;
327   op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_grad_info->out_value);
328   op_grad_info->input_value_grad_type = op_run_info->op_grad_info->input_value_grad_type;
329   if (jit_forward_graph->output()->abstract()->isa<abstract::AbstractAny>()) {
330     op_grad_info->out_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_grad_info->out_value->ToAbstract());
331   } else {
332     op_grad_info->out_abs = jit_forward_graph->output()->abstract();
333   }
334   auto grad_param = std::make_shared<GradParam>(op_grad_info, grad_executor->use_dynamic_shape_process());
335   grad_param->is_control_flow = compile_info_.is_control_flow_;
336 
337   grad_param->has_added_v = has_added_v;
338   grad_param->is_jit_graph = true;
339   // As long as the jit is in the process of dynamic shape,
340   // let it run actor execution to avoid backend pass
341   grad_param->is_jit_self_dynamic_shape = compile_info_.is_dynamic_shape_;
342 
343   grad_param->fg = jit_grad_graph;
344   grad_param->source_fg = jit_forward_graph;
345   grad_param->graph_cache_key = graph_phase_;
346   grad_param->jit_out_has_dict = JitOutputHasDict(op_grad_info->out_abs);
347   auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
348   KPynativeWithFProp(grad_executor, auto_grad_cell_ptr, grad_param);
349   top_cell->set_need_do_final_opt(true);
350   top_cell->set_has_call_graph(grad_executor->use_dynamic_shape_process());
351   top_cell->set_has_control_flow(compile_info_.is_control_flow_);
352   top_cell->set_jit_out_has_dict(grad_param->jit_out_has_dict);
353 }
354 
KPynativeWithFProp(const GradExecutor * grad_executor,const autograd::AutoGradPtr & auto_grad_cell_ptr,const GradParamPtr & grad_param) const355 void Jit::KPynativeWithFProp(const GradExecutor *grad_executor, const autograd::AutoGradPtr &auto_grad_cell_ptr,
356                              const GradParamPtr &grad_param) const {
357   grad_executor->WaitBpropTask();
358   MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
359   if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
360     MS_LOG(EXCEPTION) << "Failed to make adjoint for jit cnode";
361   }
362 }
363 
RecordForwardGraphForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph) const364 void Jit::RecordForwardGraphForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
365                                    const FuncGraphPtr &jit_forward_graph) const {
366   int save_graphs = MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
367   if (save_graphs) {
368     CNodePtr jit_cnode = nullptr;
369     MakeCNodeForJit(op_run_info, grad_executor, jit_forward_graph, &jit_cnode);
370     MS_EXCEPTION_IF_NULL(jit_cnode);
371     const auto &out_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
372     const auto &top_cell = grad_executor->top_cell();
373     top_cell->SetNodeMapInGraphInfoMap(out_id, jit_cnode);
374   }
375 }
376 
GradJitInner(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & primal_func_graph,const FuncGraphPtr & jit_grad_graph,const CNodePtr & added_node,const ValuePtr & added_out_v)377 void Jit::GradJitInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
378                        const FuncGraphPtr &primal_func_graph, const FuncGraphPtr &jit_grad_graph,
379                        const CNodePtr &added_node, const ValuePtr &added_out_v) {
380   MS_EXCEPTION_IF_NULL(op_run_info);
381   MS_EXCEPTION_IF_NULL(grad_executor);
382   // Step 1: Replace added cnode forward with actual output
383   ValuePtr flatten_v = added_out_v;
384   bool added_v_is_empty = true;
385   if (added_out_v != nullptr) {
386     ValuePtrList total_output_tensors;
387     PyNativeAlgo::DataConvert::FlattenValueSeqArg(added_out_v, false, true, &total_output_tensors);
388     flatten_v = std::make_shared<ValueTuple>(total_output_tensors);
389     added_v_is_empty = total_output_tensors.empty();
390     ReplaceAddedCnodeActualOutput(added_node, total_output_tensors);
391   }
392 
393   // Step 2: Check or set set_use_dynamic_shape_process flag
394   auto node_info = std::make_shared<DynamicDetectNodeInfo>(nullptr, op_run_info->op_grad_info->input_abs,
395                                                            op_run_info->base_op_run_info.abstract);
396   node_info->is_graph_node = true;
397   node_info->graph_phase = graph_phase_;
398   grad_executor->dynamic_shape()->CheckNodeDynamic(grad_executor->top_cell(), op_run_info->op_grad_info->input_value,
399                                                    node_info);
400 
401   // Step 3: Update actual output tensors used in grad graph.
402   MS_LOG(DEBUG) << "jit actual output value: " << op_run_info->real_out->ToString();
403   grad_executor->top_cell()->GetOpInfo(op_run_info, true);
404   grad_executor->UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
405                                                             op_run_info->base_op_run_info.stream_id);
406 
407   // Step 4: Update output tensors of added forward nodes, which are added to return node of jit func graph.
408   if (!added_v_is_empty) {
409     if (grad_executor->use_dynamic_shape_process()) {
410       // If jit is not control flow, the jit is executed by actor under dynamic shape, and valuenode
411       // will be updated
412       if (!compile_info_.is_control_flow_) {
413         UpdateJitForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, flatten_v,
414                                                op_run_info->base_op_run_info.stream_id);
415       }
416     } else {
417       // Static shape will run by replace
418       grad_executor->UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, flatten_v,
419                                                                 op_run_info->base_op_run_info.stream_id);
420     }
421   }
422 
423   // Make Adjoint for grad graph
424   MakeAdjointForJit(op_run_info, grad_executor, primal_func_graph, jit_grad_graph, !added_v_is_empty);
425 }
426 
UpdateJitForwardTensorInfoInBpropGraph(const std::string & op_info,const ValuePtr & v,const size_t & stream_id)427 void Jit::UpdateJitForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v,
428                                                  const size_t &stream_id) {
429   const auto it = graph_phase_with_replace_info_.find(graph_phase_);
430   if (it == graph_phase_with_replace_info_.end()) {
431     MS_LOG(DEBUG) << "Jit " << graph_phase_ << " run firstly";
432     auto &replace_info = graph_phase_with_replace_info_[graph_phase_];
433     SetIdWithOpInfo(v, op_info, kIndex0, &(replace_info.id_with_op_info));
434     return;
435   }
436   // Not first run
437   MS_LOG(DEBUG) << "Update jit forward output tensor info " << op_info;
438   UpdateForwardOutputTensorInfo(op_info, v, it->second);
439 }
440 
SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr & func_graph)441 void Jit::SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph) {
442   const auto it = graph_phase_with_replace_info_.find(graph_phase_);
443   if (it == graph_phase_with_replace_info_.end()) {
444     MS_LOG(EXCEPTION) << "Can not find graph phase " << graph_phase_ << " in graph_phase_with_replace_info";
445   }
446   MS_LOG(DEBUG) << "Save jit forward output tensor info";
447   auto manager = MakeManager();
448   MS_EXCEPTION_IF_NULL(manager);
449   manager->AddFuncGraph(func_graph);
450   SaveForwardOutputTensorInfo(func_graph, true, &(it->second));
451 }
452 
ProcessCnodeFromAdGrad(const CNodePtr & k_app,const CNodePtr & cnode_morph)453 void Jit::ProcessCnodeFromAdGrad(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
454   // Run grad process for func_graph and replace forward nodes with its output tensors.
455   if (eliminate_forward_) {
456     ReplaceEquivOut(k_app, cnode_morph);
457   }
458 }
459 
GetJitGradGraph(const pipeline::ResourcePtr & resource)460 bool Jit::GetJitGradGraph(const pipeline::ResourcePtr &resource) {
461   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
462   MS_EXCEPTION_IF_NULL(graph_executor);
463   graph_phase_ = graph_executor->phase();
464   MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << graph_phase_;
465   // Exporting graph in PyNative mode or only running forward process no need to do this action.
466   const auto &pynative_grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
467   if (graph_phase_.find("export") == 0 || !pynative_grad_executor->RequiresGrad()) {
468     MS_LOG(DEBUG) << "When exporting graph or only running forward process";
469     return true;
470   }
471 
472   MS_EXCEPTION_IF_NULL(resource);
473   auto jit_forward_graph = resource->func_graph();
474   MS_EXCEPTION_IF_NULL(jit_forward_graph);
475   graph_executor->SetJitPrimalFuncGraph(BasicClone(jit_forward_graph), graph_phase_);
476   auto clone_graph = GetJitForwardGraphCNodeInfo(jit_forward_graph);
477   if (clone_graph != nullptr) {
478     graph_executor->SetJitGradGraph(clone_graph, graph_phase_);
479     return true;
480   }
481 
482   // Control flow not eliminate forward
483   auto is_control_flow = PyNativeAlgo::Common::IsControlFlowGraph(jit_forward_graph);
484   auto jit_output_has_dict = JitOutputHasDict(jit_forward_graph->output()->abstract());
485   set_eliminate_forward(!is_control_flow && !jit_output_has_dict);
486   MS_LOG(DEBUG) << "Run ad grad eliminate_forward " << eliminate_forward_;
487   auto grad_graph = ad::Grad(is_control_flow ? BasicClone(jit_forward_graph) : jit_forward_graph,
488                              opt::Optimizer::MakeEmptyOptimizer(resource));
489   MS_EXCEPTION_IF_NULL(grad_graph);
490   graph_executor->SetJitGradGraph(grad_graph, graph_phase_);
491   ModifyOutputNode(jit_forward_graph);
492 
493   // Keep roots for only keeping forward func graph in resource.
494   auto manager = resource->manager();
495   MS_EXCEPTION_IF_NULL(manager);
496   manager->KeepRoots({jit_forward_graph});
497   eliminate_forward_ = true;
498   return true;
499 }
500 
Reset()501 void Jit::Reset() { graph_phase_.clear(); }
502 
Clear()503 void Jit::Clear() {
504   for (auto &t : graph_phase_with_replace_info_) {
505     t.second.clear();
506   }
507 }
508 
GetJitForwardGraphCNodeInfo(const FuncGraphPtr & jit_forward_graph)509 FuncGraphPtr Jit::GetJitForwardGraphCNodeInfo(const FuncGraphPtr &jit_forward_graph) {
510   MS_EXCEPTION_IF_NULL(jit_forward_graph);
511   PyNativeAlgo::Common::DumpGraphIR("jit_modify_before_forward_graph.ir", jit_forward_graph);
512   if (PyNativeAlgo::Common::IsControlFlowGraph(jit_forward_graph)) {
513     MS_LOG(DEBUG) << "Get control flow";
514     jit_compile_info_[graph_phase_].is_control_flow_ = true;
515     return nullptr;
516   }
517   if (IsGraphDynamic(jit_forward_graph)) {
518     MS_LOG(DEBUG) << "Get dynamic shape";
519     jit_compile_info_[graph_phase_].is_dynamic_shape_ = true;
520     return nullptr;
521   }
522   jit_compile_info_[graph_phase_] = JitCompileInfo();
523   AnfNodePtrList node_list{};
524   const auto &order = TopoSort(jit_forward_graph->output());
525   for (const auto &node : order) {
526     if (node == nullptr || !node->isa<CNode>()) {
527       continue;
528     }
529     auto cnode = node->cast<CNodePtr>();
530     MS_EXCEPTION_IF_NULL(cnode);
531     const auto &prim = GetCNodePrimitive(cnode);
532     if (prim == nullptr) {
533       MS_LOG(EXCEPTION) << "Should be primitive, but: " << node->DebugString();
534     }
535     if (!PyNativeAlgo::GradCommon::IsRealOp(cnode)) {
536       continue;
537     }
538     MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString();
539     const auto &unused_inputs = BpropExpander::GetUnusedInputs(prim->name());
540     if (!unused_inputs.empty() && unused_inputs.find(INT_MAX) != unused_inputs.end() &&
541         kExpanderWhiteList.find(prim->name()) == kExpanderWhiteList.end()) {
542       MS_LOG(DEBUG) << "Prim " << prim->name() << " is not support by expander";
543       jit_compile_info_[graph_phase_].is_control_flow_ = true;
544       return nullptr;
545     }
546     pynative::PyNativeAlgo::GradCommon::GetUsedCNodeInBpropGraph(cnode, unused_inputs, &node_list);
547   }
548   if (node_list.empty()) {
549     MS_LOG(DEBUG) << "No need do replace";
550     // Make sure forward graph does not change
551     return BasicClone(jit_forward_graph);
552   }
553   pynative::PyNativeAlgo::GradCommon::SetForward(node_list);
554   // jit_forward_graph will be changed output
555   auto clone_graph = BasicClone(jit_forward_graph);
556   jit_forward_graph->set_used_forward_nodes(node_list);
557   ModifyOutputNode(jit_forward_graph);
558   PyNativeAlgo::Common::DumpGraphIR("jit_modify_after_forward_graph.ir", jit_forward_graph);
559   return clone_graph;
560 }
561 
GradJit(const py::object & out,const py::args & args)562 py::object Jit::GradJit(const py::object &out, const py::args &args) {
563   if (graph_phase_.empty()) {
564     MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain jit func graph.";
565   }
566   PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->WaitForwardTask();
567   // Get forward graph
568   MS_LOG(DEBUG) << "jit func graph phase: " << graph_phase_;
569   auto executor = pipeline::GraphExecutorPy::GetInstance();
570   MS_EXCEPTION_IF_NULL(executor);
571   FuncGraphPtr jit_forward_graph = executor->GetFuncGraph(graph_phase_);
572   MS_EXCEPTION_IF_NULL(jit_forward_graph);
573   // Get actual forward output object.
574   py::object ret = out;
575   if (jit_forward_graph->modify_output()) {
576     auto tuple_out = py::cast<py::tuple>(out);
577     ret = tuple_out[0];
578   }
579   // Save dynamic shape info if output tensors of forward graph have dynamic shapes
580   const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
581   // Make Adjoint for grad graph of jit.
582   if (!grad_executor->RequiresGrad()) {
583     MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
584     graph_phase_.clear();
585     return ret;
586   }
587   compile_info_ = jit_compile_info_.at(graph_phase_);
588   ValuePtr added_out_v = nullptr;
589   const auto &op_run_info =
590     GetOpRunInfo(out, args, graph_phase_, jit_forward_graph->modify_output(), jit_forward_graph, &added_out_v);
591   PyNativeAlgo::Common::DumpGraphIR("jit_forward_graph.ir", jit_forward_graph);
592   auto jit_grad_graph = executor->GetJitGradGraph(graph_phase_);
593   if (compile_info_.is_dynamic_shape_) {
594     grad_executor->set_use_dynamic_shape_process(true);
595   }
596   GradJitInner(op_run_info, grad_executor.get(), executor->GetJitPrimalFuncGraph(graph_phase_), jit_grad_graph,
597                GetAddedNode(jit_forward_graph), added_out_v);
598   Reset();
599   return ret;
600 }
601 }  // namespace pynative
602 }  // namespace mindspore
603