1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/jit/jit_grad.h"
18
19 #include <utility>
20 #include "frontend/optimizer/ad/grad.h"
21 #include "ops/structure_op_name.h"
22 #include "ops/framework_op_name.h"
23 #include "ops/sequence_ops.h"
24 #include "pipeline/pynative/pynative_utils.h"
25 #include "pipeline/pynative/grad/jit/jit_dfunctor.h"
26 #include "ir/func_graph_cloner.h"
27 #include "frontend/expander/bprop/bprop.h"
28
29 namespace mindspore {
30 namespace pynative {
31 namespace {
32 constexpr char kAddedValue[] = "added_value";
33
34 const mindspore::HashSet<std::string> kExpanderWhiteList{
35 kVmapStackAssignOpName,
36 kVmapUnstackAssignOpName,
37 kPyExecuteOpName,
38 kPrintOpName,
39 };
40
GetOpRunInfo(const py::object & out,const py::args & args,const std::string & graph_phase,bool modify_output,const FuncGraphPtr & jit_forward_graph,ValuePtr * added_out_v)41 FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, const std::string &graph_phase,
42 bool modify_output, const FuncGraphPtr &jit_forward_graph, ValuePtr *added_out_v) {
43 auto op_run_info = std::make_shared<FrontendOpRunInfo>();
44 op_run_info->requires_grad = true;
45 op_run_info->is_jit_input = true;
46 op_run_info->base_op_run_info.op_name = graph_phase;
47 PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
48 // Set input abs
49 const auto &original_params = jit_forward_graph->parameters();
50 for (size_t i = 0; i < op_run_info->input_size; ++i) {
51 op_run_info->op_grad_info->input_abs[i] = original_params[i]->abstract();
52 }
53 if (modify_output) {
54 if (!py::isinstance<py::tuple>(out)) {
55 MS_LOG(EXCEPTION) << "The output value of jit func graph should be a tuple.";
56 }
57 auto tuple_out = py::cast<py::tuple>(out);
58 constexpr size_t tuple_out_size = 2;
59 if (tuple_out.size() != tuple_out_size) {
60 MS_LOG(EXCEPTION) << "The tuple size of output value of jit func graph should be 2.";
61 }
62 MS_EXCEPTION_IF_NULL(added_out_v);
63 // Forward output of op in jit graph
64 *added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
65 op_run_info->real_out = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[0]);
66 } else {
67 op_run_info->real_out = PyNativeAlgo::DataConvert::PyObjToValue(out);
68 }
69 return op_run_info;
70 }
71
GetTensorNumFromAbstract(const abstract::AbstractBasePtr & abs)72 size_t GetTensorNumFromAbstract(const abstract::AbstractBasePtr &abs) {
73 MS_EXCEPTION_IF_NULL(abs);
74 if (abs->isa<abstract::AbstractTensor>()) {
75 // Is a tensor
76 constexpr size_t kTensorOutputNum = 1;
77 return kTensorOutputNum;
78 } else if (abs->isa<abstract::AbstractSequence>()) {
79 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>()->elements();
80 return std::accumulate(abs_seq.begin(), abs_seq.end(), 0, [](size_t out_num, const abstract::AbstractBasePtr &abs) {
81 return out_num + GetTensorNumFromAbstract(abs);
82 });
83 } else if (abs->isa<abstract::AbstractCSRTensor>()) {
84 // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
85 constexpr size_t kCSRTensorOutputNum = 5;
86 return kCSRTensorOutputNum;
87 } else if (abs->isa<abstract::AbstractCOOTensor>()) {
88 // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
89 constexpr size_t kCOOTensorOutputNum = 4;
90 return kCOOTensorOutputNum;
91 }
92 return 0;
93 }
94
95 // Modify the output node of func_graph to add forward nodes used in bprop graph.
ModifyOutputNode(const FuncGraphPtr & func_graph)96 void ModifyOutputNode(const FuncGraphPtr &func_graph) {
97 MS_EXCEPTION_IF_NULL(func_graph);
98 const auto &used_forward_nodes = func_graph->used_forward_nodes();
99 if (used_forward_nodes.empty()) {
100 return;
101 }
102
103 // Create a new make tuple node to hold all forward used nodes.
104 abstract::AbstractBasePtrList added_abs_list;
105 AnfNodePtrList added_node_list{NewValueNode(prim::kPrimMakeTuple)};
106 for (const auto &node : used_forward_nodes) {
107 MS_EXCEPTION_IF_NULL(node);
108 (void)added_node_list.emplace_back(node);
109 (void)added_abs_list.emplace_back(node->abstract());
110 }
111 AnfNodePtr added_output_node = func_graph->NewCNode(std::move(added_node_list));
112 AbstractBasePtr added_output_abs = std::make_shared<abstract::AbstractTuple>(added_abs_list);
113 added_output_node->set_abstract(added_output_abs);
114
115 // Get original output node and abstract, and merge original output node and used forward nodes to return node.
116 auto original_output_node = func_graph->output();
117 MS_EXCEPTION_IF_NULL(original_output_node);
118 auto original_output_abs = original_output_node->abstract();
119 MS_EXCEPTION_IF_NULL(original_output_abs);
120 AnfNodePtrList new_output_nodes{NewValueNode(prim::kPrimMakeTuple), original_output_node, added_output_node};
121 auto merge_node = func_graph->NewCNode(std::move(new_output_nodes));
122 abstract::AbstractBasePtrList new_output_abs{original_output_abs, added_output_abs};
123 merge_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_output_abs));
124 func_graph->set_output(merge_node);
125
126 // Clear
127 func_graph->set_modify_output(true);
128 func_graph->ClearUsedForwardNodes();
129 }
130
GetAddedNode(const FuncGraphPtr & jit_forward_graph)131 CNodePtr GetAddedNode(const FuncGraphPtr &jit_forward_graph) {
132 MS_EXCEPTION_IF_NULL(jit_forward_graph);
133 if (!jit_forward_graph->modify_output()) {
134 return nullptr;
135 }
136 // Get added forward nodes.
137 auto merge_node = jit_forward_graph->output();
138 MS_EXCEPTION_IF_NULL(merge_node);
139 auto merge_make_tuple = merge_node->cast<CNodePtr>();
140 MS_EXCEPTION_IF_NULL(merge_make_tuple);
141 constexpr size_t merge_output_size = 3;
142 // First is make_tuple, second is actual output, third is added output
143 if (merge_make_tuple->size() != merge_output_size) {
144 MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
145 }
146 constexpr size_t added_output_index = 2;
147 return merge_make_tuple->input(added_output_index)->cast<CNodePtr>();
148 }
149
IsGraphDynamic(const FuncGraphPtr & func_graph)150 bool IsGraphDynamic(const FuncGraphPtr &func_graph) {
151 for (const auto ¶m : func_graph->parameters()) {
152 if (param->isa<Parameter>() && !param->cast<ParameterPtr>()->has_default()) {
153 const auto &abs = param->abstract();
154 if (abs != nullptr && abs->BuildShape()->IsDynamic()) {
155 return true;
156 }
157 }
158 }
159 MS_EXCEPTION_IF_NULL(func_graph->output());
160 if (auto abs = func_graph->output()->abstract(); abs != nullptr && abs->BuildShape()->IsDynamic()) {
161 return true;
162 }
163 return false;
164 }
165
JitOutputHasDict(const abstract::AbstractBasePtr & abs)166 bool JitOutputHasDict(const abstract::AbstractBasePtr &abs) {
167 MS_EXCEPTION_IF_NULL(abs);
168 if (abs->isa<abstract::AbstractDictionary>()) {
169 return true;
170 } else if (abs->isa<abstract::AbstractSequence>()) {
171 const auto &abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
172 return std::any_of(abs_sequence->elements().begin(), abs_sequence->elements().end(),
173 [](const abstract::AbstractBasePtr &item) { return JitOutputHasDict(item); });
174 }
175 return false;
176 }
177 } // namespace
178
RunReplace(const CNodePtr & added_node,const ValuePtrList & total_output_tensors) const179 void Jit::RunReplace(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const {
180 MS_EXCEPTION_IF_NULL(added_node);
181 size_t index = 0;
182 for (size_t i = 1; i < added_node->size(); ++i) {
183 const auto &input_i = added_node->input(i);
184 MS_EXCEPTION_IF_NULL(input_i);
185 auto cnode = input_i->cast<CNodePtr>();
186 MS_EXCEPTION_IF_NULL(cnode);
187 MS_LOG(DEBUG) << "Replace output tensors for cnode: " << cnode->DebugString();
188 const auto &output_vnode = cnode->forward().first;
189 MS_EXCEPTION_IF_NULL(output_vnode);
190 MS_LOG(DEBUG) << "Old output value node: " << output_vnode->ToString();
191 MS_EXCEPTION_IF_NULL(output_vnode->abstract());
192 bool is_tuple_out = output_vnode->abstract()->isa<abstract::AbstractSequence>();
193 size_t output_num = GetTensorNumFromAbstract(cnode->abstract());
194 if (output_num == 0) {
195 MS_LOG(DEBUG) << "The output value out is not include tensor";
196 continue;
197 }
198 if (index + output_num > total_output_tensors.size()) {
199 MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
200 << ", but the current index: " << index << ", output num: " << output_num;
201 }
202 // Get new tensors.
203 std::vector<ValuePtr> new_values;
204 for (size_t j = index; j < index + output_num; ++j) {
205 // If jit graph reused in dynamic shape, added output tensor should be update tensor address in run actor
206 auto tensor = total_output_tensors[j]->cast<tensor::BaseTensorPtr>();
207 if (tensor != nullptr) {
208 tensor->set_is_forward_output(true);
209 }
210 (void)new_values.emplace_back(total_output_tensors[j]);
211 }
212 index = index + output_num;
213 // Replace new tensors.
214 // Can not use output_num > 1, because output can be (a), tuple just have only one element
215 if (is_tuple_out) {
216 output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
217 } else {
218 output_vnode->set_value(new_values[0]);
219 }
220 MS_LOG(DEBUG) << "New output value node: " << output_vnode->ToString();
221 }
222 // Save op info with new tensors for current running jit func graph.
223 if (index != total_output_tensors.size()) {
224 MS_LOG(EXCEPTION) << "The index: " << index
225 << " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
226 }
227 }
228
ReplaceAddedCnodeActualOutput(const CNodePtr & added_node,const ValuePtrList & total_output_tensors) const229 void Jit::ReplaceAddedCnodeActualOutput(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const {
230 MS_EXCEPTION_IF_NULL(added_node);
231 // Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
232 MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_node->DebugString();
233 // The forward node in jit graph is created during compilation and is a placeholder.
234 // After running jit, need to update to real value.
235 RunReplace(added_node, total_output_tensors);
236 }
237
GetInputArgsNode(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,AnfNodePtrList * input_nodes) const238 void Jit::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
239 AnfNodePtrList *input_nodes) const {
240 MS_EXCEPTION_IF_NULL(op_run_info);
241 MS_EXCEPTION_IF_NULL(input_nodes);
242 MS_EXCEPTION_IF_NULL(grad_executor);
243 for (size_t i = 0; i < op_run_info->input_size; ++i) {
244 const auto &input_i_value = op_run_info->op_grad_info->input_value[i];
245 const auto &id = PyNativeAlgo::Common::GetIdByValue(input_i_value);
246 const auto &input_i_node = grad_executor->GetInput(input_i_value, id);
247 MS_EXCEPTION_IF_NULL(input_i_node);
248 MS_LOG(DEBUG) << "The input " << i << " id " << id << " , node is: " << input_i_node->DebugString();
249 (void)input_nodes->emplace_back(input_i_node);
250 }
251 }
252
GetWeightsNode(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,AnfNodePtrList * input_nodes) const253 void Jit::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
254 const FuncGraphPtr &jit_forward_graph, AnfNodePtrList *input_nodes) const {
255 MS_EXCEPTION_IF_NULL(grad_executor);
256 MS_EXCEPTION_IF_NULL(input_nodes);
257 const auto &top_cell = grad_executor->top_cell();
258 const auto &graph_info = top_cell->graph_info_map().at(top_cell->fg());
259 MS_EXCEPTION_IF_NULL(graph_info);
260 // Get weights info of jit
261 MS_EXCEPTION_IF_NULL(jit_forward_graph);
262 const auto &original_params = jit_forward_graph->parameters();
263 size_t params_size = original_params.size();
264 MS_EXCEPTION_IF_NULL(op_run_info);
265 for (size_t i = 0; i < params_size; ++i) {
266 if (i < op_run_info->input_size) { // non-weights node.
267 continue;
268 }
269 // Must weight param
270 auto param = original_params[i]->cast<ParameterPtr>();
271 const auto tensor_value = PyNativeAlgo::Common::GetTensorFromParam(original_params[i]);
272 MS_EXCEPTION_IF_NULL(tensor_value);
273 const auto it = graph_info->weight_params.find(tensor_value->id());
274 if (it != graph_info->weight_params.end()) {
275 param = it->second;
276 } else {
277 top_cell->fg()->add_parameter(param);
278 param->debug_info()->set_name(param->name());
279 top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
280 }
281 (void)input_nodes->emplace_back(param);
282 MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
283 << tensor_value->ToString() << ". Its name is: " << param->name();
284 }
285 }
286
MakeCNodeForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,CNodePtr * jit_cnode) const287 void Jit::MakeCNodeForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
288 const FuncGraphPtr &jit_forward_graph, CNodePtr *jit_cnode) const {
289 MS_EXCEPTION_IF_NULL(op_run_info);
290 MS_EXCEPTION_IF_NULL(jit_forward_graph);
291 // Get input node info of jit
292 AnfNodePtrList input_nodes{NewValueNode(jit_forward_graph)};
293 MS_EXCEPTION_IF_NULL(grad_executor);
294 GetInputArgsNode(op_run_info, grad_executor, &input_nodes);
295 // Get weights node info of jit.
296 GetWeightsNode(op_run_info, grad_executor, jit_forward_graph, &input_nodes);
297 // Make a CNode which includes jit fprop graph and inputs node
298 MS_EXCEPTION_IF_NULL(jit_cnode);
299 *jit_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
300 (*jit_cnode)->set_abstract(jit_forward_graph->output()->abstract());
301 MS_LOG(DEBUG) << "Make jit forward CNode: " << (*jit_cnode)->DebugString();
302 }
303
MakeAdjointForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph,const FuncGraphPtr & jit_grad_graph,bool has_added_v) const304 void Jit::MakeAdjointForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
305 const FuncGraphPtr &jit_forward_graph, const FuncGraphPtr &jit_grad_graph,
306 bool has_added_v) const {
307 MS_EXCEPTION_IF_NULL(op_run_info);
308 MS_EXCEPTION_IF_NULL(grad_executor);
309
310 const auto &top_cell = grad_executor->top_cell();
311 PyNativeAlgo::Common::SetGraphInputAndWeightsInfo(op_run_info, jit_forward_graph, top_cell);
312 RecordForwardGraphForJit(op_run_info, grad_executor, jit_forward_graph);
313 // Connect grad graph of jit to context.
314 (void)PyNativeAlgo::Common::SetValueGradInfo(op_run_info->real_out, top_cell, InputType::kOpOutput);
315 MS_EXCEPTION_IF_NULL(jit_forward_graph);
316 MS_EXCEPTION_IF_NULL(jit_forward_graph->output()->abstract());
317 if (grad_executor->dynamic_shape()->enable_unknown_shape() &&
318 jit_forward_graph->output()->abstract()->BuildShape()->IsDynamic()) {
319 MS_LOG(DEBUG) << "Set jit unknown shape out to abs cache";
320 grad_executor->dynamic_shape()->SaveUnknownShapeAbsFromJit(op_run_info->real_out,
321 jit_forward_graph->output()->abstract(), 0);
322 }
323 auto op_grad_info = std::make_shared<OpGradInfo>();
324 op_grad_info->input_value = op_run_info->op_grad_info->input_value;
325 op_grad_info->input_abs = op_run_info->op_grad_info->input_abs;
326 op_grad_info->out_value = op_run_info->real_out;
327 op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_grad_info->out_value);
328 op_grad_info->input_value_grad_type = op_run_info->op_grad_info->input_value_grad_type;
329 if (jit_forward_graph->output()->abstract()->isa<abstract::AbstractAny>()) {
330 op_grad_info->out_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_grad_info->out_value->ToAbstract());
331 } else {
332 op_grad_info->out_abs = jit_forward_graph->output()->abstract();
333 }
334 auto grad_param = std::make_shared<GradParam>(op_grad_info, grad_executor->use_dynamic_shape_process());
335 grad_param->is_control_flow = compile_info_.is_control_flow_;
336
337 grad_param->has_added_v = has_added_v;
338 grad_param->is_jit_graph = true;
339 // As long as the jit is in the process of dynamic shape,
340 // let it run actor execution to avoid backend pass
341 grad_param->is_jit_self_dynamic_shape = compile_info_.is_dynamic_shape_;
342
343 grad_param->fg = jit_grad_graph;
344 grad_param->source_fg = jit_forward_graph;
345 grad_param->graph_cache_key = graph_phase_;
346 grad_param->jit_out_has_dict = JitOutputHasDict(op_grad_info->out_abs);
347 auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
348 KPynativeWithFProp(grad_executor, auto_grad_cell_ptr, grad_param);
349 top_cell->set_need_do_final_opt(true);
350 top_cell->set_has_call_graph(grad_executor->use_dynamic_shape_process());
351 top_cell->set_has_control_flow(compile_info_.is_control_flow_);
352 top_cell->set_jit_out_has_dict(grad_param->jit_out_has_dict);
353 }
354
KPynativeWithFProp(const GradExecutor * grad_executor,const autograd::AutoGradPtr & auto_grad_cell_ptr,const GradParamPtr & grad_param) const355 void Jit::KPynativeWithFProp(const GradExecutor *grad_executor, const autograd::AutoGradPtr &auto_grad_cell_ptr,
356 const GradParamPtr &grad_param) const {
357 grad_executor->WaitBpropTask();
358 MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
359 if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
360 MS_LOG(EXCEPTION) << "Failed to make adjoint for jit cnode";
361 }
362 }
363
RecordForwardGraphForJit(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & jit_forward_graph) const364 void Jit::RecordForwardGraphForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
365 const FuncGraphPtr &jit_forward_graph) const {
366 int save_graphs = MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
367 if (save_graphs) {
368 CNodePtr jit_cnode = nullptr;
369 MakeCNodeForJit(op_run_info, grad_executor, jit_forward_graph, &jit_cnode);
370 MS_EXCEPTION_IF_NULL(jit_cnode);
371 const auto &out_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
372 const auto &top_cell = grad_executor->top_cell();
373 top_cell->SetNodeMapInGraphInfoMap(out_id, jit_cnode);
374 }
375 }
376
GradJitInner(const FrontendOpRunInfoPtr & op_run_info,const GradExecutor * grad_executor,const FuncGraphPtr & primal_func_graph,const FuncGraphPtr & jit_grad_graph,const CNodePtr & added_node,const ValuePtr & added_out_v)377 void Jit::GradJitInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
378 const FuncGraphPtr &primal_func_graph, const FuncGraphPtr &jit_grad_graph,
379 const CNodePtr &added_node, const ValuePtr &added_out_v) {
380 MS_EXCEPTION_IF_NULL(op_run_info);
381 MS_EXCEPTION_IF_NULL(grad_executor);
382 // Step 1: Replace added cnode forward with actual output
383 ValuePtr flatten_v = added_out_v;
384 bool added_v_is_empty = true;
385 if (added_out_v != nullptr) {
386 ValuePtrList total_output_tensors;
387 PyNativeAlgo::DataConvert::FlattenValueSeqArg(added_out_v, false, true, &total_output_tensors);
388 flatten_v = std::make_shared<ValueTuple>(total_output_tensors);
389 added_v_is_empty = total_output_tensors.empty();
390 ReplaceAddedCnodeActualOutput(added_node, total_output_tensors);
391 }
392
393 // Step 2: Check or set set_use_dynamic_shape_process flag
394 auto node_info = std::make_shared<DynamicDetectNodeInfo>(nullptr, op_run_info->op_grad_info->input_abs,
395 op_run_info->base_op_run_info.abstract);
396 node_info->is_graph_node = true;
397 node_info->graph_phase = graph_phase_;
398 grad_executor->dynamic_shape()->CheckNodeDynamic(grad_executor->top_cell(), op_run_info->op_grad_info->input_value,
399 node_info);
400
401 // Step 3: Update actual output tensors used in grad graph.
402 MS_LOG(DEBUG) << "jit actual output value: " << op_run_info->real_out->ToString();
403 grad_executor->top_cell()->GetOpInfo(op_run_info, true);
404 grad_executor->UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
405 op_run_info->base_op_run_info.stream_id);
406
407 // Step 4: Update output tensors of added forward nodes, which are added to return node of jit func graph.
408 if (!added_v_is_empty) {
409 if (grad_executor->use_dynamic_shape_process()) {
410 // If jit is not control flow, the jit is executed by actor under dynamic shape, and valuenode
411 // will be updated
412 if (!compile_info_.is_control_flow_) {
413 UpdateJitForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, flatten_v,
414 op_run_info->base_op_run_info.stream_id);
415 }
416 } else {
417 // Static shape will run by replace
418 grad_executor->UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info + kAddedValue, flatten_v,
419 op_run_info->base_op_run_info.stream_id);
420 }
421 }
422
423 // Make Adjoint for grad graph
424 MakeAdjointForJit(op_run_info, grad_executor, primal_func_graph, jit_grad_graph, !added_v_is_empty);
425 }
426
UpdateJitForwardTensorInfoInBpropGraph(const std::string & op_info,const ValuePtr & v,const size_t & stream_id)427 void Jit::UpdateJitForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v,
428 const size_t &stream_id) {
429 const auto it = graph_phase_with_replace_info_.find(graph_phase_);
430 if (it == graph_phase_with_replace_info_.end()) {
431 MS_LOG(DEBUG) << "Jit " << graph_phase_ << " run firstly";
432 auto &replace_info = graph_phase_with_replace_info_[graph_phase_];
433 SetIdWithOpInfo(v, op_info, kIndex0, &(replace_info.id_with_op_info));
434 return;
435 }
436 // Not first run
437 MS_LOG(DEBUG) << "Update jit forward output tensor info " << op_info;
438 UpdateForwardOutputTensorInfo(op_info, v, it->second);
439 }
440
SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr & func_graph)441 void Jit::SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph) {
442 const auto it = graph_phase_with_replace_info_.find(graph_phase_);
443 if (it == graph_phase_with_replace_info_.end()) {
444 MS_LOG(EXCEPTION) << "Can not find graph phase " << graph_phase_ << " in graph_phase_with_replace_info";
445 }
446 MS_LOG(DEBUG) << "Save jit forward output tensor info";
447 auto manager = MakeManager();
448 MS_EXCEPTION_IF_NULL(manager);
449 manager->AddFuncGraph(func_graph);
450 SaveForwardOutputTensorInfo(func_graph, true, &(it->second));
451 }
452
ProcessCnodeFromAdGrad(const CNodePtr & k_app,const CNodePtr & cnode_morph)453 void Jit::ProcessCnodeFromAdGrad(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
454 // Run grad process for func_graph and replace forward nodes with its output tensors.
455 if (eliminate_forward_) {
456 ReplaceEquivOut(k_app, cnode_morph);
457 }
458 }
459
GetJitGradGraph(const pipeline::ResourcePtr & resource)460 bool Jit::GetJitGradGraph(const pipeline::ResourcePtr &resource) {
461 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
462 MS_EXCEPTION_IF_NULL(graph_executor);
463 graph_phase_ = graph_executor->phase();
464 MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << graph_phase_;
465 // Exporting graph in PyNative mode or only running forward process no need to do this action.
466 const auto &pynative_grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
467 if (graph_phase_.find("export") == 0 || !pynative_grad_executor->RequiresGrad()) {
468 MS_LOG(DEBUG) << "When exporting graph or only running forward process";
469 return true;
470 }
471
472 MS_EXCEPTION_IF_NULL(resource);
473 auto jit_forward_graph = resource->func_graph();
474 MS_EXCEPTION_IF_NULL(jit_forward_graph);
475 graph_executor->SetJitPrimalFuncGraph(BasicClone(jit_forward_graph), graph_phase_);
476 auto clone_graph = GetJitForwardGraphCNodeInfo(jit_forward_graph);
477 if (clone_graph != nullptr) {
478 graph_executor->SetJitGradGraph(clone_graph, graph_phase_);
479 return true;
480 }
481
482 // Control flow not eliminate forward
483 auto is_control_flow = PyNativeAlgo::Common::IsControlFlowGraph(jit_forward_graph);
484 auto jit_output_has_dict = JitOutputHasDict(jit_forward_graph->output()->abstract());
485 set_eliminate_forward(!is_control_flow && !jit_output_has_dict);
486 MS_LOG(DEBUG) << "Run ad grad eliminate_forward " << eliminate_forward_;
487 auto grad_graph = ad::Grad(is_control_flow ? BasicClone(jit_forward_graph) : jit_forward_graph,
488 opt::Optimizer::MakeEmptyOptimizer(resource));
489 MS_EXCEPTION_IF_NULL(grad_graph);
490 graph_executor->SetJitGradGraph(grad_graph, graph_phase_);
491 ModifyOutputNode(jit_forward_graph);
492
493 // Keep roots for only keeping forward func graph in resource.
494 auto manager = resource->manager();
495 MS_EXCEPTION_IF_NULL(manager);
496 manager->KeepRoots({jit_forward_graph});
497 eliminate_forward_ = true;
498 return true;
499 }
500
Reset()501 void Jit::Reset() { graph_phase_.clear(); }
502
Clear()503 void Jit::Clear() {
504 for (auto &t : graph_phase_with_replace_info_) {
505 t.second.clear();
506 }
507 }
508
GetJitForwardGraphCNodeInfo(const FuncGraphPtr & jit_forward_graph)509 FuncGraphPtr Jit::GetJitForwardGraphCNodeInfo(const FuncGraphPtr &jit_forward_graph) {
510 MS_EXCEPTION_IF_NULL(jit_forward_graph);
511 PyNativeAlgo::Common::DumpGraphIR("jit_modify_before_forward_graph.ir", jit_forward_graph);
512 if (PyNativeAlgo::Common::IsControlFlowGraph(jit_forward_graph)) {
513 MS_LOG(DEBUG) << "Get control flow";
514 jit_compile_info_[graph_phase_].is_control_flow_ = true;
515 return nullptr;
516 }
517 if (IsGraphDynamic(jit_forward_graph)) {
518 MS_LOG(DEBUG) << "Get dynamic shape";
519 jit_compile_info_[graph_phase_].is_dynamic_shape_ = true;
520 return nullptr;
521 }
522 jit_compile_info_[graph_phase_] = JitCompileInfo();
523 AnfNodePtrList node_list{};
524 const auto &order = TopoSort(jit_forward_graph->output());
525 for (const auto &node : order) {
526 if (node == nullptr || !node->isa<CNode>()) {
527 continue;
528 }
529 auto cnode = node->cast<CNodePtr>();
530 MS_EXCEPTION_IF_NULL(cnode);
531 const auto &prim = GetCNodePrimitive(cnode);
532 if (prim == nullptr) {
533 MS_LOG(EXCEPTION) << "Should be primitive, but: " << node->DebugString();
534 }
535 if (!PyNativeAlgo::GradCommon::IsRealOp(cnode)) {
536 continue;
537 }
538 MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString();
539 const auto &unused_inputs = BpropExpander::GetUnusedInputs(prim->name());
540 if (!unused_inputs.empty() && unused_inputs.find(INT_MAX) != unused_inputs.end() &&
541 kExpanderWhiteList.find(prim->name()) == kExpanderWhiteList.end()) {
542 MS_LOG(DEBUG) << "Prim " << prim->name() << " is not support by expander";
543 jit_compile_info_[graph_phase_].is_control_flow_ = true;
544 return nullptr;
545 }
546 pynative::PyNativeAlgo::GradCommon::GetUsedCNodeInBpropGraph(cnode, unused_inputs, &node_list);
547 }
548 if (node_list.empty()) {
549 MS_LOG(DEBUG) << "No need do replace";
550 // Make sure forward graph does not change
551 return BasicClone(jit_forward_graph);
552 }
553 pynative::PyNativeAlgo::GradCommon::SetForward(node_list);
554 // jit_forward_graph will be changed output
555 auto clone_graph = BasicClone(jit_forward_graph);
556 jit_forward_graph->set_used_forward_nodes(node_list);
557 ModifyOutputNode(jit_forward_graph);
558 PyNativeAlgo::Common::DumpGraphIR("jit_modify_after_forward_graph.ir", jit_forward_graph);
559 return clone_graph;
560 }
561
GradJit(const py::object & out,const py::args & args)562 py::object Jit::GradJit(const py::object &out, const py::args &args) {
563 if (graph_phase_.empty()) {
564 MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain jit func graph.";
565 }
566 PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->WaitForwardTask();
567 // Get forward graph
568 MS_LOG(DEBUG) << "jit func graph phase: " << graph_phase_;
569 auto executor = pipeline::GraphExecutorPy::GetInstance();
570 MS_EXCEPTION_IF_NULL(executor);
571 FuncGraphPtr jit_forward_graph = executor->GetFuncGraph(graph_phase_);
572 MS_EXCEPTION_IF_NULL(jit_forward_graph);
573 // Get actual forward output object.
574 py::object ret = out;
575 if (jit_forward_graph->modify_output()) {
576 auto tuple_out = py::cast<py::tuple>(out);
577 ret = tuple_out[0];
578 }
579 // Save dynamic shape info if output tensors of forward graph have dynamic shapes
580 const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
581 // Make Adjoint for grad graph of jit.
582 if (!grad_executor->RequiresGrad()) {
583 MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
584 graph_phase_.clear();
585 return ret;
586 }
587 compile_info_ = jit_compile_info_.at(graph_phase_);
588 ValuePtr added_out_v = nullptr;
589 const auto &op_run_info =
590 GetOpRunInfo(out, args, graph_phase_, jit_forward_graph->modify_output(), jit_forward_graph, &added_out_v);
591 PyNativeAlgo::Common::DumpGraphIR("jit_forward_graph.ir", jit_forward_graph);
592 auto jit_grad_graph = executor->GetJitGradGraph(graph_phase_);
593 if (compile_info_.is_dynamic_shape_) {
594 grad_executor->set_use_dynamic_shape_process(true);
595 }
596 GradJitInner(op_run_info, grad_executor.get(), executor->GetJitPrimalFuncGraph(graph_phase_), jit_grad_graph,
597 GetAddedNode(jit_forward_graph), added_out_v);
598 Reset();
599 return ret;
600 }
601 } // namespace pynative
602 } // namespace mindspore
603