• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include "ir/anf.h"
24 #include "ir/tensor.h"
25 #include "pipeline/pynative/base.h"
26 #include "pipeline/pynative/grad/top_cell.h"
27 #include "pipeline/pynative/grad/auto_grad.h"
28 #include "pipeline/pynative/grad/ir/bprop_tensor_replace.h"
29 #include "pipeline/jit/ps/pipeline.h"
30 #include "pipeline/jit/ps/resource.h"
31 
32 namespace mindspore {
33 namespace pynative {
34 class GradExecutor;
35 struct JitCompileInfo {
36   bool is_control_flow_{false};
37   bool is_dynamic_shape_{false};
38 };
39 
40 class Jit {
41  public:
42   Jit() = default;
43   ~Jit() = default;
set_graph_phase(const std::string & graph_phase)44   inline void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
45   py::object GradJit(const py::object &out, const py::args &args);
46   void SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph);
47   void ProcessCnodeFromAdGrad(const CNodePtr &k_app, const CNodePtr &cnode_morph);
48   bool GetJitGradGraph(const pipeline::ResourcePtr &resource);
eliminate_forward()49   inline bool eliminate_forward() const { return eliminate_forward_; }
set_eliminate_forward(bool eliminate_forward)50   inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
51   void Clear();
52 
53  private:
54   void GradJitInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
55                     const FuncGraphPtr &primal_func_graph, const FuncGraphPtr &jit_grad_graph,
56                     const CNodePtr &added_node, const ValuePtr &added_out_v);
57   // Update device address of value node in grad graph by forward tensors.
58   void RunReplace(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const;
59   void ReplaceAddedCnodeActualOutput(const CNodePtr &added_node, const ValuePtrList &total_output_tensors) const;
60   // Make CNode for jit forward graph.
61   void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
62                         AnfNodePtrList *input_nodes) const;
63   void GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
64                       const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes) const;
65   void MakeCNodeForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
66                        const FuncGraphPtr &ms_func_graph, CNodePtr *jit_cnode) const;
67   // Make adjoint for jit fprop graph and connect it with previous op
68   void MakeAdjointForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
69                          const FuncGraphPtr &jit_forward_graph, const FuncGraphPtr &jit_grad_graph,
70                          bool has_added_v) const;
71   void KPynativeWithFProp(const GradExecutor *grad_executor, const autograd::AutoGradPtr &auto_grad_cell_ptr,
72                           const GradParamPtr &grad_param) const;
73   void RecordForwardGraphForJit(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
74                                 const FuncGraphPtr &ms_func_graph) const;
75   void UpdateJitForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v, const size_t &stream_id);
76   FuncGraphPtr GetJitForwardGraphCNodeInfo(const FuncGraphPtr &jit_forward_graph);
77   void Reset();
78 
79   bool eliminate_forward_{true};
80   // The graph phase is used to obtain backend graph that is complied by jit
81   std::string graph_phase_;
82   JitCompileInfo compile_info_;
83   mindspore::HashMap<std::string, TensorReplaceInfo> graph_phase_with_replace_info_{};
84   mindspore::HashMap<std::string, JitCompileInfo> jit_compile_info_{};
85 };
86 using JitPtr = std::shared_ptr<Jit>;
87 }  // namespace pynative
88 }  // namespace mindspore
89 
90 #endif  // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_JIT_JIT_GRAD_H_
91