• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/forward/forward.h"
18 #include <set>
19 #include <algorithm>
20 #include <unordered_set>
21 #include <vector>
22 #include "ops/structure_op_name.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "pipeline/pynative/pynative_utils.h"
26 #include "pybind_api/gil_scoped_long_running.h"
27 #include "include/common/utils/python_fallback_running.h"
28 #include "backend/graph_compiler/transform.h"
29 #include "utils/ms_context.h"
30 #include "pipeline/pynative/forward/forward_task.h"
31 #include "pipeline/pynative/predict_out_type_map.h"
32 #include "include/common/utils/stub_tensor.h"
33 #include "runtime/pynative/op_executor.h"
34 #ifndef ENABLE_SECURITY
35 #include "include/backend/debug/profiler/profiling.h"
36 using mindspore::profiler::ProfilerManager;
37 #endif
38 #include "frontend/operator/ops_front_infer_function.h"
39 #include "runtime/pipeline/pipeline.h"
40 #include "runtime/device/device_address_utils.h"
41 
42 namespace mindspore {
43 namespace pynative {
44 enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
45 namespace {
46 const std::set<std::string> kVmOperators = {"InsertGradientOf", "StopGradient", "HookBackward", "CellBackwardHook"};
47 constexpr char kBegin[] = "Begin";
48 constexpr char kEnd[] = "End";
49 constexpr auto kOpNameCustom = "Custom";
50 
51 // Shallow Copy Value and change shape
ShallowCopyValue(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & value)52 ValuePtr ShallowCopyValue(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &value) {
53   MS_EXCEPTION_IF_NULL(op_run_info);
54   MS_EXCEPTION_IF_NULL(value);
55   auto tensor_abs = op_run_info->base_op_run_info.abstract;
56   MS_EXCEPTION_IF_NULL(tensor_abs);
57   if (tensor_abs->isa<abstract::AbstractRefTensor>()) {
58     tensor_abs = tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
59   }
60   auto new_shape = tensor_abs->BuildShape()->cast<abstract::ShapePtr>();
61   MS_EXCEPTION_IF_NULL(new_shape);
62   if (value->isa<mindspore::tensor::BaseTensor>()) {
63     auto tensor_value = value->cast<mindspore::tensor::BaseTensorPtr>();
64     return std::make_shared<mindspore::tensor::Tensor>(tensor_value->data_type(), new_shape->shape(),
65                                                        tensor_value->data_c(), tensor_value->Size());
66   }
67   if (value->isa<ValueTuple>()) {
68     std::vector<ValuePtr> values;
69     auto value_tuple = value->cast<ValueTuplePtr>();
70     (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
71                          [op_run_info](const ValuePtr &elem) { return ShallowCopyValue(op_run_info, elem); });
72     return std::make_shared<ValueTuple>(values);
73   }
74   return value;
75 }
76 
CopyTensorValueWithNewId(const ValuePtr & v)77 ValuePtr CopyTensorValueWithNewId(const ValuePtr &v) {
78   MS_EXCEPTION_IF_NULL(v);
79   if (v->isa<tensor::BaseTensor>()) {
80     auto tensor = v->cast<tensor::BaseTensorPtr>();
81     // This constructor will make a tensor with the new id
82     auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
83     new_tensor->set_need_pipeline_sync(true);
84     new_tensor->set_device_address(tensor->device_address());
85     new_tensor->set_sync_status(tensor->sync_status());
86     return new_tensor;
87   }
88   if (v->isa<ValueTuple>()) {
89     const auto &v_tup = v->cast<ValueTuplePtr>();
90     ValuePtrList list;
91     for (const auto &ele : v_tup->value()) {
92       (void)list.emplace_back(CopyTensorValueWithNewId(ele));
93     }
94     return std::make_shared<ValueTuple>(list);
95   }
96   if (v->isa<ValueList>()) {
97     const auto &v_list = v->cast<ValueListPtr>();
98     ValuePtrList list;
99     for (const auto &ele : v_list->value()) {
100       (void)list.emplace_back(CopyTensorValueWithNewId(ele));
101     }
102     return std::make_shared<ValueList>(list);
103   }
104   return v;
105 }
106 
UpdateOutputStubNodeAbs(const FrontendOpRunInfoPtr & op_run_info)107 void UpdateOutputStubNodeAbs(const FrontendOpRunInfoPtr &op_run_info) {
108   if (op_run_info->stub_output == nullptr) {
109     return;
110   }
111   const auto &abs = op_run_info->base_op_run_info.abstract;
112   MS_EXCEPTION_IF_NULL(abs);
113   auto success = op_run_info->stub_output->SetAbstract(abs);
114   if (!success) {
115     const auto &op_name = op_run_info->base_op_run_info.op_name;
116     MS_EXCEPTION(TypeError) << "The predict type and infer type is not match, predict type is "
117                             << PredictOutType(op_run_info) << ", infer type is " << abs->BuildType()
118                             << ", the name of operator is [" << op_name
119                             << "]. Please modify or add predict type of operator in predict_out_type_map.h.";
120   }
121   MS_LOG(DEBUG) << "Update StubNode abstract " << abs->ToString();
122 }
123 
ClonePrim(const FrontendOpRunInfoPtr & op_run_info)124 void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
125   // Clone a new prim
126   MS_EXCEPTION_IF_NULL(op_run_info);
127   auto prim = op_run_info->op_grad_info->op_prim;
128   auto prim_py = prim->cast<PrimitivePyPtr>();
129   if (prim_py == nullptr) {
130     return;
131   }
132   auto new_adapter = std::make_shared<PrimitivePyAdapter>(*prim_py->adapter());
133   auto new_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_grad_info->op_prim->cast<PrimitivePyPtr>()));
134   new_prim->EnableSharedMutex();
135   op_run_info->op_grad_info->op_prim = new_prim;
136   MS_EXCEPTION_IF_NULL(new_adapter);
137   new_adapter->set_attached_primitive(new_prim);
138 }
139 
IsDynamicInputs(const FrontendOpRunInfoPtr & op_run_info)140 bool IsDynamicInputs(const FrontendOpRunInfoPtr &op_run_info) {
141   for (const auto &value : op_run_info->op_grad_info->input_value) {
142     MS_EXCEPTION_IF_NULL(value);
143     if (value->isa<stub::SequenceNode>()) {
144       return true;
145     }
146     if (!value->isa<ValueSequence>()) {
147       continue;
148     }
149     auto value_seq = value->cast<ValueSequencePtr>();
150     MS_EXCEPTION_IF_NULL(value_seq);
151 
152     const auto &tuple_inputs = value_seq->value();
153     if (tuple_inputs.empty()) {
154       continue;
155     }
156     if (tuple_inputs[0]->isa<tensor::BaseTensor>() || tuple_inputs[0]->isa<stub::TensorNode>()) {
157       return true;
158     }
159   }
160   return false;
161 }
162 
ConstructOutputInVM(const std::vector<ValuePtr> & result)163 ValuePtr ConstructOutputInVM(const std::vector<ValuePtr> &result) {
164   if (result.size() == 1) {
165     return result[kIndex0];
166   }
167   return std::make_shared<ValueTuple>(result);
168 }
169 
UpdateOutputStubNodeValue(const FrontendOpRunInfoPtr & op_run_info)170 void UpdateOutputStubNodeValue(const FrontendOpRunInfoPtr &op_run_info) {
171   if (op_run_info->stub_output != nullptr) {
172     op_run_info->stub_output->SetValue(op_run_info->real_out);
173   }
174 }
175 
CreateBackendOpRunInfo(const FrontendOpRunInfoPtr & op_run_info)176 BackendOpRunInfoPtr CreateBackendOpRunInfo(const FrontendOpRunInfoPtr &op_run_info) {
177   auto backend_op_run_info = std::make_shared<BackendOpRunInfo>(
178     op_run_info->base_op_run_info, std::make_shared<Primitive>(*op_run_info->op_grad_info->op_prim), true, false);
179   // Erase RandomOp cache avoid memory leak.
180   if (AnfAlgo::NeedEraseCache(backend_op_run_info->op_prim)) {
181     backend_op_run_info->base_op_run_info.need_earse_cache = true;
182   }
183   if (op_run_info->base_op_run_info.has_dynamic_output) {
184     backend_op_run_info->base_op_run_info.use_dynamic_shape_process = true;
185   }
186   return backend_op_run_info;
187 }
188 
UpdateStubTensor(const FrontendOpRunInfoPtr & op_run_info)189 void UpdateStubTensor(const FrontendOpRunInfoPtr &op_run_info) {
190   // Some operators do not have StubNodes, such as Cast inserted for automatic mixed precision.
191   if (op_run_info->stub_output != nullptr) {
192     if (op_run_info->base_op_run_info.has_dynamic_output ||
193         OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name)) {
194       UpdateOutputStubNodeAbs(op_run_info);
195     }
196     op_run_info->stub_output->SetValue(op_run_info->real_out);
197   }
198 }
199 
GetViewOpTaskType(const std::string & op_name)200 runtime::KernelTaskType GetViewOpTaskType(const std::string &op_name) {
201   if (op_name == kCopyWithSliceOpName) {
202     return runtime::KernelTaskType::kCOPY_TASK;
203   }
204   return runtime::KernelTaskType::kNORMAL_VIEW_TASK;
205 }
206 
EmplaceSliceInputs(const FrontendOpRunInfoPtr & op_run_info,const std::vector<ValuePtr> & input_values,const SliceOpInfoPtr & slice_op_info)207 void EmplaceSliceInputs(const FrontendOpRunInfoPtr &op_run_info, const std::vector<ValuePtr> &input_values,
208                         const SliceOpInfoPtr &slice_op_info) {
209   for (auto idx : slice_op_info->data_indexs) {
210     if (idx >= input_values.size()) {
211       MS_LOG(EXCEPTION) << "data_idx is out of bounds, data_idx:" << idx
212                         << " input_values.size():" << input_values.size();
213     }
214     (void)op_run_info->op_grad_info->input_value.emplace_back(input_values[idx]);
215   }
216 
217   for (const auto &slice_index : slice_op_info->slice_index_inputs) {
218     ValuePtr v = nullptr;
219     if (slice_index->is_int()) {
220       v = MakeValue(slice_index->int_value());
221     } else {
222       v = MakeValue(slice_index->vec_value());
223     }
224 
225     (void)op_run_info->op_grad_info->input_value.emplace_back(v);
226   }
227 
228   if (op_run_info->requires_grad && op_run_info->base_op_run_info.op_name == kStridedSliceOpName) {
229     // StridedSlice mask input
230     int64_t v = 0;
231 
232     (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v));  // begin_mask
233     (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v));  // end_mask
234     (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v));  // ellipsis_mask
235     (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v));  // new_axis_mask
236     (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v));  // shrink_new_mask
237   }
238 
239   op_run_info->input_size = op_run_info->op_grad_info->input_value.size();
240   op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
241 }
242 
243 #ifndef ENABLE_TEST
GetCurStreamId(const std::string & device_target)244 size_t GetCurStreamId(const std::string &device_target) {
245   auto device_ctx = runtime::OpRunner::GetDeviceContext(device_target);
246   return device_ctx->device_res_manager_->GetCurrentStreamId();
247 }
248 #endif
249 }  // namespace
250 
WaitForwardTask()251 void ForwardExecutor::WaitForwardTask() {
252   GilReleaseWithCheck gil_release;
253   runtime::Pipeline::Get().frontend_stage()->Wait();
254 }
255 
IsVmOp(const std::string & op_name) const256 bool ForwardExecutor::IsVmOp(const std::string &op_name) const {
257   return kVmOperators.find(op_name) != kVmOperators.end();
258 }
259 
GetCurrentCellObjId() const260 std::string ForwardExecutor::GetCurrentCellObjId() const {
261   if (forward_cell_stack_.empty()) {
262     return "";
263   }
264   auto &cell = forward_cell_stack_.top();
265   return cell->id();
266 }
267 
grad() const268 GradExecutorPtr ForwardExecutor::grad() const {
269   auto grad_executor = grad_executor_.lock();
270   MS_EXCEPTION_IF_NULL(grad_executor);
271   return grad_executor;
272 }
273 
InitOpRunInfo(const FrontendOpRunInfoPtr & op_run_info)274 void ForwardExecutor::InitOpRunInfo(const FrontendOpRunInfoPtr &op_run_info) {
275   Init();
276   // Used for async run
277   op_run_info->requires_grad = grad()->RequiresGrad();
278   if (op_run_info->requires_grad) {
279     op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
280   } else {
281     op_run_info->base_op_run_info.use_dynamic_shape_process =
282       grad()->forward_use_dynamic_shape_process() || grad()->use_dynamic_shape_process();
283   }
284   op_run_info->base_op_run_info.device_target = GetCurrentDeviceTarget(op_run_info->op_grad_info->op_prim);
285   op_run_info->cell_obj_id = GetCurrentCellObjId();
286   auto device_context = runtime::OpRunner::GetDeviceContext(op_run_info->base_op_run_info.device_target);
287   op_run_info->base_op_run_info.stream_id = device_context->device_res_manager_->GetCurrentStreamId();
288 }
289 
ReInit()290 void ForwardExecutor::ReInit() {
291   device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
292   enable_async_ = !MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
293 }
294 
Init()295 void ForwardExecutor::Init() {
296   ReInit();
297   if (init_) {
298     return;
299   }
300   init_ = true;
301   MS_LOG(DEBUG) << "Init ForwardExecutor";
302   compile::SetMindRTEnable();
303   python_adapter::set_python_env_flag(true);
304   runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() {
305     runtime::Pipeline::Get().frontend_stage()->Wait();
306     grad()->WaitBpropTask();
307   });
308 }
309 
RefreshForwardCallback()310 void ForwardExecutor::RefreshForwardCallback() {
311 #if defined(_WIN32) || defined(_WIN64)
312   runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() {
313     runtime::Pipeline::Get().frontend_stage()->Wait();
314     grad()->WaitBpropTask();
315   });
316 #endif
317   // ForwardCallback has been set in ForwardExecutor::Init, no need to refresh anymore.
318 }
319 
enable_async() const320 bool ForwardExecutor::enable_async() const {
321 #if defined(ENABLE_TEST) || defined(__APPLE__)
322   return false;
323 #else
324   return enable_async_;
325 #endif
326 }
327 
EnablePipeline(const std::string & op_name) const328 bool ForwardExecutor::EnablePipeline(const std::string &op_name) const {
329   return enable_async() && !IsVmOp(op_name) && op_name != kOpNameCustom && !ScopedFallbackRunning::on() &&
330          !runtime::OpExecutor::NeedSync();
331 }
332 
DispatchFrontendTask(const FrontendOpRunInfoPtr & op_run_info)333 void ForwardExecutor::DispatchFrontendTask(const FrontendOpRunInfoPtr &op_run_info) {
334   auto forward_task = std::make_shared<FrontendTask>(
335     [this](const FrontendOpRunInfoPtr &op_run_info) { RunOpFrontend(op_run_info); }, op_run_info);
336   runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(forward_task->task_id());
337   runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
338 }
339 
ForwardOpGradImpl(const FrontendOpRunInfoPtr & op_run_info) const340 void ForwardExecutor::ForwardOpGradImpl(const FrontendOpRunInfoPtr &op_run_info) const {
341   // If jit is compiled in first step, op info will not be find in second training step
342   MS_LOG(DEBUG) << "Current custom bprop cell count " << op_run_info->async_status.custom_bprop_cell_count;
343   if (!op_run_info->async_status.is_jit_compiling && op_run_info->async_status.custom_bprop_cell_count <= 0) {
344     grad()->ProcessOpGradInfo(op_run_info);
345   }
346 }
347 
ForwardRunViewKernelTask(const FrontendOpRunInfoPtr & op_run_info,const runtime::KernelTaskType & task_type,bool enable_async)348 void ForwardExecutor::ForwardRunViewKernelTask(const FrontendOpRunInfoPtr &op_run_info,
349                                                const runtime::KernelTaskType &task_type, bool enable_async) {
350   if (task_type == runtime::KernelTaskType::kNORMAL_VIEW_TASK) {
351     return;
352   }
353   MS_LOG(DEBUG) << "Start, task_type:" << task_type;
354 
355   const auto &cur_mind_rt_backend = GetMindRtBackend(op_run_info->base_op_run_info.device_target);
356   MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
357   cur_mind_rt_backend->RunViewKernelTask(op_run_info->base_op_run_info, task_type, enable_async);
358 
359   MS_LOG(DEBUG) << "End";
360 }
361 
CreateViewOpOutputs(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & view_input_tensor,runtime::KernelTaskType task_type,const TensorStorageInfoPtrList & storage_infos,bool is_tuple_output)362 void ForwardExecutor::CreateViewOpOutputs(const FrontendOpRunInfoPtr &op_run_info,
363                                           const tensor::BaseTensorPtr &view_input_tensor,
364                                           runtime::KernelTaskType task_type,
365                                           const TensorStorageInfoPtrList &storage_infos, bool is_tuple_output) {
366   const bool is_single_tensor_output = storage_infos.size() == 1 && !is_tuple_output;
367   // Generate output abs by storage_info.
368   if (is_single_tensor_output) {
369     op_run_info->base_op_run_info.abstract = abstract::MakeAbstractTensor(
370       std::make_shared<abstract::Shape>(storage_infos[0]->shape), view_input_tensor->Dtype());
371   } else {
372     AbstractBasePtrList abs_list;
373     for (const auto &storage_info : storage_infos) {
374       auto abs = abstract::MakeAbstractTensor(std::make_shared<abstract::Shape>(storage_info->shape),
375                                               view_input_tensor->Dtype());
376       (void)abs_list.emplace_back(abs);
377     }
378     op_run_info->base_op_run_info.abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
379   }
380 
381   UpdateOutputStubNodeAbs(op_run_info);
382   CreateInputAddressForViewOp(view_input_tensor, op_run_info);
383 
384   for (size_t i = 0; i < storage_infos.size(); i++) {
385     MS_LOG(DEBUG) << "View op " << op_run_info->base_op_run_info.op_name << ", i:" << i
386                   << ", storage_info:" << storage_infos[i]->ToString();
387     CreateViewOutputTensor(op_run_info, view_input_tensor, storage_infos[i], task_type);
388   }
389 
390   if (is_single_tensor_output) {
391     op_run_info->real_out = op_run_info->base_op_run_info.output_tensors[0];
392     op_run_info->op_grad_info->output_size = 1;
393   } else {
394     std::vector<ValuePtr> output_values;
395     (void)std::transform(op_run_info->base_op_run_info.output_tensors.begin(),
396                          op_run_info->base_op_run_info.output_tensors.end(), std::back_inserter(output_values),
397                          [](const auto &t) {
398                            MS_EXCEPTION_IF_NULL(t);
399                            return t;
400                          });
401     op_run_info->real_out = std::make_shared<ValueTuple>(output_values);
402     op_run_info->op_grad_info->output_size = output_values.size();
403   }
404 
405   UpdateOutputStubNodeValue(op_run_info);
406 }
407 
ProcessViewOp(const FrontendOpRunInfoPtr & op_run_info,const ops::StridesCalcFunc & strides_calc_func,bool is_tuple_output)408 bool ForwardExecutor::ProcessViewOp(const FrontendOpRunInfoPtr &op_run_info,
409                                     const ops::StridesCalcFunc &strides_calc_func, bool is_tuple_output) {
410   MS_LOG(DEBUG) << "Start, op:" << op_run_info->base_op_run_info.op_name;
411   if (op_run_info->op_grad_info->input_value.empty()) {
412     MS_LOG(EXCEPTION) << "op_run_info->op_grad_info->input_value is empty";
413   }
414 
415   // Only split and chunk has mul outputs, and input tensor is first input.
416   auto view_value = op_run_info->op_grad_info->input_value[0];
417   MS_EXCEPTION_IF_NULL(view_value);
418   if (!view_value->isa<tensor::BaseTensor>()) {
419     MS_EXCEPTION(TypeError) << "For primitive[" << op_run_info->base_op_run_info.op_name
420                             << "],  the input[0] should be Tensor, but got:" << view_value->ToString();
421   }
422   auto view_input_tensor = view_value->cast<tensor::BaseTensorPtr>();
423   MS_EXCEPTION_IF_NULL(view_input_tensor);
424 
425   auto storage_infos = strides_calc_func(op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_value);
426   if (storage_infos.empty()) {
427     MS_LOG(DEBUG) << "Not View op " << op_run_info->base_op_run_info.op_name;
428     return false;
429   }
430 
431   // Reuse SetInputAbstract, abs of inputs is need when requires_grad is true.
432   InferOutputAbstract(op_run_info);
433   runtime::KernelTaskType task_type = GetViewOpTaskType(op_run_info->base_op_run_info.op_name);
434 
435   // Create view output tensor
436   CreateViewOpOutputs(op_run_info, view_input_tensor, task_type, storage_infos, is_tuple_output);
437 
438   if (op_run_info->requires_grad || task_type != runtime::KernelTaskType::kNORMAL_VIEW_TASK) {
439     const auto &top_cell = op_run_info->requires_grad ? grad()->top_cell() : nullptr;
440     for (size_t index = 0; index < op_run_info->input_size; ++index) {
441       const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
442       PyNativeAlgo::DataConvert::MarkInputs(op_run_info, input_object, index, top_cell);
443     }
444   }
445 
446   // Gil might be release  by ACL, so release here to reduce conflict
447   GilReleaseWithCheck release_gil;
448   ForwardRunViewKernelTask(op_run_info, task_type, false);
449   if (op_run_info->requires_grad) {
450     ForwardOpGradImpl(op_run_info);
451   }
452   MS_LOG(DEBUG) << "End";
453   return true;
454 }
455 
DispatchSilceOpFrontendTask(const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)456 void ForwardExecutor::DispatchSilceOpFrontendTask(const std::vector<ValuePtr> &input_values,
457                                                   const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
458                                                   const stub::StubNodePtr &stub_output) {
459   auto forward_task = std::make_shared<SliceOpFrontendTask>(
460     [this](const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos,
461            bool requires_grad, const stub::StubNodePtr &stub_output) {
462       (void)RunSliceOpFrontend(input_values, slice_op_infos, requires_grad, stub_output);
463     },
464     input_values, slice_op_infos, requires_grad, stub_output);
465   runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
466 }
467 
RunSliceOpFrontend(const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)468 ValuePtr ForwardExecutor::RunSliceOpFrontend(const std::vector<ValuePtr> &input_values,
469                                              const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
470                                              const stub::StubNodePtr &stub_output) {
471   if (input_values.empty()) {
472     MS_LOG(EXCEPTION) << "input_values is empty.";
473   }
474 
475   MS_LOG(DEBUG) << "Start, slice_op_infos size:" << slice_op_infos.size();
476   auto intermediate_tensor = input_values;
477   auto last_tensor = input_values[0];
478 
479   for (size_t i = 0; i < slice_op_infos.size(); i++) {
480     const auto &slice_op_info = slice_op_infos[i];
481     MS_EXCEPTION_IF_NULL(slice_op_info);
482     MS_LOG(DEBUG) << "Run slice op name:" << slice_op_info->slice_op_name;
483     MS_EXCEPTION_IF_CHECK_FAIL(!slice_op_info->data_indexs.empty(), "data_indexs can not be empty");
484     auto first_data_idx = slice_op_info->data_indexs[0];
485     if (first_data_idx >= intermediate_tensor.size()) {
486       MS_LOG(EXCEPTION) << "data_idx is out of bounds, data_idx:" << first_data_idx
487                         << " intermediate_tensor.size():" << intermediate_tensor.size();
488     }
489 
490     // Only last op need to update stub node.
491     auto cur_op_stub_output = (i + 1 == slice_op_infos.size() ? stub_output : nullptr);
492     auto op_run_info = GenerateSliceOpRunInfo(slice_op_info->slice_op_name, requires_grad, cur_op_stub_output);
493     if (slice_op_info->slice_op_name == kCastOpName) {
494       // slice_index_inputs of Cast op is type
495       MS_EXCEPTION_IF_CHECK_FAIL(slice_op_info->slice_index_inputs.size() == 1, "Size of cast type input should be 1");
496       auto type_value = slice_op_info->slice_index_inputs[0];
497       MS_EXCEPTION_IF_CHECK_FAIL(type_value->is_int(), "type_value should be int.");
498       auto type_id = static_cast<TypeId>(type_value->int_value());
499       (void)cast_operation()->DoNormalCast(op_run_info, intermediate_tensor[first_data_idx], type_id);
500     } else {
501       EmplaceSliceInputs(op_run_info, intermediate_tensor, slice_op_info);
502 
503       auto strides_calc_info =
504         ops::ViewStridesCalcFactory::GetInstance().GetStridesCalcFunc(op_run_info->base_op_run_info.op_name);
505       if (!strides_calc_info.first.has_value()) {
506         MS_LOG(EXCEPTION) << "op:" << op_run_info->base_op_run_info.op_name << " is not view.";
507       }
508       op_run_info->is_view_op = true;
509       PyNativeAlgo::Common::StubNodeToValue(op_run_info);
510       if (!ProcessViewOp(op_run_info, strides_calc_info.first.value(), strides_calc_info.second)) {
511         MS_EXCEPTION(ValueError) << "op:" << op_run_info->base_op_run_info.op_name << " inputs is not for view.";
512       }
513     }
514     intermediate_tensor[first_data_idx] = op_run_info->real_out;
515     last_tensor = op_run_info->real_out;
516   }
517   MS_LOG(DEBUG) << "End";
518   return last_tensor;
519 }
520 
RunOpFrontend(const FrontendOpRunInfoPtr & op_run_info)521 void ForwardExecutor::RunOpFrontend(const FrontendOpRunInfoPtr &op_run_info) {
522   MS_EXCEPTION_IF_NULL(op_run_info);
523   MS_LOG(DEBUG) << "RunOp name: " << op_run_info->base_op_run_info.op_name;
524 #ifndef ENABLE_TEST
525   auto strides_calc_info =
526     ops::ViewStridesCalcFactory::GetInstance().GetStridesCalcFunc(op_run_info->base_op_run_info.op_name);
527   op_run_info->is_view_op = strides_calc_info.first.has_value();
528 #endif
529 
530   // Convert StubNode to Tensor and no need to concern about input StubNode anymore in this thread.
531   PyNativeAlgo::Common::StubNodeToValue(op_run_info);
532   // 1.Set cast for inputs
533   SetCastForInputs(op_run_info);
534 
535 #ifndef ENABLE_TEST
536   if (op_run_info->is_view_op &&
537       ProcessViewOp(op_run_info, strides_calc_info.first.value(), strides_calc_info.second)) {
538     return;
539   }
540 #endif
541 
542   if (op_run_info->is_view_op) {
543     // Some special inputs cannot run view op, so need continuous inputs firstly, and set flag to false.
544     for (size_t i = 0; i < op_run_info->op_grad_info->input_value.size(); i++) {
545       op_run_info->op_grad_info->input_value[i] = PyNativeAlgo::Common::ConvertToContiguousValue(
546         op_run_info->op_grad_info->input_value[i], op_run_info->requires_grad);
547     }
548     op_run_info->is_view_op = false;
549   }
550 
551   // Infer output abstract
552   InferOutputAbstract(op_run_info);
553 
554   if (!(op_run_info->base_op_run_info.has_dynamic_output ||
555         OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name))) {
556     // Output is dynamic shape, need to SetAbstract after RunOp.
557     UpdateOutputStubNodeAbs(op_run_info);
558   }
559 
560   if (op_run_info->output_get_by_infer_value) {
561     UpdateOutputStubNodeValue(op_run_info);
562     MS_LOG(DEBUG) << "Grad flag: " << op_run_info->requires_grad
563                   << " output_get_by_infer_value: " << op_run_info->output_get_by_infer_value;
564     return;
565   }
566 
567   PrepareOpInputs(op_run_info);
568 
569   RunOpBackendSync(op_run_info);
570 }
571 
RunOpBackendSync(const FrontendOpRunInfoPtr & op_run_info)572 void ForwardExecutor::RunOpBackendSync(const FrontendOpRunInfoPtr &op_run_info) {
573   const auto &backend_op_run_info = CreateBackendOpRunInfo(op_run_info);
574   RunOpBackend(op_run_info, backend_op_run_info);
575   if (!op_run_info->requires_grad) {
576     MS_LOG(DEBUG) << "Grad flag is false";
577     UpdateStubTensor(op_run_info);
578     return;
579   }
580   // Do op grad and record op info
581   ForwardOpGradImpl(op_run_info);
582   // output is dynamic shape. Need to update abstract and value.
583   UpdateStubTensor(op_run_info);
584 }
585 
OpRunInfoUsePrimC(const FrontendOpRunInfoPtr & op_run_info) const586 void ForwardExecutor::OpRunInfoUsePrimC(const FrontendOpRunInfoPtr &op_run_info) const {
587   auto prim = op_run_info->op_grad_info->op_prim;
588   auto op_name = prim->name();
589   if (EnablePipeline(op_name) && expander::bprop::HasBpropExpander(op_name) &&
590       abstract::GetFrontendPrimitiveInferImpl(prim).has_value()) {
591     auto new_prim = std::make_shared<Primitive>(*prim);
592     new_prim->EnableSharedMutex();
593     op_run_info->op_grad_info->op_prim = new_prim;
594   }
595 }
596 
GetSlicePrimFromCache(const std::string & op_name)597 PrimitivePtr ForwardExecutor::GetSlicePrimFromCache(const std::string &op_name) {
598   auto iter = slice_prim_cache_.find(op_name);
599   if (iter != slice_prim_cache_.end()) {
600     return iter->second;
601   }
602 
603   auto prim = std::make_shared<Primitive>(op_name);
604   slice_prim_cache_[op_name] = prim;
605   return prim;
606 }
607 
GenerateSliceOpRunInfo(const std::string & op_name,bool requires_grad,const stub::StubNodePtr & stub_output)608 FrontendOpRunInfoPtr ForwardExecutor::GenerateSliceOpRunInfo(const std::string &op_name, bool requires_grad,
609                                                              const stub::StubNodePtr &stub_output) {
610   Init();
611   const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
612   op_run_info->base_op_run_info.op_name = op_name;
613   op_run_info->requires_grad = requires_grad;
614   op_run_info->base_op_run_info.device_target = device_target_;
615 
616   if (op_name == kCastOpName) {
617     // Cast prim will be set in DoNormalCast.
618     return op_run_info;
619   }
620 
621   if (op_run_info->requires_grad) {
622     op_run_info->op_grad_info->op_prim = GetSlicePrimFromCache(op_name);
623   }
624   op_run_info->stub_output = stub_output;
625   return op_run_info;
626 }
627 
GenerateOpRunInfo(const py::args & args,bool stub)628 FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args, bool stub) {
629   if (args.size() != static_cast<size_t>(RunOpArgsEnum::PY_ARGS_NUM)) {
630     MS_LOG(EXCEPTION) << "Three args are needed by RunOp";
631   }
632   Init();
633   const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
634   // Used for async run
635   op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
636   op_run_info->requires_grad = grad()->RequiresGrad();
637   if (op_run_info->requires_grad) {
638     op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
639   } else {
640     op_run_info->base_op_run_info.use_dynamic_shape_process =
641       grad()->forward_use_dynamic_shape_process() || grad()->use_dynamic_shape_process();
642   }
643   PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
644   OpRunInfoUsePrimC(op_run_info);
645   PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_INPUTS)],
646                                                   stub);
647   op_run_info->base_op_run_info.device_target = GetCurrentDeviceTarget(op_run_info->op_grad_info->op_prim);
648   bool is_dynamic_shape =
649     op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.use_dynamic_shape_process;
650   PyNativeAlgo::Common::GetConstInputToAttr(op_run_info->op_grad_info->op_prim, op_run_info->base_op_run_info.op_name,
651                                             op_run_info->base_op_run_info.device_target, is_dynamic_shape,
652                                             &op_run_info->input_to_attr);
653   bool is_dynamic_inputs = IsDynamicInputs(op_run_info);
654   if (!op_run_info->input_to_attr.empty() || is_dynamic_inputs) {
655     MS_LOG(DEBUG) << "Op_prim need clone:" << op_run_info->base_op_run_info.op_name
656                   << ", is_dynamic_inputs:" << is_dynamic_inputs
657                   << ", input_to_attr is not empty:" << (!op_run_info->input_to_attr.empty());
658     ClonePrim(op_run_info);
659   }
660 #ifndef ENABLE_TEST
661   // Obtaining device context may fail in UT
662   op_run_info->base_op_run_info.stream_id = GetCurStreamId(op_run_info->base_op_run_info.device_target);
663 #endif
664   op_run_info->cell_obj_id = GetCurrentCellObjId();
665   return op_run_info;
666 }
667 
SetCastForInputs(const FrontendOpRunInfoPtr & op_run_info) const668 void ForwardExecutor::SetCastForInputs(const FrontendOpRunInfoPtr &op_run_info) const {
669   MS_EXCEPTION_IF_NULL(op_run_info);
670   // No need cast self
671   if (op_run_info->base_op_run_info.op_name == prim::kPrimCast->name()) {
672     return;
673   }
674   cast_operation()->DoCast(op_run_info);
675 }
676 
ClearNodeAbsMap() const677 void ForwardExecutor::ClearNodeAbsMap() const { infer_operation()->ClearNodeAbsCache(); }
678 
SetNodeAbsMapByValue(const FrontendOpRunInfoPtr & op_run_info) const679 void ForwardExecutor::SetNodeAbsMapByValue(const FrontendOpRunInfoPtr &op_run_info) const {
680   infer_operation()->SetNodeAbsCacheByValue(op_run_info);
681 }
682 
SetNodeAbsMapById(const std::string & id,const abstract::AbstractBasePtr & abs) const683 void ForwardExecutor::SetNodeAbsMapById(const std::string &id, const abstract::AbstractBasePtr &abs) const {
684   infer_operation()->SetNodeAbsCacheById(id, abs);
685 }
686 
GetNodeAbsById(const std::string & id) const687 AbstractBasePtr ForwardExecutor::GetNodeAbsById(const std::string &id) const {
688   return infer_operation()->GetNodeAbsById(id);
689 }
690 
InferOutputAbstract(const FrontendOpRunInfoPtr & op_run_info) const691 void ForwardExecutor::InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) const {
692   infer_operation()->DoInfer(op_run_info);
693 }
694 
RunOpBackendInner(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)695 VectorRef ForwardExecutor::RunOpBackendInner(const FrontendOpRunInfoPtr &op_run_info,
696                                              const BackendOpRunInfoPtr &backend_op_run_info) {
697   MS_LOG(DEBUG) << "RunOpBackendInner start";
698   MS_EXCEPTION_IF_NULL(op_run_info);
699   auto ms_context = MsContext::GetInstance();
700   MS_EXCEPTION_IF_NULL(ms_context);
701   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
702 
703   VectorRef outputs;
704   const auto &cur_mind_rt_backend = GetMindRtBackend(backend_op_run_info->base_op_run_info.device_target);
705   MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
706   bool use_dynamic_shape_process = backend_op_run_info->base_op_run_info.use_dynamic_shape_process;
707   if (use_dynamic_shape_process) {
708     cur_mind_rt_backend->RunOpDynamic(backend_op_run_info, &outputs);
709   } else {
710     cur_mind_rt_backend->RunOp(backend_op_run_info, &outputs);
711   }
712 
713   if (op_run_info->base_op_run_info.has_dynamic_output ||
714       OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name)) {
715     op_run_info->base_op_run_info.abstract = backend_op_run_info->base_op_run_info.abstract;
716   }
717   op_run_info->op_grad_info->output_size = outputs.size();
718   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
719   MS_LOG(DEBUG) << "RunOpBackendInner end";
720   return outputs;
721 }
722 
RunOpBackend(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)723 void ForwardExecutor::RunOpBackend(const FrontendOpRunInfoPtr &op_run_info,
724                                    const BackendOpRunInfoPtr &backend_op_run_info) {
725   // Run op with selected backend, nop is no need run backend
726   op_run_info->real_out = RunOpWithBackendPolicy(op_run_info, backend_op_run_info);
727   // Not use GetNext abs
728   if (op_run_info->base_op_run_info.op_name != kGetNextOpName) {
729     op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
730     SetNodeAbsMapByValue(op_run_info);
731   }
732 }
733 
GetMindRtBackend(const string & cur_device_target)734 compile::MindRTBackendPtr ForwardExecutor::GetMindRtBackend(const string &cur_device_target) {
735   const auto iter = mindrt_backends_.find(cur_device_target);
736   if (iter != mindrt_backends_.end()) {
737     return iter->second;
738   } else {
739     auto ms_context = MsContext::GetInstance();
740     MS_EXCEPTION_IF_NULL(ms_context);
741     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
742     auto backend = std::make_shared<compile::MindRTBackend>("ms", cur_device_target, device_id);
743     MS_EXCEPTION_IF_NULL(backend);
744     mindrt_backends_[cur_device_target] = backend;
745     return backend;
746   }
747 }
748 
RunOpWithBackendPolicy(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)749 ValuePtr ForwardExecutor::RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info,
750                                                  const BackendOpRunInfoPtr &backend_op_run_info) {
751   MS_EXCEPTION_IF_NULL(op_run_info);
752 #ifndef ENABLE_TEST
753   if (IsVmOp(op_run_info->base_op_run_info.op_name)) {
754     return RunOpInVM(op_run_info);
755   } else {
756     return RunOpInMs(op_run_info, backend_op_run_info);
757   }
758 #else
759   return RunOpInVM(op_run_info);
760 #endif
761 }
762 
RunOpInVM(const FrontendOpRunInfoPtr & op_run_info) const763 ValuePtr ForwardExecutor::RunOpInVM(const FrontendOpRunInfoPtr &op_run_info) const {
764   MS_LOG(DEBUG) << "RunOpInVM start";
765   MS_EXCEPTION_IF_NULL(op_run_info);
766   op_run_info->run_in_vm = true;
767   if (op_run_info->requires_grad) {
768     for (size_t i = 0; i < op_run_info->input_size; i++) {
769       op_run_info->op_grad_info->input_value_grad_type[i] = PyNativeAlgo::Common::SetValueGradInfo(
770         op_run_info->op_grad_info->input_value[i], nullptr, InputType::kConstant);
771       (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(op_run_info->op_grad_info->input_value[i]);
772     }
773   }
774   if (IsVmOp(op_run_info->base_op_run_info.op_name)) {
775     std::vector<ValuePtr> result(op_run_info->input_size);
776     for (size_t i = 0; i < op_run_info->input_size; i++) {
777       result[i] = CopyTensorValueWithNewId(op_run_info->op_grad_info->input_value[i]);
778     }
779     auto result_v = ConstructOutputInVM(result);
780     if (op_run_info->requires_grad) {
781       op_run_info->op_grad_info->output_size = result.size();
782       (void)PyNativeAlgo::Common::SetValueGradInfo(result_v, nullptr, InputType::kOpOutput);
783     }
784     MS_LOG(DEBUG) << "RunOpInVM end";
785     return result_v;
786   }
787 
788   MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
789   py::list vm_op_inputs = py::list(op_run_info->input_size);
790   for (size_t i = 0; i < op_run_info->input_size; ++i) {
791     vm_op_inputs[i] = PyNativeAlgo::DataConvert::ValueToPyObj(op_run_info->op_grad_info->input_value[i]);
792   }
793   if (!utils::isa<PrimitivePy>(op_run_info->op_grad_info->op_prim)) {
794     MS_LOG(EXCEPTION) << "Not a PrimitivePy, " << op_run_info->op_grad_info->op_prim->ToString();
795   }
796   auto result = utils::cast<PrimitivePyPtr>(op_run_info->op_grad_info->op_prim)->RunPyComputeFunction(vm_op_inputs);
797   if (py::isinstance<py::none>(result)) {
798     MS_LOG(EXCEPTION) << "VM op " << op_run_info->base_op_run_info.op_name << " run failed!";
799   }
800   ValuePtr result_v = PyNativeAlgo::DataConvert::PyObjToValue(result);
801   if (!result_v->isa<ValueSequence>() && (op_run_info->base_op_run_info.abstract == nullptr ||
802                                           op_run_info->base_op_run_info.abstract->isa<abstract::AbstractSequence>())) {
803     result_v = std::make_shared<ValueTuple>(std::vector{result_v});
804   }
805   if (op_run_info->requires_grad) {
806     (void)PyNativeAlgo::Common::SetValueGradInfo(result_v, nullptr, InputType::kOpOutput);
807   }
808   op_run_info->op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(result_v);
809   MS_LOG(DEBUG) << "RunOpInVM end";
810   return result_v;
811 }
812 
CellNotSetMixedPrecision(const FrontendOpRunInfoPtr & op_run_info)813 bool ForwardExecutor::CellNotSetMixedPrecision(const FrontendOpRunInfoPtr &op_run_info) {
814   MS_EXCEPTION_IF_NULL(op_run_info);
815   const auto &cur_cell = forward_cell_stack_.top();
816   MS_EXCEPTION_IF_NULL(cur_cell);
817   MixedPrecisionType mix_type = cur_cell->GetMixedPrecisionType();
818   if (mix_type == kNotSet) {
819     return true;
820   }
821   op_run_info->mix_type = mix_type;
822   return false;
823 }
824 
ExecuteLazyTask() const825 void ForwardExecutor::ExecuteLazyTask() const {
826   runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
827   GilReleaseWithCheck gil_release;
828   runtime::OpExecutor::GetInstance().WaitAll();
829 }
830 
PrintPyObjInfo(const py::object & obj,const std::string & str,bool is_cell) const831 void ForwardExecutor::PrintPyObjInfo(const py::object &obj, const std::string &str, bool is_cell) const {
832   if (is_cell) {
833     MS_LOG(DEBUG) << str << " run " << obj.cast<CellPtr>()->ToString();
834     return;
835   }
836   MS_LOG(DEBUG) << str << " run python function " << py::getattr(obj, "__name__").cast<std::string>();
837 }
838 
ProcessBeforeNewGraph(const py::object & obj,const py::args & args)839 void ForwardExecutor::ProcessBeforeNewGraph(const py::object &obj, const py::args &args) {
840   bool is_cell = py::isinstance<Cell>(obj);
841   if (is_cell) {
842     auto cell = obj.cast<CellPtr>();
843     MS_EXCEPTION_IF_NULL(cell);
844     PushForwardCell(cell);
845     if (!grad()->RequiresGrad()) {
846       if (grad()->is_cell_has_dynamic_inputs(cell->id())) {
847         MS_LOG(DEBUG) << "obj id:" << cell->id() << " set forward use dynamic shape process true";
848         grad()->set_forward_use_dynamic_shape_process(true);
849 #ifndef ENABLE_SECURITY
850         ProfilerManager::GetInstance()->SetNetDynamicShapeStatus();
851 #endif
852       }
853     } else {
854       PrintPyObjInfo(obj, kBegin, is_cell);
855     }
856   }
857 }
858 
ProcessAfterNewGraph(const py::object & obj) const859 void ForwardExecutor::ProcessAfterNewGraph(const py::object &obj) const { grad()->SetTopCellDynamicAttr(obj); }
860 
ProcessBeforeEndGraph(const py::object & obj,bool is_cell)861 void ForwardExecutor::ProcessBeforeEndGraph(const py::object &obj, bool is_cell) {
862   if (is_cell) {
863     PopForwardCell();
864   }
865 
866   // Do some finishing work before end graph
867   if (IsFirstCell()) {
868     {
869       runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
870       GilReleaseWithCheck gil_release;
871       runtime::Pipeline::Get().frontend_stage()->Wait();
872     }
873     // Finish lazy task
874     ExecuteLazyTask();
875     if (!grad()->RequiresGrad()) {
876       ClearNodeAbsMap();
877     }
878     if (grad()->forward_use_dynamic_shape_process()) {
879       MS_LOG(DEBUG) << "first cell run end, set forward use dynamic shape process false";
880       grad()->set_forward_use_dynamic_shape_process(false);
881     }
882   }
883 }
884 
ProcessAfterEndGraph(const py::object & obj,bool is_cell) const885 void ForwardExecutor::ProcessAfterEndGraph(const py::object &obj, bool is_cell) const {
886   if (IsFirstCell()) {
887 #if defined(__APPLE__)
888     ClearNodeAbsMap();
889 #else
890     static const auto op_run_info = std::make_shared<FrontendOpRunInfo>();
891     auto forward_task = std::make_shared<FrontendTask>([this](...) { ClearNodeAbsMap(); }, op_run_info);
892     runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
893 #endif
894   }
895   PrintPyObjInfo(obj, kEnd, is_cell);
896 }
897 
GetCurrentDeviceTarget(const PrimitivePtr & op_prim) const898 std::string ForwardExecutor::GetCurrentDeviceTarget(const PrimitivePtr &op_prim) const {
899   MS_EXCEPTION_IF_NULL(op_prim);
900   PrimitiveReadLock read_lock(op_prim->shared_mutex());
901   const auto &attr_map = op_prim->attrs();
902   auto iter = attr_map.find("primitive_target");
903   if (iter != attr_map.end()) {
904     return GetValue<std::string>(iter->second);
905   }
906   return device_target_;
907 }
908 
Sync()909 void ForwardExecutor::Sync() {
910   ExecuteLazyTask();
911 
912   runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kSyncStream);
913   device::DeviceContextManager::GetInstance().SyncAllStreams();
914 }
915 
RunOpInMs(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)916 ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info,
917                                     const BackendOpRunInfoPtr &backend_op_run_info) {
918   if (!ScopedFallbackRunning::on()) {
919     GilReleaseWithCheck gil_relase;
920     return RunOpInMsInner(op_run_info, backend_op_run_info);
921   }
922   // Print the op running in JIT Fallback.
923   static const auto dump_fallback = (common::GetEnv("MS_DEV_FALLBACK_DUMP_NODE") == "1");
924   if (dump_fallback) {
925     MS_LOG(ERROR) << "NOTICE: The op is running in JIT Fallback:\n"
926                   << "primitive: " << op_run_info->op_grad_info->op_prim->ToString();
927   } else {
928     MS_LOG(INFO) << "NOTICE: The op is running in JIT Fallback:\n"
929                  << "primitive: " << op_run_info->op_grad_info->op_prim->ToString();
930   }
931   return RunOpInMsInner(op_run_info, backend_op_run_info);
932 }
933 
CreateInputAddressForViewOp(const tensor::BaseTensorPtr & input_tensor,const FrontendOpRunInfoPtr & op_run_info)934 void ForwardExecutor::CreateInputAddressForViewOp(const tensor::BaseTensorPtr &input_tensor,
935                                                   const FrontendOpRunInfoPtr &op_run_info) {
936   MS_EXCEPTION_IF_NULL(input_tensor);
937   bool is_cpu_address_exist = false;
938   const auto &device_sync = input_tensor->device_address();
939   if (device_sync != nullptr) {
940     auto tensor_address = std::static_pointer_cast<device::DeviceAddress>(device_sync);
941     MS_EXCEPTION_IF_NULL(tensor_address);
942     if (tensor_address->GetDeviceType() != device::DeviceType::kCPU) {
943       // If the address is a cpu address, need to check if the device-ptr is from pool in device thread(flag is not set
944       // yet). If the device-ptr is not from pool, need to recopy.
945       tensor_address->set_is_view(true);
946       return;
947     }
948     is_cpu_address_exist = true;
949   }
950 
951   const auto &device_context = runtime::OpRunner::GetDeviceContext(op_run_info->base_op_run_info.device_target);
952   MS_EXCEPTION_IF_NULL(device_context);
953 
954   // If the address exists means address is not from pool, no need to create adderss repeatedly.
955   // Just copy data.
956   if (!is_cpu_address_exist) {
957     MS_LOG(DEBUG) << "Input_tensor address is nullptr, need create address.";
958     auto address_size = GetTypeByte(input_tensor->Dtype()) * static_cast<size_t>(input_tensor->ElementsNum());
959     auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
960       nullptr, address_size, Format::DEFAULT_FORMAT, input_tensor->data_type(), input_tensor->shape(),
961       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
962     kernel_tensor->SetType(std::make_shared<TensorType>(input_tensor->Dtype()));
963     kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(input_tensor->shape()));
964     kernel_tensor->set_stream_id(op_run_info->base_op_run_info.stream_id);
965 
966     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
967     input_tensor->set_device_address(device_address);
968   }
969 
970   const auto &cur_mind_rt_backend = GetMindRtBackend(op_run_info->base_op_run_info.device_target);
971   MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
972   cur_mind_rt_backend->RunAllocMemTask(device_context, input_tensor, EnablePipeline(""), is_cpu_address_exist);
973 }
974 
TensorContiguousCallback(const DeviceSyncPtr & device_address,const TensorStorageInfoPtr & storage_info)975 device::DeviceAddressPtr ForwardExecutor::TensorContiguousCallback(const DeviceSyncPtr &device_address,
976                                                                    const TensorStorageInfoPtr &storage_info) {
977   MS_EXCEPTION_IF_NULL(device_address);
978   // Gil might be release  by ACL, so release here to reduce conflict
979   auto device_addr = std::dynamic_pointer_cast<device::DeviceAddress>(device_address);
980   MS_EXCEPTION_IF_NULL(device_addr);
981   if (storage_info == nullptr) {
982     return device_addr;
983   }
984 
985   // as_numpy sync promise contiguous run_sync
986   return runtime::DeviceAddressUtils::ConvertContiguousDeviceAddress(nullptr, device_addr, true);
987 }
988 
PrepareOpInputs(const FrontendOpRunInfoPtr & op_run_info)989 void ForwardExecutor::PrepareOpInputs(const FrontendOpRunInfoPtr &op_run_info) {
990   MS_EXCEPTION_IF_NULL(op_run_info);
991   PyNativeAlgo::DataConvert::GetInputTensor(op_run_info, op_run_info->requires_grad ? grad()->top_cell() : nullptr);
992   for (const auto &value : op_run_info->base_op_run_info.expanded_input_values) {
993     if (!value->isa<tensor::BaseTensor>()) {
994       continue;
995     }
996   }
997 }
998 
CreateViewOutputTensor(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & input_tensor,const TensorStorageInfoPtr & storage_info,runtime::KernelTaskType task_type)999 void ForwardExecutor::CreateViewOutputTensor(const FrontendOpRunInfoPtr &op_run_info,
1000                                              const tensor::BaseTensorPtr &input_tensor,
1001                                              const TensorStorageInfoPtr &storage_info,
1002                                              runtime::KernelTaskType task_type) {
1003   MS_EXCEPTION_IF_NULL(input_tensor);
1004   MS_EXCEPTION_IF_NULL(storage_info);
1005   auto output_tensor = std::make_shared<tensor::Tensor>(input_tensor->data_type(), storage_info->shape);
1006   output_tensor->set_need_pipeline_sync(true);
1007   output_tensor->set_contiguous_callback([this](const DeviceSyncPtr &device_address) -> DeviceSyncPtr {
1008     return TensorContiguousCallback(device_address, device_address->GetTensorStorageInfo());
1009   });
1010 
1011   auto input_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
1012   MS_EXCEPTION_IF_NULL(input_device_address);
1013   if (task_type == runtime::KernelTaskType::kCOPY_TASK) {
1014     input_device_address->kernel_tensor()->set_tensor_storage_info(storage_info);
1015   }
1016 
1017   // Create view output address
1018   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1019     nullptr, input_device_address->GetSize(), Format::DEFAULT_FORMAT, output_tensor->data_type(),
1020     output_tensor->shape(), input_device_address->device_name(), input_device_address->device_id());
1021   if (input_device_address->GetDeviceType() != device::DeviceType::kAscend) {
1022     // Not transmitting host shape information under Ascend for better performance.
1023     kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(output_tensor->data_type())));
1024     kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(output_tensor->shape()));
1025   }
1026   kernel_tensor->set_tensor_storage_info(storage_info);
1027   kernel_tensor->set_size(input_device_address->GetSize());
1028   kernel_tensor->set_stream_id(input_device_address->stream_id());
1029 
1030   const auto &device_context = runtime::OpRunner::GetDeviceContext(input_device_address->device_name());
1031   MS_EXCEPTION_IF_NULL(device_context);
1032   auto output_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1033   MS_EXCEPTION_IF_NULL(output_device_address);
1034 
1035   output_device_address->set_pointer_ref_count(input_device_address->pointer_ref_count());
1036   output_tensor->set_device_address(output_device_address);
1037   if (op_run_info->requires_grad) {
1038     output_tensor->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1039     output_tensor->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1040   }
1041   (void)op_run_info->base_op_run_info.output_tensors.emplace_back(output_tensor);
1042 }
1043 
RunOpInMsInner(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)1044 ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info,
1045                                          const BackendOpRunInfoPtr &backend_op_run_info) {
1046   const auto &outputs = RunOpBackendInner(op_run_info, backend_op_run_info);
1047   bool is_out_sequence = (op_run_info->base_op_run_info.abstract == nullptr ||
1048                           op_run_info->base_op_run_info.abstract->isa<abstract::AbstractSequence>());
1049   const auto &result_v =
1050     PyNativeAlgo::DataConvert::VectorRefToValue(outputs, op_run_info->requires_grad, is_out_sequence);
1051   MS_LOG(DEBUG) << "RunOpInMs end";
1052   return result_v;
1053 }
1054 
ClearRes()1055 void ForwardExecutor::ClearRes() {
1056   MS_LOG(DEBUG) << "Clear forward res";
1057   {
1058     GilReleaseWithCheck gil_release;
1059     runtime::Pipeline::Get().frontend_stage()->Clear();
1060   }
1061   for (const auto &item : mindrt_backends_) {
1062     MS_EXCEPTION_IF_NULL(item.second);
1063     item.second->ClearOpExecutorResource();
1064   }
1065   init_ = false;
1066   enable_async_ = false;
1067   is_jit_compiling_ = false;
1068   last_target_ = "Unknown";
1069   cast_operation()->ClearRes();
1070   ClearNodeAbsMap();
1071   infer_operation()->ClearPrimAbsList();
1072   infer_operation()->ClearConstFlagPrimCache();
1073   std::stack<CellPtr>().swap(forward_cell_stack_);
1074   mindrt_backends_.clear();
1075   slice_prim_cache_.clear();
1076 }
1077 
ChildAfterFork()1078 void ForwardExecutor::ChildAfterFork() {
1079   MS_LOG(DEBUG) << "ForwardExecutor reinitialize after fork.";
1080   MS_LOG(DEBUG) << "Reinitialize frontend_queue_.";
1081   runtime::Pipeline::Get().frontend_stage()->ChildAfterFork();
1082   MS_LOG(DEBUG) << "ForwardExecutor reinitialize after fork done.";
1083 }
1084 }  // namespace pynative
1085 }  // namespace mindspore
1086