• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_
19 
20 #include <vector>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <map>
25 #include <mutex>
26 #include <unordered_map>
27 #include <list>
28 
29 #include "pybind11/pybind11.h"
30 
31 #include "ir/anf.h"
32 #include "ir/tensor.h"
33 #include "pipeline/jit/ps/action.h"
34 #include "abstract/abstract_value.h"
35 #include "backend/graph_compiler/segment_runner.h"
36 #include "backend/graph_compiler/transform.h"
37 #include "pipeline/jit/ps/base.h"
38 #include "frontend/parallel/strategy.h"
39 #include "include/common/visible.h"
40 #include "mindrt/include/fork_utils.h"
41 
42 namespace mindspore {
43 // namespace to support pipeline structures definition
44 namespace pipeline {
45 
46 namespace py = pybind11;
47 
48 class Pipeline {
49  public:
Pipeline(const ResourcePtr & res,const std::vector<ActionItem> & actions)50   Pipeline(const ResourcePtr &res, const std::vector<ActionItem> &actions) : resource_(res), actions_(actions) {}
51 
52   ~Pipeline() = default;
53 
54   void Run();
55 
resource()56   ResourcePtr resource() { return resource_; }
57 
58   bool NeedCreateBackend();
59 
60  private:
61   ResourcePtr resource_;
62   std::vector<ActionItem> actions_;
63 };
64 
65 // A function pipeline.
66 class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
67  public:
GetInstance()68   static std::shared_ptr<GraphExecutorPy> GetInstance() {
69     std::lock_guard<std::mutex> i_lock(instance_lock_);
70     if (executor_ == nullptr) {
71       executor_ = std::shared_ptr<GraphExecutorPy>(new (std::nothrow) GraphExecutorPy());
72     }
73     return executor_;
74   }
75 
76   ~GraphExecutorPy();
77 
78   bool Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
79                bool use_vm);
80   bool CompileInner(const FuncGraphPtr &graph, const py::tuple &args, const py::dict &kwargs, const std::string &phase,
81                     bool use_vm, bool trace_flag = false);
82   py::object Run(const py::tuple &args, const py::object &phase);
83 
phase()84   const std::string &phase() const { return phase_; }
85   void SaveCompiledGraph(const std::string &phase);
86   void ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
87                    abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments);
88   void ConvertSymbolicShape(const py::tuple &args, AbstractBasePtrList *args_abs);
89   void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list);
90   ResourcePtr GetResource(const std::string &phase);
91   FuncGraphPtr GetFuncGraph(const std::string &phase);
92   void SetJitPrimalFuncGraph(const FuncGraphPtr &primal_func_graph, const std::string &phase);
93   FuncGraphPtr GetJitPrimalFuncGraph(const std::string &phase);
94   FuncGraphPtr GetJitGradGraph(const std::string &phase);
95   void SetJitGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
96   py::bytes GetFuncGraphProto(const std::string &phase, const std::string &ir_type, const bool &incremental);
97   py::bytes GetObfuscateFuncGraphProto(const std::string &phase, const bool &incremental, const float obf_ratio,
98                                        const int branch_control_input);
99 #ifndef ENABLE_SECURITY
100   py::bytes GetOptimizeGraphProto(const std::string &phase);
101 #endif
102   void SetJitConfig(const py::dict &jit_config);
103   compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
104   bool HasCompiled(const std::string &phase) const;
105 
106   FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase) const;
107   void ExportGraph(const std::string &file_name, const std::string &phase, const py::object encrypt = py::none(),
108                    char *key = nullptr);
109   bool InitParams(const py::dict &init_params, const std::string &phase) const;
110   py::dict GetParams(const std::string &phase);
111   py::bytes GetRandomStatus(const std::string &phase) const;
112   void UpdataParamNodeDefaultInput(const std::string &phase,
113                                    const std::unordered_map<std::string, tensor::TensorPtr> &params_value);
114   void PyExePath(const py::object &py_exe_path) const;
115   void KernelBuildServerDir(const py::object &kernel_build_server_dir) const;
116   py::dict GetParameterLayout(const std::string &phase);
117   py::tuple FlopsCollection(const std::string &phase);
118   // Get CNode name, input node name and attribute from each graph
119   py::dict GetParallelGraphInfo(const std::string &phase);
120   py::dict GetCNodeStrategy(const std::string &phase);
121   py::list GetParallelParameterNameList(const std::string &phase);
122   void SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy);
123   size_t GetNumOpsInfo(const std::string &phase);
124   void SetNumOpsInfo(size_t num_ops);
125   py::dict GetAllreduceFusion(const std::string &phase);
126   void DelNetRes(const py::object &source, const py::set &id);
127   void ReleaseResourceOnException(const py::object &phase);
128   void CleanCompileRes(const ResourcePtr &resource);
129   static void ClearRes();
set_queue_name(const std::string & queue_name)130   void set_queue_name(const std::string &queue_name) { queue_name_ = queue_name; }
131   std::string get_queue_name(const std::string &dataset_phase);
set_enable_tuple_broaden(bool enable_tuple_broaden)132   void set_enable_tuple_broaden(bool enable_tuple_broaden) { enable_tuple_broaden_ = enable_tuple_broaden; }
set_compile_cache_dep_files(const py::list & compile_cache_dep_files)133   void set_compile_cache_dep_files(const py::list &compile_cache_dep_files) {
134     compile_cache_dep_files_ = compile_cache_dep_files;
135   }
set_weights_values(const py::dict & weights)136   void set_weights_values(const py::dict &weights) { weights_ = weights; }
137 #ifdef ENABLE_DEBUGGER
138   void TerminateDebugger();
139 #endif
140 
141   // Generate a key for mapping function graph
142   py::object GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
143                                   bool enable_tuple_broaden = false);
144   // Check consistency of two arguments for mapping function graph
145   void CheckArgumentsConsistency(const py::tuple &compile_args, const py::tuple &args_list, const py::object &target);
146   void ClearCompileArgumentsResource();
147 
148   void ClearCurConvertInput();
149   void ParentBeforeFork();
150   void ParentAfterFork();
151   void ChildAfterFork();
152 
IncGraphCellCount()153   void IncGraphCellCount() { ++graph_cell_count_; }
DecGraphCellCount()154   void DecGraphCellCount() { --graph_cell_count_; }
graph_cell_count()155   size_t graph_cell_count() const { return graph_cell_count_; }
156 
157  private:
158   GraphExecutorPy() = default;
159   void ParallelPostProcess(const string &phase, bool use_compile_cache);
160   void GetGeBackendPolicy() const;
161   // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
162   // 'validate' stage
163   static std::vector<ActionItem> FilterActions(const std::vector<ActionItem> &actions, const std::string &phase);
164 
165   void DelOneNetRes(const py::handle &py_phase);
166   // If enable compile cache, get the compile cache resource.
167   void InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase);
168 
169 #ifdef WITH_BACKEND
170   void GeFirstInitParams();
171 #endif
172 
173   bool CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
174                     bool use_vm);
175   py::object RunInner(const py::tuple &args, const py::object &phase);
176   void ClearRunArgumentsResource(size_t input_arg_size, VectorRef *arg_list);
177 
178   std::map<std::string, ExecutorInfoPtr> info_;
179   static std::shared_ptr<GraphExecutorPy> executor_;
180   static std::mutex instance_lock_;
181   std::map<std::string, py::dict> stra_dict_;
182   std::string phase_{""};
183   std::string source_{""};
184   std::string obj_desc_{""};
185   std::map<std::string, size_t> phase_to_num_op_info_;
186   std::string queue_name_;
187   bool enable_tuple_broaden_{false};
188   py::list compile_cache_dep_files_;
189   bool compile_cache_consistent_{true};
190   py::dict weights_;
191   std::map<PyObject *, std::pair<ValuePtr, AbstractBasePtr>> cur_convert_input_;
192   bool executor_running_{false};
193   // Temporary solution, disable boost infer when there is a graph cell instance.
194   size_t graph_cell_count_{0};
195 };
196 using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;
197 
198 std::string GetJitLevel();
199 
200 std::string GetObjDesc(const py::object &source);
201 bool IsPhaseLoadFromMindIR(const std::string &phase);
202 void CheckArgsValid(const py::object &source, const py::tuple &args);
203 py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs);
204 
205 bool InitDistribute(const std::map<std::string, std::string> &options);
206 
207 void ResetOpId();
208 void ResetOpIdWithOffset();
209 void InitHccl();
210 void FinalizeHccl();
211 uint32_t GetHcclRankId();
212 uint32_t GetHcclRankSize();
213 void InitPipeline();
214 void FinalizeBackend();
215 void ME_EXPORT ClearResAtexit();
216 void CloseTsd(bool force = false);
217 void MemoryRecycle();
218 void BindDeviceCtx();
219 
220 FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
221                         const std::string &dec_mode, const py::object decrypt = py::none(),
222                         const bool obfuscated = false);
223 
224 FuncGraphPtr SplitMindIR(const std::string &file_name);
225 
226 FuncGraphPtr SplitDynamicMindIR(const std::string &file_name, size_t device_num, size_t rank_id, bool sapp);
227 
228 // init and exec dataset sub graph
229 bool ME_EXPORT InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
230                                const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
231                                const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run);
232 
233 // Build and run dataset subgraph for ms backend
234 bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
235                        const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
236                        const std::vector<int64_t> &input_indexes, bool need_run);
237 
238 void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
239 
240 py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode);
241 py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode);
242 py::bytes PyDecryptData(char *model_data, size_t data_size, char *key, size_t key_len, const std::string &dec_mode);
243 bool PyIsCipherFile(const std::string &file_path);
244 void FinalizeCluster();
245 FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int branch_control_input,
246                                     char *dec_key, const size_t key_len, const std::string &dec_mode);
247 void SwapCache(const tensor::TensorPtr &host, const tensor::TensorPtr &device, const tensor::TensorPtr &block_mapping,
248                const bool &type);
249 bool IsPhaseExport(const std::string &phase);
250 }  // namespace pipeline
251 }  // namespace mindspore
252 
253 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_
254