1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_ 19 20 #include <memory> 21 #include <string> 22 #include <map> 23 #include <utility> 24 #include <stack> 25 #include <vector> 26 #include "pipeline/pynative/forward/do_cast.h" 27 #include "pipeline/pynative/forward/do_pyboost_cast.h" 28 #include "pipeline/pynative/forward/do_infer.h" 29 #include "backend/graph_compiler/backend.h" 30 #include "ir/cell.h" 31 #include "runtime/pipeline/async_hqueue.h" 32 #include "ops/view/view_strides_calculator.h" 33 #include "runtime/pipeline/async_rqueue.h" 34 35 namespace mindspore { 36 namespace pynative { 37 class GradExecutor; 38 using GradExecutorPtr = std::shared_ptr<GradExecutor>; 39 using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>; 40 41 using MindrtBackendMap = std::map<std::string, std::shared_ptr<compile::MindRTBackend>>; 42 43 class ForwardExecutor { 44 public: ForwardExecutor()45 ForwardExecutor() 46 : cast_operation_(std::make_shared<CastOperation>()), 47 pyboost_cast_operation_(std::make_shared<PyBoostCastOperation>()), 48 infer_operation_(std::make_shared<InferOperation>()) {} 49 ~ForwardExecutor() = default; 50 51 void Init(); 52 std::function<void(const FrontendOpRunInfoPtr &)> RunOpS = [this](auto &&PH1) { 53 RunOpFrontend(std::forward<decltype(PH1)>(PH1)); 54 }; 55 56 void DispatchFrontendTask(const FrontendOpRunInfoPtr &op_run_info); 57 void RunOpFrontend(const FrontendOpRunInfoPtr &op_run_info); 58 // If sub is true, this function will not convert StubTensor to Tensor. 59 // Used to reduce the overhead of StubTensor WaitValue. 60 FrontendOpRunInfoPtr GenerateOpRunInfo(const py::args &args, bool stub = false); 61 ValuePtr RunSliceOpFrontend(const std::vector<ValuePtr> &input_values, 62 const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad, 63 const stub::StubNodePtr &stub_output); 64 void DispatchSilceOpFrontendTask(const std::vector<ValuePtr> &input_values, 65 const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad, 66 const stub::StubNodePtr &stub_output); set_grad_executor(const GradExecutorPtr & grad_executor)67 void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); } 68 void RefreshForwardCallback(); 69 void ClearNodeAbsMap() const; 70 void SetNodeAbsMapByValue(const FrontendOpRunInfoPtr &op_run_info) const; 71 void SetNodeAbsMapById(const std::string &id, const abstract::AbstractBasePtr &abs) const; 72 AbstractBasePtr GetNodeAbsById(const std::string &id) const; 73 void ClearRes(); 74 bool EnablePipeline(const std::string &op_name) const; 75 bool enable_async() const; device_target()76 inline const std::string &device_target() const { return device_target_; } mindrt_backend()77 const MindrtBackendMap &mindrt_backend() const { return mindrt_backends_; } IsFirstCell()78 inline bool IsFirstCell() const { return forward_cell_stack_.empty(); } PushForwardCell(const CellPtr & cell)79 void PushForwardCell(const CellPtr &cell) { forward_cell_stack_.push(cell); } PopForwardCell()80 void PopForwardCell() { forward_cell_stack_.pop(); } 81 void ExecuteLazyTask() const; 82 void Sync(); 83 void PrintPyObjInfo(const py::object &obj, const std::string &str, bool is_cell) const; 84 void ProcessBeforeNewGraph(const py::object &obj, const py::args &args); 85 void ProcessAfterNewGraph(const py::object &obj) const; 86 void ProcessBeforeEndGraph(const py::object &obj, bool is_cell); 87 void ProcessAfterEndGraph(const py::object &obj, bool is_cell) const; 88 bool CellNotSetMixedPrecision(const FrontendOpRunInfoPtr &op_run_info); infer_operation()89 inline InferOperationPtr infer_operation() const { 90 MS_EXCEPTION_IF_NULL(infer_operation_); 91 return infer_operation_; 92 } set_is_jit_compiling(bool is_jit_compiling)93 inline void set_is_jit_compiling(bool is_jit_compiling) { is_jit_compiling_ = is_jit_compiling; } is_jit_compiling()94 bool is_jit_compiling() const { return is_jit_compiling_; } 95 96 void WaitForwardTask(); 97 bool IsVmOp(const std::string &op_name) const; 98 std::string GetCurrentCellObjId() const; 99 std::string GetCurrentDeviceTarget(const PrimitivePtr &op_prim) const; 100 void ReInit(); 101 void ForwardOpGradImpl(const FrontendOpRunInfoPtr &op_run_info) const; 102 GradExecutorPtr grad() const; 103 void InitOpRunInfo(const FrontendOpRunInfoPtr &op_run_info); 104 // Mix precision and Implicit transform 105 void SetCastForInputs(const FrontendOpRunInfoPtr &op_run_info) const; pyboost_cast_operation()106 inline const PyBoostCastOperationPtr &pyboost_cast_operation() const { 107 MS_EXCEPTION_IF_NULL(pyboost_cast_operation_); 108 return pyboost_cast_operation_; 109 } 110 void ChildAfterFork(); 111 112 private: 113 compile::MindRTBackendPtr GetMindRtBackend(const string &cur_device_target); cast_operation()114 inline CastOperationPtr cast_operation() const { 115 MS_EXCEPTION_IF_NULL(cast_operation_); 116 return cast_operation_; 117 } 118 ValuePtr RunOpInVM(const FrontendOpRunInfoPtr &op_run_info) const; 119 ValuePtr RunOpInMs(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info); 120 ValuePtr RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info); 121 ValuePtr RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info, 122 const BackendOpRunInfoPtr &backend_op_run_info); 123 void RunOpBackend(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info); 124 void RunOpBackendSync(const FrontendOpRunInfoPtr &op_run_info); 125 126 VectorRef RunOpBackendInner(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info); 127 // Infer output abstract 128 void InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) const; 129 void PrepareOpInputs(const FrontendOpRunInfoPtr &op_run_info); 130 void OpRunInfoUsePrimC(const FrontendOpRunInfoPtr &op_run_info) const; 131 void CreateInputAddressForViewOp(const tensor::BaseTensorPtr &input_tensor, const FrontendOpRunInfoPtr &op_run_info); 132 void DispatchViewKernelTask(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type); 133 void ForwardRunViewKernelTask(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type, 134 bool enable_async); 135 136 bool ProcessViewOp(const FrontendOpRunInfoPtr &op_run_info, const ops::StridesCalcFunc &func_info, 137 bool is_tuple_output); 138 device::DeviceAddressPtr TensorContiguousCallback(const DeviceSyncPtr &device_address, 139 const TensorStorageInfoPtr &storage_info); 140 141 void CreateViewOutputTensor(const FrontendOpRunInfoPtr &op_run_info, const tensor::BaseTensorPtr &input_tensor, 142 const TensorStorageInfoPtr &storage_info, runtime::KernelTaskType task_type); 143 144 void DispatchAllocateMemTask(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor, 145 const size_t &input_idx, bool need_wait = false); 146 PrimitivePtr GetSlicePrimFromCache(const std::string &op_name); 147 FrontendOpRunInfoPtr GenerateSliceOpRunInfo(const std::string &op_name, bool requires_grad, 148 const stub::StubNodePtr &stub_output); 149 void CreateViewOpOutputs(const FrontendOpRunInfoPtr &op_run_info, const tensor::BaseTensorPtr &view_input_tensor, 150 runtime::KernelTaskType task_type, const TensorStorageInfoPtrList &storage_infos, 151 bool is_tuple_output); 152 153 private: 154 bool init_{false}; 155 bool enable_async_{true}; 156 bool is_jit_compiling_{false}; 157 std::string device_target_; 158 std::string last_target_{"Unknown"}; 159 std::stack<CellPtr> forward_cell_stack_; 160 GradExecutorWeakPtr grad_executor_; 161 CastOperationPtr cast_operation_; 162 PyBoostCastOperationPtr pyboost_cast_operation_; 163 InferOperationPtr infer_operation_; 164 MindrtBackendMap mindrt_backends_; 165 mindspore::HashMap<std::string, PrimitivePtr> slice_prim_cache_; 166 }; 167 } // namespace pynative 168 } // namespace mindspore 169 170 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_ 171