1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include "ir/anf.h" 24 #include "ir/tensor.h" 25 #include "pipeline/pynative/base.h" 26 #include "pipeline/pynative/grad/top_cell.h" 27 #include "pipeline/pynative/grad/auto_grad.h" 28 #include "pipeline/pynative/grad/ir/bprop_tensor_replace.h" 29 #include "pipeline/jit/ps/pipeline.h" 30 #include "pipeline/jit/ps/resource.h" 31 32 namespace mindspore { 33 namespace pynative { 34 class GradExecutor; 35 struct JitCompileInfo { 36 bool is_control_flow_{false}; 37 bool is_dynamic_shape_{false}; 38 }; 39 40 class Jit { 41 public: 42 Jit() = default; 43 ~Jit() = default; set_graph_phase(const std::string & graph_phase)44 inline void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; } 45 py::object GradJit(const py::object &out, const py::args &args); 46 void SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph); 47 void ProcessCnodeFromAdGrad(const CNodePtr &k_app, const CNodePtr &cnode_morph); 48 bool GetJitGradGraph(const pipeline::ResourcePtr &resource); eliminate_forward()49 inline bool eliminate_forward() const { return eliminate_forward_; } set_eliminate_forward(bool eliminate_forward)50 inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; } 51 void Clear(); 52 53 private: 54 void GradJitInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 55 const FuncGraphPtr &primal_func_graph, const FuncGraphPtr &jit_grad_graph, 56 const CNodePtr &added_node, const ValuePtr &added_out_v); 57 // Update device address of value node in grad graph by forward tensors. 58 void RunReplace(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const; 59 void ReplaceAddedCnodeActualOutput(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const; 60 // Make CNode for jit forward graph. 61 void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 62 AnfNodePtrList *input_nodes) const; 63 void GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 64 const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes) const; 65 void MakeCNodeForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 66 const FuncGraphPtr &ms_func_graph, CNodePtr *jit_cnode) const; 67 // Make adjoint for jit fprop graph and connect it with previous op 68 void MakeAdjointForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 69 const FuncGraphPtr &jit_forward_graph, const FuncGraphPtr &jit_grad_graph, 70 bool has_added_v) const; 71 void KPynativeWithFProp(const GradExecutor *grad_executor, const autograd::AutoGradPtr &auto_grad_cell_ptr, 72 const GradParamPtr &grad_param) const; 73 void RecordForwardGraphForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor, 74 const FuncGraphPtr &ms_func_graph) const; 75 void UpdateJitForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v, const size_t &stream_id); 76 FuncGraphPtr GetJitForwardGraphCNodeInfo(const FuncGraphPtr &jit_forward_graph); 77 void Reset(); 78 79 bool eliminate_forward_{true}; 80 // The graph phase is used to obtain backend graph that is complied by jit 81 std::string graph_phase_; 82 JitCompileInfo compile_info_; 83 mindspore::HashMap<std::string, TensorReplaceInfo> graph_phase_with_replace_info_{}; 84 mindspore::HashMap<std::string, JitCompileInfo> jit_compile_info_{}; 85 }; 86 using JitPtr = std::shared_ptr<Jit>; 87 } // namespace pynative 88 } // namespace mindspore 89 90 #endif // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_ 91