• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <map>
25 #include <set>
26 #include "runtime/hardware/device_context.h"
27 #include "runtime/framework/actor/actor_common.h"
28 #include "runtime/framework/control_node_parser.h"
29 #include "backend/session/session_basic.h"
30 #include "backend/session/session_factory.h"
31 #include "ir/tensor.h"
32 
33 namespace mindspore {
34 using device::DeviceContext;
35 using session::CallBackFunc;
36 using session::GraphOutputInfo;
37 using session::InputTensorInfo;
38 using session::KernelGraph;
39 using session::KernelWithIndex;
40 using session::OpRunInfo;
41 using tensor::TensorPtr;
42 
43 namespace runtime {
44 // Position of kernel with index, the value pair<branch_id, vector<pos>> means the branch id of the kernel and the pos
45 // of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are
46 // multiple branch scenarios, and pos represents the position of the kernel in the branch.
47 using KernelMapPosition = std::map<KernelWithIndex, std::vector<size_t>, session::KernelWithIndexCmp>;
48 
49 // The graph compiler info generated by graph compiler is the express of executable graph.
50 // The device context is unified interface of interaction with device of corresponding graph.
51 // The tensors mask is used to distinguish input tensor's type.
52 // The input tensor is used to link graphs in the dynamic build scenario.
53 // The control node is used to link graphs in the control flow scenario.
54 // The control node parser is used to parse the edge info in control nodes.
55 // The origin parameters order is used to correspond to the input args.
56 // The origin outputs order is used to correspond to the output args.
57 // The need_erase means need erase this GraphCompilerInfo object after run actor set.
58 struct GraphCompilerInfo {
GraphCompilerInfoGraphCompilerInfo59   GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
60                     const std::vector<std::vector<int64_t> *> &tensors_mask,
61                     const std::vector<std::vector<TensorPtr> *> &input_tensors,
62                     const std::vector<AnfNodePtr> &control_nodes,
63                     const std::vector<AnfNodePtr> &origin_parameters_order, const ControlNodeParserPtr &parser,
64                     const KernelMapPosition &origin_outputs_order, const size_t outputs_num, const std::string &name,
65                     bool need_erase, GraphExecutionStrategy strategy)
66       : graphs_(graphs),
67         device_contexts_(device_contexts),
68         tensors_mask_(tensors_mask),
69         input_tensors_(input_tensors),
70         control_nodes_(control_nodes),
71         control_node_parser_(parser),
72         origin_parameters_order_(origin_parameters_order),
73         origin_outputs_order_(origin_outputs_order),
74         outputs_num_(outputs_num),
75         name_(name),
76         need_erase_(need_erase),
77         strategy_(strategy) {}
78   ~GraphCompilerInfo();
79   std::vector<KernelGraphPtr> graphs_;
80   std::vector<DeviceContext *> device_contexts_;
81   std::vector<std::vector<int64_t> *> tensors_mask_;
82   std::vector<std::vector<TensorPtr> *> input_tensors_;
83   std::vector<AnfNodePtr> control_nodes_;
84   ControlNodeParserPtr control_node_parser_;
85   std::vector<AnfNodePtr> origin_parameters_order_;
86   KernelMapPosition origin_outputs_order_;
87   size_t outputs_num_;
88   std::string name_;
89   bool need_erase_;
90   GraphExecutionStrategy strategy_;
91 };
92 
93 class GraphCompiler {
94  public:
GraphCompiler()95   GraphCompiler() { session_ = session::SessionFactory::Get().Create(kSessionBasic); }
96   ~GraphCompiler() = default;
97 
98   // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode,
99   // the detailed implementation of compiling graph is in 'CompileGraphImpl'.
100   GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs, const DeviceContext *device_context);
101 
102   // Construct single op kernel graph and compile the kernel graph in PyNative mode.
103   GraphId CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info,
104                        const std::vector<int64_t> *tensors_mask, std::vector<TensorPtr> *const input_tensors,
105                        bool *single_op_cache_hit, const DeviceContext *device_context);
106 
107   // Get graph by graph id, if not exist return nullptr, used in Graph mode.
108   KernelGraphPtr Fetch(GraphId graph_id) const;
109 
110   // Get graph by graph info, if not exist return nullptr, used in PyNative mode.
111   KernelGraphPtr Fetch(const GraphInfo &graph_info) const;
112 
113   // The following four methods used in PyNative back propagation to split complete kernel graph to single
114   // op graph, and these methods will be removed to class MindRTBackend after deleting session module.
115 
116   // Cache index for all parameter and output nodes of kernel graph, used to get parameter of single op and
117   // recover output of original complete back propagation kernel graph.
118   void GetParamAndOutputIndex(const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs,
119                               VectorRef *const outputs, std::map<AnfNodePtr, size_t> *parameter_index,
120                               std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
121 
122   // Get input tensors for single op compile and run, input tensors may convert from value node and parameter in graph
123   // and prev kernel node's output.
124   void GetSingleOpInputTensors(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
125                                const std::map<AnfNodePtr, size_t> &parameter_index,
126                                const std::vector<TensorPtr> &graph_inputs, InputTensorInfo *const input_tensor_info);
127   // Get one input tensor for single control op, such as bprop_cut.
128   TensorPtr GetSingleOpInputTensorByIndex(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
129                                           const std::map<AnfNodePtr, size_t> &parameter_index,
130                                           const std::vector<TensorPtr> &graph_inputs,
131                                           InputTensorInfo *const input_tensor_info, size_t input_index);
132 
133   // Get OpRunInfo and GraphInfo for single op compile and run.
134   void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
135                                       OpRunInfo *const run_info, GraphInfo *const graph_info);
136 
137   // Calculate ref count of PyNative back propagation operators.
138   void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;
139 
140   // Update ref count of PyNative back propagation operators.
141   void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
142                       std::map<KernelWithIndex, size_t> *ref_count,
143                       std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const;
144 
145   // Handle single op output tensor and recover output of original complete kernel graph.
146   void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
147                           const std::map<KernelWithIndex, size_t> &ref_count,
148                           std::map<KernelWithIndex, TensorPtr> *op_output_map,
149                           GraphOutputInfo *const graph_output_info) const;
150 
151   // Collect output tensors of back propagation graph for allreduce operators to average gradient,
152   // used in PyNative distributed training mode.
153   void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
154 
155   // Clear resource in bucket, such as useless tensors and device memory of all communication operators,
156   // Bucket is used in PyNative distributed training mode, one bucket handles all resource to launch and sync allreduce
157   // operator.
158   void ClearAllBucket(const GraphId &graph_id);
159 
160   const std::vector<KernelWithIndex> &GetGraphOutputNodes(GraphId graph_id) const;
161 
162   // Register a summary callback function, which is called in the final stages of summary.
163   void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const;
164   // Execute graph summary.
165   void Summary(const std::vector<KernelGraphPtr> &graphs) const;
166 
167   // Remove single op kernel graph cache and output nodes cache.
168   void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
169 
170  private:
171   DISABLE_COPY_AND_ASSIGN(GraphCompiler);
172 
173   // The implementation of compiling graph in Graph Mode, including optimizing graph,
174   // setting operator info, creating kernel and transforming kernel graph to ActorSet.
175   GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
176 
177   // Create device address for all anf nodes of graph.
178   void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
179 
180   // Single op kernel graph cache for PyNative mode.
181   std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_;
182   // Single op kernel graph output nodes cache for PyNative mode.
183   std::unordered_map<GraphId, std::vector<KernelWithIndex>> run_op_graph_output_nodes_;
184 
185   // The member variable 'session_' will be removed after removing session module.
186   // Now all the GraphCompiler share the same 'session_'.
187   session::SessionPtr session_;
188 };
189 
190 }  // namespace runtime
191 }  // namespace mindspore
192 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
193