1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/forward/forward.h"
18 #include <set>
19 #include <algorithm>
20 #include <unordered_set>
21 #include <vector>
22 #include "ops/structure_op_name.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "pipeline/pynative/pynative_utils.h"
26 #include "pybind_api/gil_scoped_long_running.h"
27 #include "include/common/utils/python_fallback_running.h"
28 #include "backend/graph_compiler/transform.h"
29 #include "utils/ms_context.h"
30 #include "pipeline/pynative/forward/forward_task.h"
31 #include "pipeline/pynative/predict_out_type_map.h"
32 #include "include/common/utils/stub_tensor.h"
33 #include "runtime/pynative/op_executor.h"
34 #ifndef ENABLE_SECURITY
35 #include "include/backend/debug/profiler/profiling.h"
36 using mindspore::profiler::ProfilerManager;
37 #endif
38 #include "frontend/operator/ops_front_infer_function.h"
39 #include "runtime/pipeline/pipeline.h"
40 #include "runtime/device/device_address_utils.h"
41
42 namespace mindspore {
43 namespace pynative {
44 enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
45 namespace {
46 const std::set<std::string> kVmOperators = {"InsertGradientOf", "StopGradient", "HookBackward", "CellBackwardHook"};
47 constexpr char kBegin[] = "Begin";
48 constexpr char kEnd[] = "End";
49 constexpr auto kOpNameCustom = "Custom";
50
51 // Shallow Copy Value and change shape
ShallowCopyValue(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & value)52 ValuePtr ShallowCopyValue(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &value) {
53 MS_EXCEPTION_IF_NULL(op_run_info);
54 MS_EXCEPTION_IF_NULL(value);
55 auto tensor_abs = op_run_info->base_op_run_info.abstract;
56 MS_EXCEPTION_IF_NULL(tensor_abs);
57 if (tensor_abs->isa<abstract::AbstractRefTensor>()) {
58 tensor_abs = tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
59 }
60 auto new_shape = tensor_abs->BuildShape()->cast<abstract::ShapePtr>();
61 MS_EXCEPTION_IF_NULL(new_shape);
62 if (value->isa<mindspore::tensor::BaseTensor>()) {
63 auto tensor_value = value->cast<mindspore::tensor::BaseTensorPtr>();
64 return std::make_shared<mindspore::tensor::Tensor>(tensor_value->data_type(), new_shape->shape(),
65 tensor_value->data_c(), tensor_value->Size());
66 }
67 if (value->isa<ValueTuple>()) {
68 std::vector<ValuePtr> values;
69 auto value_tuple = value->cast<ValueTuplePtr>();
70 (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
71 [op_run_info](const ValuePtr &elem) { return ShallowCopyValue(op_run_info, elem); });
72 return std::make_shared<ValueTuple>(values);
73 }
74 return value;
75 }
76
CopyTensorValueWithNewId(const ValuePtr & v)77 ValuePtr CopyTensorValueWithNewId(const ValuePtr &v) {
78 MS_EXCEPTION_IF_NULL(v);
79 if (v->isa<tensor::BaseTensor>()) {
80 auto tensor = v->cast<tensor::BaseTensorPtr>();
81 // This constructor will make a tensor with the new id
82 auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
83 new_tensor->set_need_pipeline_sync(true);
84 new_tensor->set_device_address(tensor->device_address());
85 new_tensor->set_sync_status(tensor->sync_status());
86 return new_tensor;
87 }
88 if (v->isa<ValueTuple>()) {
89 const auto &v_tup = v->cast<ValueTuplePtr>();
90 ValuePtrList list;
91 for (const auto &ele : v_tup->value()) {
92 (void)list.emplace_back(CopyTensorValueWithNewId(ele));
93 }
94 return std::make_shared<ValueTuple>(list);
95 }
96 if (v->isa<ValueList>()) {
97 const auto &v_list = v->cast<ValueListPtr>();
98 ValuePtrList list;
99 for (const auto &ele : v_list->value()) {
100 (void)list.emplace_back(CopyTensorValueWithNewId(ele));
101 }
102 return std::make_shared<ValueList>(list);
103 }
104 return v;
105 }
106
UpdateOutputStubNodeAbs(const FrontendOpRunInfoPtr & op_run_info)107 void UpdateOutputStubNodeAbs(const FrontendOpRunInfoPtr &op_run_info) {
108 if (op_run_info->stub_output == nullptr) {
109 return;
110 }
111 const auto &abs = op_run_info->base_op_run_info.abstract;
112 MS_EXCEPTION_IF_NULL(abs);
113 auto success = op_run_info->stub_output->SetAbstract(abs);
114 if (!success) {
115 const auto &op_name = op_run_info->base_op_run_info.op_name;
116 MS_EXCEPTION(TypeError) << "The predict type and infer type is not match, predict type is "
117 << PredictOutType(op_run_info) << ", infer type is " << abs->BuildType()
118 << ", the name of operator is [" << op_name
119 << "]. Please modify or add predict type of operator in predict_out_type_map.h.";
120 }
121 MS_LOG(DEBUG) << "Update StubNode abstract " << abs->ToString();
122 }
123
ClonePrim(const FrontendOpRunInfoPtr & op_run_info)124 void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
125 // Clone a new prim
126 MS_EXCEPTION_IF_NULL(op_run_info);
127 auto prim = op_run_info->op_grad_info->op_prim;
128 auto prim_py = prim->cast<PrimitivePyPtr>();
129 if (prim_py == nullptr) {
130 return;
131 }
132 auto new_adapter = std::make_shared<PrimitivePyAdapter>(*prim_py->adapter());
133 auto new_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_grad_info->op_prim->cast<PrimitivePyPtr>()));
134 new_prim->EnableSharedMutex();
135 op_run_info->op_grad_info->op_prim = new_prim;
136 MS_EXCEPTION_IF_NULL(new_adapter);
137 new_adapter->set_attached_primitive(new_prim);
138 }
139
IsDynamicInputs(const FrontendOpRunInfoPtr & op_run_info)140 bool IsDynamicInputs(const FrontendOpRunInfoPtr &op_run_info) {
141 for (const auto &value : op_run_info->op_grad_info->input_value) {
142 MS_EXCEPTION_IF_NULL(value);
143 if (value->isa<stub::SequenceNode>()) {
144 return true;
145 }
146 if (!value->isa<ValueSequence>()) {
147 continue;
148 }
149 auto value_seq = value->cast<ValueSequencePtr>();
150 MS_EXCEPTION_IF_NULL(value_seq);
151
152 const auto &tuple_inputs = value_seq->value();
153 if (tuple_inputs.empty()) {
154 continue;
155 }
156 if (tuple_inputs[0]->isa<tensor::BaseTensor>() || tuple_inputs[0]->isa<stub::TensorNode>()) {
157 return true;
158 }
159 }
160 return false;
161 }
162
ConstructOutputInVM(const std::vector<ValuePtr> & result)163 ValuePtr ConstructOutputInVM(const std::vector<ValuePtr> &result) {
164 if (result.size() == 1) {
165 return result[kIndex0];
166 }
167 return std::make_shared<ValueTuple>(result);
168 }
169
UpdateOutputStubNodeValue(const FrontendOpRunInfoPtr & op_run_info)170 void UpdateOutputStubNodeValue(const FrontendOpRunInfoPtr &op_run_info) {
171 if (op_run_info->stub_output != nullptr) {
172 op_run_info->stub_output->SetValue(op_run_info->real_out);
173 }
174 }
175
CreateBackendOpRunInfo(const FrontendOpRunInfoPtr & op_run_info)176 BackendOpRunInfoPtr CreateBackendOpRunInfo(const FrontendOpRunInfoPtr &op_run_info) {
177 auto backend_op_run_info = std::make_shared<BackendOpRunInfo>(
178 op_run_info->base_op_run_info, std::make_shared<Primitive>(*op_run_info->op_grad_info->op_prim), true, false);
179 // Erase RandomOp cache avoid memory leak.
180 if (AnfAlgo::NeedEraseCache(backend_op_run_info->op_prim)) {
181 backend_op_run_info->base_op_run_info.need_earse_cache = true;
182 }
183 if (op_run_info->base_op_run_info.has_dynamic_output) {
184 backend_op_run_info->base_op_run_info.use_dynamic_shape_process = true;
185 }
186 return backend_op_run_info;
187 }
188
UpdateStubTensor(const FrontendOpRunInfoPtr & op_run_info)189 void UpdateStubTensor(const FrontendOpRunInfoPtr &op_run_info) {
190 // Some operators do not have StubNodes, such as Cast inserted for automatic mixed precision.
191 if (op_run_info->stub_output != nullptr) {
192 if (op_run_info->base_op_run_info.has_dynamic_output ||
193 OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name)) {
194 UpdateOutputStubNodeAbs(op_run_info);
195 }
196 op_run_info->stub_output->SetValue(op_run_info->real_out);
197 }
198 }
199
GetViewOpTaskType(const std::string & op_name)200 runtime::KernelTaskType GetViewOpTaskType(const std::string &op_name) {
201 if (op_name == kCopyWithSliceOpName) {
202 return runtime::KernelTaskType::kCOPY_TASK;
203 }
204 return runtime::KernelTaskType::kNORMAL_VIEW_TASK;
205 }
206
EmplaceSliceInputs(const FrontendOpRunInfoPtr & op_run_info,const std::vector<ValuePtr> & input_values,const SliceOpInfoPtr & slice_op_info)207 void EmplaceSliceInputs(const FrontendOpRunInfoPtr &op_run_info, const std::vector<ValuePtr> &input_values,
208 const SliceOpInfoPtr &slice_op_info) {
209 for (auto idx : slice_op_info->data_indexs) {
210 if (idx >= input_values.size()) {
211 MS_LOG(EXCEPTION) << "data_idx is out of bounds, data_idx:" << idx
212 << " input_values.size():" << input_values.size();
213 }
214 (void)op_run_info->op_grad_info->input_value.emplace_back(input_values[idx]);
215 }
216
217 for (const auto &slice_index : slice_op_info->slice_index_inputs) {
218 ValuePtr v = nullptr;
219 if (slice_index->is_int()) {
220 v = MakeValue(slice_index->int_value());
221 } else {
222 v = MakeValue(slice_index->vec_value());
223 }
224
225 (void)op_run_info->op_grad_info->input_value.emplace_back(v);
226 }
227
228 if (op_run_info->requires_grad && op_run_info->base_op_run_info.op_name == kStridedSliceOpName) {
229 // StridedSlice mask input
230 int64_t v = 0;
231
232 (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v)); // begin_mask
233 (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v)); // end_mask
234 (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v)); // ellipsis_mask
235 (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v)); // new_axis_mask
236 (void)op_run_info->op_grad_info->input_value.emplace_back(MakeValue(v)); // shrink_new_mask
237 }
238
239 op_run_info->input_size = op_run_info->op_grad_info->input_value.size();
240 op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
241 }
242
243 #ifndef ENABLE_TEST
GetCurStreamId(const std::string & device_target)244 size_t GetCurStreamId(const std::string &device_target) {
245 auto device_ctx = runtime::OpRunner::GetDeviceContext(device_target);
246 return device_ctx->device_res_manager_->GetCurrentStreamId();
247 }
248 #endif
249 } // namespace
250
WaitForwardTask()251 void ForwardExecutor::WaitForwardTask() {
252 GilReleaseWithCheck gil_release;
253 runtime::Pipeline::Get().frontend_stage()->Wait();
254 }
255
IsVmOp(const std::string & op_name) const256 bool ForwardExecutor::IsVmOp(const std::string &op_name) const {
257 return kVmOperators.find(op_name) != kVmOperators.end();
258 }
259
GetCurrentCellObjId() const260 std::string ForwardExecutor::GetCurrentCellObjId() const {
261 if (forward_cell_stack_.empty()) {
262 return "";
263 }
264 auto &cell = forward_cell_stack_.top();
265 return cell->id();
266 }
267
grad() const268 GradExecutorPtr ForwardExecutor::grad() const {
269 auto grad_executor = grad_executor_.lock();
270 MS_EXCEPTION_IF_NULL(grad_executor);
271 return grad_executor;
272 }
273
InitOpRunInfo(const FrontendOpRunInfoPtr & op_run_info)274 void ForwardExecutor::InitOpRunInfo(const FrontendOpRunInfoPtr &op_run_info) {
275 Init();
276 // Used for async run
277 op_run_info->requires_grad = grad()->RequiresGrad();
278 if (op_run_info->requires_grad) {
279 op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
280 } else {
281 op_run_info->base_op_run_info.use_dynamic_shape_process =
282 grad()->forward_use_dynamic_shape_process() || grad()->use_dynamic_shape_process();
283 }
284 op_run_info->base_op_run_info.device_target = GetCurrentDeviceTarget(op_run_info->op_grad_info->op_prim);
285 op_run_info->cell_obj_id = GetCurrentCellObjId();
286 auto device_context = runtime::OpRunner::GetDeviceContext(op_run_info->base_op_run_info.device_target);
287 op_run_info->base_op_run_info.stream_id = device_context->device_res_manager_->GetCurrentStreamId();
288 }
289
ReInit()290 void ForwardExecutor::ReInit() {
291 device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
292 enable_async_ = !MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
293 }
294
Init()295 void ForwardExecutor::Init() {
296 ReInit();
297 if (init_) {
298 return;
299 }
300 init_ = true;
301 MS_LOG(DEBUG) << "Init ForwardExecutor";
302 compile::SetMindRTEnable();
303 python_adapter::set_python_env_flag(true);
304 runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() {
305 runtime::Pipeline::Get().frontend_stage()->Wait();
306 grad()->WaitBpropTask();
307 });
308 }
309
RefreshForwardCallback()310 void ForwardExecutor::RefreshForwardCallback() {
311 #if defined(_WIN32) || defined(_WIN64)
312 runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() {
313 runtime::Pipeline::Get().frontend_stage()->Wait();
314 grad()->WaitBpropTask();
315 });
316 #endif
317 // ForwardCallback has been set in ForwardExecutor::Init, no need to refresh anymore.
318 }
319
enable_async() const320 bool ForwardExecutor::enable_async() const {
321 #if defined(ENABLE_TEST) || defined(__APPLE__)
322 return false;
323 #else
324 return enable_async_;
325 #endif
326 }
327
EnablePipeline(const std::string & op_name) const328 bool ForwardExecutor::EnablePipeline(const std::string &op_name) const {
329 return enable_async() && !IsVmOp(op_name) && op_name != kOpNameCustom && !ScopedFallbackRunning::on() &&
330 !runtime::OpExecutor::NeedSync();
331 }
332
DispatchFrontendTask(const FrontendOpRunInfoPtr & op_run_info)333 void ForwardExecutor::DispatchFrontendTask(const FrontendOpRunInfoPtr &op_run_info) {
334 auto forward_task = std::make_shared<FrontendTask>(
335 [this](const FrontendOpRunInfoPtr &op_run_info) { RunOpFrontend(op_run_info); }, op_run_info);
336 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(forward_task->task_id());
337 runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
338 }
339
ForwardOpGradImpl(const FrontendOpRunInfoPtr & op_run_info) const340 void ForwardExecutor::ForwardOpGradImpl(const FrontendOpRunInfoPtr &op_run_info) const {
341 // If jit is compiled in first step, op info will not be find in second training step
342 MS_LOG(DEBUG) << "Current custom bprop cell count " << op_run_info->async_status.custom_bprop_cell_count;
343 if (!op_run_info->async_status.is_jit_compiling && op_run_info->async_status.custom_bprop_cell_count <= 0) {
344 grad()->ProcessOpGradInfo(op_run_info);
345 }
346 }
347
ForwardRunViewKernelTask(const FrontendOpRunInfoPtr & op_run_info,const runtime::KernelTaskType & task_type,bool enable_async)348 void ForwardExecutor::ForwardRunViewKernelTask(const FrontendOpRunInfoPtr &op_run_info,
349 const runtime::KernelTaskType &task_type, bool enable_async) {
350 if (task_type == runtime::KernelTaskType::kNORMAL_VIEW_TASK) {
351 return;
352 }
353 MS_LOG(DEBUG) << "Start, task_type:" << task_type;
354
355 const auto &cur_mind_rt_backend = GetMindRtBackend(op_run_info->base_op_run_info.device_target);
356 MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
357 cur_mind_rt_backend->RunViewKernelTask(op_run_info->base_op_run_info, task_type, enable_async);
358
359 MS_LOG(DEBUG) << "End";
360 }
361
CreateViewOpOutputs(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & view_input_tensor,runtime::KernelTaskType task_type,const TensorStorageInfoPtrList & storage_infos,bool is_tuple_output)362 void ForwardExecutor::CreateViewOpOutputs(const FrontendOpRunInfoPtr &op_run_info,
363 const tensor::BaseTensorPtr &view_input_tensor,
364 runtime::KernelTaskType task_type,
365 const TensorStorageInfoPtrList &storage_infos, bool is_tuple_output) {
366 const bool is_single_tensor_output = storage_infos.size() == 1 && !is_tuple_output;
367 // Generate output abs by storage_info.
368 if (is_single_tensor_output) {
369 op_run_info->base_op_run_info.abstract = abstract::MakeAbstractTensor(
370 std::make_shared<abstract::Shape>(storage_infos[0]->shape), view_input_tensor->Dtype());
371 } else {
372 AbstractBasePtrList abs_list;
373 for (const auto &storage_info : storage_infos) {
374 auto abs = abstract::MakeAbstractTensor(std::make_shared<abstract::Shape>(storage_info->shape),
375 view_input_tensor->Dtype());
376 (void)abs_list.emplace_back(abs);
377 }
378 op_run_info->base_op_run_info.abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
379 }
380
381 UpdateOutputStubNodeAbs(op_run_info);
382 CreateInputAddressForViewOp(view_input_tensor, op_run_info);
383
384 for (size_t i = 0; i < storage_infos.size(); i++) {
385 MS_LOG(DEBUG) << "View op " << op_run_info->base_op_run_info.op_name << ", i:" << i
386 << ", storage_info:" << storage_infos[i]->ToString();
387 CreateViewOutputTensor(op_run_info, view_input_tensor, storage_infos[i], task_type);
388 }
389
390 if (is_single_tensor_output) {
391 op_run_info->real_out = op_run_info->base_op_run_info.output_tensors[0];
392 op_run_info->op_grad_info->output_size = 1;
393 } else {
394 std::vector<ValuePtr> output_values;
395 (void)std::transform(op_run_info->base_op_run_info.output_tensors.begin(),
396 op_run_info->base_op_run_info.output_tensors.end(), std::back_inserter(output_values),
397 [](const auto &t) {
398 MS_EXCEPTION_IF_NULL(t);
399 return t;
400 });
401 op_run_info->real_out = std::make_shared<ValueTuple>(output_values);
402 op_run_info->op_grad_info->output_size = output_values.size();
403 }
404
405 UpdateOutputStubNodeValue(op_run_info);
406 }
407
ProcessViewOp(const FrontendOpRunInfoPtr & op_run_info,const ops::StridesCalcFunc & strides_calc_func,bool is_tuple_output)408 bool ForwardExecutor::ProcessViewOp(const FrontendOpRunInfoPtr &op_run_info,
409 const ops::StridesCalcFunc &strides_calc_func, bool is_tuple_output) {
410 MS_LOG(DEBUG) << "Start, op:" << op_run_info->base_op_run_info.op_name;
411 if (op_run_info->op_grad_info->input_value.empty()) {
412 MS_LOG(EXCEPTION) << "op_run_info->op_grad_info->input_value is empty";
413 }
414
415 // Only split and chunk has mul outputs, and input tensor is first input.
416 auto view_value = op_run_info->op_grad_info->input_value[0];
417 MS_EXCEPTION_IF_NULL(view_value);
418 if (!view_value->isa<tensor::BaseTensor>()) {
419 MS_EXCEPTION(TypeError) << "For primitive[" << op_run_info->base_op_run_info.op_name
420 << "], the input[0] should be Tensor, but got:" << view_value->ToString();
421 }
422 auto view_input_tensor = view_value->cast<tensor::BaseTensorPtr>();
423 MS_EXCEPTION_IF_NULL(view_input_tensor);
424
425 auto storage_infos = strides_calc_func(op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_value);
426 if (storage_infos.empty()) {
427 MS_LOG(DEBUG) << "Not View op " << op_run_info->base_op_run_info.op_name;
428 return false;
429 }
430
431 // Reuse SetInputAbstract, abs of inputs is need when requires_grad is true.
432 InferOutputAbstract(op_run_info);
433 runtime::KernelTaskType task_type = GetViewOpTaskType(op_run_info->base_op_run_info.op_name);
434
435 // Create view output tensor
436 CreateViewOpOutputs(op_run_info, view_input_tensor, task_type, storage_infos, is_tuple_output);
437
438 if (op_run_info->requires_grad || task_type != runtime::KernelTaskType::kNORMAL_VIEW_TASK) {
439 const auto &top_cell = op_run_info->requires_grad ? grad()->top_cell() : nullptr;
440 for (size_t index = 0; index < op_run_info->input_size; ++index) {
441 const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
442 PyNativeAlgo::DataConvert::MarkInputs(op_run_info, input_object, index, top_cell);
443 }
444 }
445
446 // Gil might be release by ACL, so release here to reduce conflict
447 GilReleaseWithCheck release_gil;
448 ForwardRunViewKernelTask(op_run_info, task_type, false);
449 if (op_run_info->requires_grad) {
450 ForwardOpGradImpl(op_run_info);
451 }
452 MS_LOG(DEBUG) << "End";
453 return true;
454 }
455
DispatchSilceOpFrontendTask(const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)456 void ForwardExecutor::DispatchSilceOpFrontendTask(const std::vector<ValuePtr> &input_values,
457 const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
458 const stub::StubNodePtr &stub_output) {
459 auto forward_task = std::make_shared<SliceOpFrontendTask>(
460 [this](const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos,
461 bool requires_grad, const stub::StubNodePtr &stub_output) {
462 (void)RunSliceOpFrontend(input_values, slice_op_infos, requires_grad, stub_output);
463 },
464 input_values, slice_op_infos, requires_grad, stub_output);
465 runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
466 }
467
RunSliceOpFrontend(const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)468 ValuePtr ForwardExecutor::RunSliceOpFrontend(const std::vector<ValuePtr> &input_values,
469 const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
470 const stub::StubNodePtr &stub_output) {
471 if (input_values.empty()) {
472 MS_LOG(EXCEPTION) << "input_values is empty.";
473 }
474
475 MS_LOG(DEBUG) << "Start, slice_op_infos size:" << slice_op_infos.size();
476 auto intermediate_tensor = input_values;
477 auto last_tensor = input_values[0];
478
479 for (size_t i = 0; i < slice_op_infos.size(); i++) {
480 const auto &slice_op_info = slice_op_infos[i];
481 MS_EXCEPTION_IF_NULL(slice_op_info);
482 MS_LOG(DEBUG) << "Run slice op name:" << slice_op_info->slice_op_name;
483 MS_EXCEPTION_IF_CHECK_FAIL(!slice_op_info->data_indexs.empty(), "data_indexs can not be empty");
484 auto first_data_idx = slice_op_info->data_indexs[0];
485 if (first_data_idx >= intermediate_tensor.size()) {
486 MS_LOG(EXCEPTION) << "data_idx is out of bounds, data_idx:" << first_data_idx
487 << " intermediate_tensor.size():" << intermediate_tensor.size();
488 }
489
490 // Only last op need to update stub node.
491 auto cur_op_stub_output = (i + 1 == slice_op_infos.size() ? stub_output : nullptr);
492 auto op_run_info = GenerateSliceOpRunInfo(slice_op_info->slice_op_name, requires_grad, cur_op_stub_output);
493 if (slice_op_info->slice_op_name == kCastOpName) {
494 // slice_index_inputs of Cast op is type
495 MS_EXCEPTION_IF_CHECK_FAIL(slice_op_info->slice_index_inputs.size() == 1, "Size of cast type input should be 1");
496 auto type_value = slice_op_info->slice_index_inputs[0];
497 MS_EXCEPTION_IF_CHECK_FAIL(type_value->is_int(), "type_value should be int.");
498 auto type_id = static_cast<TypeId>(type_value->int_value());
499 (void)cast_operation()->DoNormalCast(op_run_info, intermediate_tensor[first_data_idx], type_id);
500 } else {
501 EmplaceSliceInputs(op_run_info, intermediate_tensor, slice_op_info);
502
503 auto strides_calc_info =
504 ops::ViewStridesCalcFactory::GetInstance().GetStridesCalcFunc(op_run_info->base_op_run_info.op_name);
505 if (!strides_calc_info.first.has_value()) {
506 MS_LOG(EXCEPTION) << "op:" << op_run_info->base_op_run_info.op_name << " is not view.";
507 }
508 op_run_info->is_view_op = true;
509 PyNativeAlgo::Common::StubNodeToValue(op_run_info);
510 if (!ProcessViewOp(op_run_info, strides_calc_info.first.value(), strides_calc_info.second)) {
511 MS_EXCEPTION(ValueError) << "op:" << op_run_info->base_op_run_info.op_name << " inputs is not for view.";
512 }
513 }
514 intermediate_tensor[first_data_idx] = op_run_info->real_out;
515 last_tensor = op_run_info->real_out;
516 }
517 MS_LOG(DEBUG) << "End";
518 return last_tensor;
519 }
520
RunOpFrontend(const FrontendOpRunInfoPtr & op_run_info)521 void ForwardExecutor::RunOpFrontend(const FrontendOpRunInfoPtr &op_run_info) {
522 MS_EXCEPTION_IF_NULL(op_run_info);
523 MS_LOG(DEBUG) << "RunOp name: " << op_run_info->base_op_run_info.op_name;
524 #ifndef ENABLE_TEST
525 auto strides_calc_info =
526 ops::ViewStridesCalcFactory::GetInstance().GetStridesCalcFunc(op_run_info->base_op_run_info.op_name);
527 op_run_info->is_view_op = strides_calc_info.first.has_value();
528 #endif
529
530 // Convert StubNode to Tensor and no need to concern about input StubNode anymore in this thread.
531 PyNativeAlgo::Common::StubNodeToValue(op_run_info);
532 // 1.Set cast for inputs
533 SetCastForInputs(op_run_info);
534
535 #ifndef ENABLE_TEST
536 if (op_run_info->is_view_op &&
537 ProcessViewOp(op_run_info, strides_calc_info.first.value(), strides_calc_info.second)) {
538 return;
539 }
540 #endif
541
542 if (op_run_info->is_view_op) {
543 // Some special inputs cannot run view op, so need continuous inputs firstly, and set flag to false.
544 for (size_t i = 0; i < op_run_info->op_grad_info->input_value.size(); i++) {
545 op_run_info->op_grad_info->input_value[i] = PyNativeAlgo::Common::ConvertToContiguousValue(
546 op_run_info->op_grad_info->input_value[i], op_run_info->requires_grad);
547 }
548 op_run_info->is_view_op = false;
549 }
550
551 // Infer output abstract
552 InferOutputAbstract(op_run_info);
553
554 if (!(op_run_info->base_op_run_info.has_dynamic_output ||
555 OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name))) {
556 // Output is dynamic shape, need to SetAbstract after RunOp.
557 UpdateOutputStubNodeAbs(op_run_info);
558 }
559
560 if (op_run_info->output_get_by_infer_value) {
561 UpdateOutputStubNodeValue(op_run_info);
562 MS_LOG(DEBUG) << "Grad flag: " << op_run_info->requires_grad
563 << " output_get_by_infer_value: " << op_run_info->output_get_by_infer_value;
564 return;
565 }
566
567 PrepareOpInputs(op_run_info);
568
569 RunOpBackendSync(op_run_info);
570 }
571
RunOpBackendSync(const FrontendOpRunInfoPtr & op_run_info)572 void ForwardExecutor::RunOpBackendSync(const FrontendOpRunInfoPtr &op_run_info) {
573 const auto &backend_op_run_info = CreateBackendOpRunInfo(op_run_info);
574 RunOpBackend(op_run_info, backend_op_run_info);
575 if (!op_run_info->requires_grad) {
576 MS_LOG(DEBUG) << "Grad flag is false";
577 UpdateStubTensor(op_run_info);
578 return;
579 }
580 // Do op grad and record op info
581 ForwardOpGradImpl(op_run_info);
582 // output is dynamic shape. Need to update abstract and value.
583 UpdateStubTensor(op_run_info);
584 }
585
OpRunInfoUsePrimC(const FrontendOpRunInfoPtr & op_run_info) const586 void ForwardExecutor::OpRunInfoUsePrimC(const FrontendOpRunInfoPtr &op_run_info) const {
587 auto prim = op_run_info->op_grad_info->op_prim;
588 auto op_name = prim->name();
589 if (EnablePipeline(op_name) && expander::bprop::HasBpropExpander(op_name) &&
590 abstract::GetFrontendPrimitiveInferImpl(prim).has_value()) {
591 auto new_prim = std::make_shared<Primitive>(*prim);
592 new_prim->EnableSharedMutex();
593 op_run_info->op_grad_info->op_prim = new_prim;
594 }
595 }
596
GetSlicePrimFromCache(const std::string & op_name)597 PrimitivePtr ForwardExecutor::GetSlicePrimFromCache(const std::string &op_name) {
598 auto iter = slice_prim_cache_.find(op_name);
599 if (iter != slice_prim_cache_.end()) {
600 return iter->second;
601 }
602
603 auto prim = std::make_shared<Primitive>(op_name);
604 slice_prim_cache_[op_name] = prim;
605 return prim;
606 }
607
GenerateSliceOpRunInfo(const std::string & op_name,bool requires_grad,const stub::StubNodePtr & stub_output)608 FrontendOpRunInfoPtr ForwardExecutor::GenerateSliceOpRunInfo(const std::string &op_name, bool requires_grad,
609 const stub::StubNodePtr &stub_output) {
610 Init();
611 const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
612 op_run_info->base_op_run_info.op_name = op_name;
613 op_run_info->requires_grad = requires_grad;
614 op_run_info->base_op_run_info.device_target = device_target_;
615
616 if (op_name == kCastOpName) {
617 // Cast prim will be set in DoNormalCast.
618 return op_run_info;
619 }
620
621 if (op_run_info->requires_grad) {
622 op_run_info->op_grad_info->op_prim = GetSlicePrimFromCache(op_name);
623 }
624 op_run_info->stub_output = stub_output;
625 return op_run_info;
626 }
627
GenerateOpRunInfo(const py::args & args,bool stub)628 FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args, bool stub) {
629 if (args.size() != static_cast<size_t>(RunOpArgsEnum::PY_ARGS_NUM)) {
630 MS_LOG(EXCEPTION) << "Three args are needed by RunOp";
631 }
632 Init();
633 const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
634 // Used for async run
635 op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
636 op_run_info->requires_grad = grad()->RequiresGrad();
637 if (op_run_info->requires_grad) {
638 op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
639 } else {
640 op_run_info->base_op_run_info.use_dynamic_shape_process =
641 grad()->forward_use_dynamic_shape_process() || grad()->use_dynamic_shape_process();
642 }
643 PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
644 OpRunInfoUsePrimC(op_run_info);
645 PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_INPUTS)],
646 stub);
647 op_run_info->base_op_run_info.device_target = GetCurrentDeviceTarget(op_run_info->op_grad_info->op_prim);
648 bool is_dynamic_shape =
649 op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.use_dynamic_shape_process;
650 PyNativeAlgo::Common::GetConstInputToAttr(op_run_info->op_grad_info->op_prim, op_run_info->base_op_run_info.op_name,
651 op_run_info->base_op_run_info.device_target, is_dynamic_shape,
652 &op_run_info->input_to_attr);
653 bool is_dynamic_inputs = IsDynamicInputs(op_run_info);
654 if (!op_run_info->input_to_attr.empty() || is_dynamic_inputs) {
655 MS_LOG(DEBUG) << "Op_prim need clone:" << op_run_info->base_op_run_info.op_name
656 << ", is_dynamic_inputs:" << is_dynamic_inputs
657 << ", input_to_attr is not empty:" << (!op_run_info->input_to_attr.empty());
658 ClonePrim(op_run_info);
659 }
660 #ifndef ENABLE_TEST
661 // Obtaining device context may fail in UT
662 op_run_info->base_op_run_info.stream_id = GetCurStreamId(op_run_info->base_op_run_info.device_target);
663 #endif
664 op_run_info->cell_obj_id = GetCurrentCellObjId();
665 return op_run_info;
666 }
667
SetCastForInputs(const FrontendOpRunInfoPtr & op_run_info) const668 void ForwardExecutor::SetCastForInputs(const FrontendOpRunInfoPtr &op_run_info) const {
669 MS_EXCEPTION_IF_NULL(op_run_info);
670 // No need cast self
671 if (op_run_info->base_op_run_info.op_name == prim::kPrimCast->name()) {
672 return;
673 }
674 cast_operation()->DoCast(op_run_info);
675 }
676
ClearNodeAbsMap() const677 void ForwardExecutor::ClearNodeAbsMap() const { infer_operation()->ClearNodeAbsCache(); }
678
SetNodeAbsMapByValue(const FrontendOpRunInfoPtr & op_run_info) const679 void ForwardExecutor::SetNodeAbsMapByValue(const FrontendOpRunInfoPtr &op_run_info) const {
680 infer_operation()->SetNodeAbsCacheByValue(op_run_info);
681 }
682
SetNodeAbsMapById(const std::string & id,const abstract::AbstractBasePtr & abs) const683 void ForwardExecutor::SetNodeAbsMapById(const std::string &id, const abstract::AbstractBasePtr &abs) const {
684 infer_operation()->SetNodeAbsCacheById(id, abs);
685 }
686
GetNodeAbsById(const std::string & id) const687 AbstractBasePtr ForwardExecutor::GetNodeAbsById(const std::string &id) const {
688 return infer_operation()->GetNodeAbsById(id);
689 }
690
InferOutputAbstract(const FrontendOpRunInfoPtr & op_run_info) const691 void ForwardExecutor::InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) const {
692 infer_operation()->DoInfer(op_run_info);
693 }
694
RunOpBackendInner(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)695 VectorRef ForwardExecutor::RunOpBackendInner(const FrontendOpRunInfoPtr &op_run_info,
696 const BackendOpRunInfoPtr &backend_op_run_info) {
697 MS_LOG(DEBUG) << "RunOpBackendInner start";
698 MS_EXCEPTION_IF_NULL(op_run_info);
699 auto ms_context = MsContext::GetInstance();
700 MS_EXCEPTION_IF_NULL(ms_context);
701 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
702
703 VectorRef outputs;
704 const auto &cur_mind_rt_backend = GetMindRtBackend(backend_op_run_info->base_op_run_info.device_target);
705 MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
706 bool use_dynamic_shape_process = backend_op_run_info->base_op_run_info.use_dynamic_shape_process;
707 if (use_dynamic_shape_process) {
708 cur_mind_rt_backend->RunOpDynamic(backend_op_run_info, &outputs);
709 } else {
710 cur_mind_rt_backend->RunOp(backend_op_run_info, &outputs);
711 }
712
713 if (op_run_info->base_op_run_info.has_dynamic_output ||
714 OpCompiler::GetInstance().IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name)) {
715 op_run_info->base_op_run_info.abstract = backend_op_run_info->base_op_run_info.abstract;
716 }
717 op_run_info->op_grad_info->output_size = outputs.size();
718 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
719 MS_LOG(DEBUG) << "RunOpBackendInner end";
720 return outputs;
721 }
722
RunOpBackend(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)723 void ForwardExecutor::RunOpBackend(const FrontendOpRunInfoPtr &op_run_info,
724 const BackendOpRunInfoPtr &backend_op_run_info) {
725 // Run op with selected backend, nop is no need run backend
726 op_run_info->real_out = RunOpWithBackendPolicy(op_run_info, backend_op_run_info);
727 // Not use GetNext abs
728 if (op_run_info->base_op_run_info.op_name != kGetNextOpName) {
729 op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
730 SetNodeAbsMapByValue(op_run_info);
731 }
732 }
733
GetMindRtBackend(const string & cur_device_target)734 compile::MindRTBackendPtr ForwardExecutor::GetMindRtBackend(const string &cur_device_target) {
735 const auto iter = mindrt_backends_.find(cur_device_target);
736 if (iter != mindrt_backends_.end()) {
737 return iter->second;
738 } else {
739 auto ms_context = MsContext::GetInstance();
740 MS_EXCEPTION_IF_NULL(ms_context);
741 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
742 auto backend = std::make_shared<compile::MindRTBackend>("ms", cur_device_target, device_id);
743 MS_EXCEPTION_IF_NULL(backend);
744 mindrt_backends_[cur_device_target] = backend;
745 return backend;
746 }
747 }
748
RunOpWithBackendPolicy(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)749 ValuePtr ForwardExecutor::RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info,
750 const BackendOpRunInfoPtr &backend_op_run_info) {
751 MS_EXCEPTION_IF_NULL(op_run_info);
752 #ifndef ENABLE_TEST
753 if (IsVmOp(op_run_info->base_op_run_info.op_name)) {
754 return RunOpInVM(op_run_info);
755 } else {
756 return RunOpInMs(op_run_info, backend_op_run_info);
757 }
758 #else
759 return RunOpInVM(op_run_info);
760 #endif
761 }
762
RunOpInVM(const FrontendOpRunInfoPtr & op_run_info) const763 ValuePtr ForwardExecutor::RunOpInVM(const FrontendOpRunInfoPtr &op_run_info) const {
764 MS_LOG(DEBUG) << "RunOpInVM start";
765 MS_EXCEPTION_IF_NULL(op_run_info);
766 op_run_info->run_in_vm = true;
767 if (op_run_info->requires_grad) {
768 for (size_t i = 0; i < op_run_info->input_size; i++) {
769 op_run_info->op_grad_info->input_value_grad_type[i] = PyNativeAlgo::Common::SetValueGradInfo(
770 op_run_info->op_grad_info->input_value[i], nullptr, InputType::kConstant);
771 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(op_run_info->op_grad_info->input_value[i]);
772 }
773 }
774 if (IsVmOp(op_run_info->base_op_run_info.op_name)) {
775 std::vector<ValuePtr> result(op_run_info->input_size);
776 for (size_t i = 0; i < op_run_info->input_size; i++) {
777 result[i] = CopyTensorValueWithNewId(op_run_info->op_grad_info->input_value[i]);
778 }
779 auto result_v = ConstructOutputInVM(result);
780 if (op_run_info->requires_grad) {
781 op_run_info->op_grad_info->output_size = result.size();
782 (void)PyNativeAlgo::Common::SetValueGradInfo(result_v, nullptr, InputType::kOpOutput);
783 }
784 MS_LOG(DEBUG) << "RunOpInVM end";
785 return result_v;
786 }
787
788 MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
789 py::list vm_op_inputs = py::list(op_run_info->input_size);
790 for (size_t i = 0; i < op_run_info->input_size; ++i) {
791 vm_op_inputs[i] = PyNativeAlgo::DataConvert::ValueToPyObj(op_run_info->op_grad_info->input_value[i]);
792 }
793 if (!utils::isa<PrimitivePy>(op_run_info->op_grad_info->op_prim)) {
794 MS_LOG(EXCEPTION) << "Not a PrimitivePy, " << op_run_info->op_grad_info->op_prim->ToString();
795 }
796 auto result = utils::cast<PrimitivePyPtr>(op_run_info->op_grad_info->op_prim)->RunPyComputeFunction(vm_op_inputs);
797 if (py::isinstance<py::none>(result)) {
798 MS_LOG(EXCEPTION) << "VM op " << op_run_info->base_op_run_info.op_name << " run failed!";
799 }
800 ValuePtr result_v = PyNativeAlgo::DataConvert::PyObjToValue(result);
801 if (!result_v->isa<ValueSequence>() && (op_run_info->base_op_run_info.abstract == nullptr ||
802 op_run_info->base_op_run_info.abstract->isa<abstract::AbstractSequence>())) {
803 result_v = std::make_shared<ValueTuple>(std::vector{result_v});
804 }
805 if (op_run_info->requires_grad) {
806 (void)PyNativeAlgo::Common::SetValueGradInfo(result_v, nullptr, InputType::kOpOutput);
807 }
808 op_run_info->op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(result_v);
809 MS_LOG(DEBUG) << "RunOpInVM end";
810 return result_v;
811 }
812
CellNotSetMixedPrecision(const FrontendOpRunInfoPtr & op_run_info)813 bool ForwardExecutor::CellNotSetMixedPrecision(const FrontendOpRunInfoPtr &op_run_info) {
814 MS_EXCEPTION_IF_NULL(op_run_info);
815 const auto &cur_cell = forward_cell_stack_.top();
816 MS_EXCEPTION_IF_NULL(cur_cell);
817 MixedPrecisionType mix_type = cur_cell->GetMixedPrecisionType();
818 if (mix_type == kNotSet) {
819 return true;
820 }
821 op_run_info->mix_type = mix_type;
822 return false;
823 }
824
ExecuteLazyTask() const825 void ForwardExecutor::ExecuteLazyTask() const {
826 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
827 GilReleaseWithCheck gil_release;
828 runtime::OpExecutor::GetInstance().WaitAll();
829 }
830
PrintPyObjInfo(const py::object & obj,const std::string & str,bool is_cell) const831 void ForwardExecutor::PrintPyObjInfo(const py::object &obj, const std::string &str, bool is_cell) const {
832 if (is_cell) {
833 MS_LOG(DEBUG) << str << " run " << obj.cast<CellPtr>()->ToString();
834 return;
835 }
836 MS_LOG(DEBUG) << str << " run python function " << py::getattr(obj, "__name__").cast<std::string>();
837 }
838
ProcessBeforeNewGraph(const py::object & obj,const py::args & args)839 void ForwardExecutor::ProcessBeforeNewGraph(const py::object &obj, const py::args &args) {
840 bool is_cell = py::isinstance<Cell>(obj);
841 if (is_cell) {
842 auto cell = obj.cast<CellPtr>();
843 MS_EXCEPTION_IF_NULL(cell);
844 PushForwardCell(cell);
845 if (!grad()->RequiresGrad()) {
846 if (grad()->is_cell_has_dynamic_inputs(cell->id())) {
847 MS_LOG(DEBUG) << "obj id:" << cell->id() << " set forward use dynamic shape process true";
848 grad()->set_forward_use_dynamic_shape_process(true);
849 #ifndef ENABLE_SECURITY
850 ProfilerManager::GetInstance()->SetNetDynamicShapeStatus();
851 #endif
852 }
853 } else {
854 PrintPyObjInfo(obj, kBegin, is_cell);
855 }
856 }
857 }
858
ProcessAfterNewGraph(const py::object & obj) const859 void ForwardExecutor::ProcessAfterNewGraph(const py::object &obj) const { grad()->SetTopCellDynamicAttr(obj); }
860
ProcessBeforeEndGraph(const py::object & obj,bool is_cell)861 void ForwardExecutor::ProcessBeforeEndGraph(const py::object &obj, bool is_cell) {
862 if (is_cell) {
863 PopForwardCell();
864 }
865
866 // Do some finishing work before end graph
867 if (IsFirstCell()) {
868 {
869 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
870 GilReleaseWithCheck gil_release;
871 runtime::Pipeline::Get().frontend_stage()->Wait();
872 }
873 // Finish lazy task
874 ExecuteLazyTask();
875 if (!grad()->RequiresGrad()) {
876 ClearNodeAbsMap();
877 }
878 if (grad()->forward_use_dynamic_shape_process()) {
879 MS_LOG(DEBUG) << "first cell run end, set forward use dynamic shape process false";
880 grad()->set_forward_use_dynamic_shape_process(false);
881 }
882 }
883 }
884
ProcessAfterEndGraph(const py::object & obj,bool is_cell) const885 void ForwardExecutor::ProcessAfterEndGraph(const py::object &obj, bool is_cell) const {
886 if (IsFirstCell()) {
887 #if defined(__APPLE__)
888 ClearNodeAbsMap();
889 #else
890 static const auto op_run_info = std::make_shared<FrontendOpRunInfo>();
891 auto forward_task = std::make_shared<FrontendTask>([this](...) { ClearNodeAbsMap(); }, op_run_info);
892 runtime::Pipeline::Get().frontend_stage()->Push(forward_task);
893 #endif
894 }
895 PrintPyObjInfo(obj, kEnd, is_cell);
896 }
897
GetCurrentDeviceTarget(const PrimitivePtr & op_prim) const898 std::string ForwardExecutor::GetCurrentDeviceTarget(const PrimitivePtr &op_prim) const {
899 MS_EXCEPTION_IF_NULL(op_prim);
900 PrimitiveReadLock read_lock(op_prim->shared_mutex());
901 const auto &attr_map = op_prim->attrs();
902 auto iter = attr_map.find("primitive_target");
903 if (iter != attr_map.end()) {
904 return GetValue<std::string>(iter->second);
905 }
906 return device_target_;
907 }
908
Sync()909 void ForwardExecutor::Sync() {
910 ExecuteLazyTask();
911
912 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kSyncStream);
913 device::DeviceContextManager::GetInstance().SyncAllStreams();
914 }
915
RunOpInMs(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)916 ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info,
917 const BackendOpRunInfoPtr &backend_op_run_info) {
918 if (!ScopedFallbackRunning::on()) {
919 GilReleaseWithCheck gil_relase;
920 return RunOpInMsInner(op_run_info, backend_op_run_info);
921 }
922 // Print the op running in JIT Fallback.
923 static const auto dump_fallback = (common::GetEnv("MS_DEV_FALLBACK_DUMP_NODE") == "1");
924 if (dump_fallback) {
925 MS_LOG(ERROR) << "NOTICE: The op is running in JIT Fallback:\n"
926 << "primitive: " << op_run_info->op_grad_info->op_prim->ToString();
927 } else {
928 MS_LOG(INFO) << "NOTICE: The op is running in JIT Fallback:\n"
929 << "primitive: " << op_run_info->op_grad_info->op_prim->ToString();
930 }
931 return RunOpInMsInner(op_run_info, backend_op_run_info);
932 }
933
CreateInputAddressForViewOp(const tensor::BaseTensorPtr & input_tensor,const FrontendOpRunInfoPtr & op_run_info)934 void ForwardExecutor::CreateInputAddressForViewOp(const tensor::BaseTensorPtr &input_tensor,
935 const FrontendOpRunInfoPtr &op_run_info) {
936 MS_EXCEPTION_IF_NULL(input_tensor);
937 bool is_cpu_address_exist = false;
938 const auto &device_sync = input_tensor->device_address();
939 if (device_sync != nullptr) {
940 auto tensor_address = std::static_pointer_cast<device::DeviceAddress>(device_sync);
941 MS_EXCEPTION_IF_NULL(tensor_address);
942 if (tensor_address->GetDeviceType() != device::DeviceType::kCPU) {
943 // If the address is a cpu address, need to check if the device-ptr is from pool in device thread(flag is not set
944 // yet). If the device-ptr is not from pool, need to recopy.
945 tensor_address->set_is_view(true);
946 return;
947 }
948 is_cpu_address_exist = true;
949 }
950
951 const auto &device_context = runtime::OpRunner::GetDeviceContext(op_run_info->base_op_run_info.device_target);
952 MS_EXCEPTION_IF_NULL(device_context);
953
954 // If the address exists means address is not from pool, no need to create adderss repeatedly.
955 // Just copy data.
956 if (!is_cpu_address_exist) {
957 MS_LOG(DEBUG) << "Input_tensor address is nullptr, need create address.";
958 auto address_size = GetTypeByte(input_tensor->Dtype()) * static_cast<size_t>(input_tensor->ElementsNum());
959 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
960 nullptr, address_size, Format::DEFAULT_FORMAT, input_tensor->data_type(), input_tensor->shape(),
961 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
962 kernel_tensor->SetType(std::make_shared<TensorType>(input_tensor->Dtype()));
963 kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(input_tensor->shape()));
964 kernel_tensor->set_stream_id(op_run_info->base_op_run_info.stream_id);
965
966 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
967 input_tensor->set_device_address(device_address);
968 }
969
970 const auto &cur_mind_rt_backend = GetMindRtBackend(op_run_info->base_op_run_info.device_target);
971 MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
972 cur_mind_rt_backend->RunAllocMemTask(device_context, input_tensor, EnablePipeline(""), is_cpu_address_exist);
973 }
974
TensorContiguousCallback(const DeviceSyncPtr & device_address,const TensorStorageInfoPtr & storage_info)975 device::DeviceAddressPtr ForwardExecutor::TensorContiguousCallback(const DeviceSyncPtr &device_address,
976 const TensorStorageInfoPtr &storage_info) {
977 MS_EXCEPTION_IF_NULL(device_address);
978 // Gil might be release by ACL, so release here to reduce conflict
979 auto device_addr = std::dynamic_pointer_cast<device::DeviceAddress>(device_address);
980 MS_EXCEPTION_IF_NULL(device_addr);
981 if (storage_info == nullptr) {
982 return device_addr;
983 }
984
985 // as_numpy sync promise contiguous run_sync
986 return runtime::DeviceAddressUtils::ConvertContiguousDeviceAddress(nullptr, device_addr, true);
987 }
988
PrepareOpInputs(const FrontendOpRunInfoPtr & op_run_info)989 void ForwardExecutor::PrepareOpInputs(const FrontendOpRunInfoPtr &op_run_info) {
990 MS_EXCEPTION_IF_NULL(op_run_info);
991 PyNativeAlgo::DataConvert::GetInputTensor(op_run_info, op_run_info->requires_grad ? grad()->top_cell() : nullptr);
992 for (const auto &value : op_run_info->base_op_run_info.expanded_input_values) {
993 if (!value->isa<tensor::BaseTensor>()) {
994 continue;
995 }
996 }
997 }
998
CreateViewOutputTensor(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & input_tensor,const TensorStorageInfoPtr & storage_info,runtime::KernelTaskType task_type)999 void ForwardExecutor::CreateViewOutputTensor(const FrontendOpRunInfoPtr &op_run_info,
1000 const tensor::BaseTensorPtr &input_tensor,
1001 const TensorStorageInfoPtr &storage_info,
1002 runtime::KernelTaskType task_type) {
1003 MS_EXCEPTION_IF_NULL(input_tensor);
1004 MS_EXCEPTION_IF_NULL(storage_info);
1005 auto output_tensor = std::make_shared<tensor::Tensor>(input_tensor->data_type(), storage_info->shape);
1006 output_tensor->set_need_pipeline_sync(true);
1007 output_tensor->set_contiguous_callback([this](const DeviceSyncPtr &device_address) -> DeviceSyncPtr {
1008 return TensorContiguousCallback(device_address, device_address->GetTensorStorageInfo());
1009 });
1010
1011 auto input_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
1012 MS_EXCEPTION_IF_NULL(input_device_address);
1013 if (task_type == runtime::KernelTaskType::kCOPY_TASK) {
1014 input_device_address->kernel_tensor()->set_tensor_storage_info(storage_info);
1015 }
1016
1017 // Create view output address
1018 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1019 nullptr, input_device_address->GetSize(), Format::DEFAULT_FORMAT, output_tensor->data_type(),
1020 output_tensor->shape(), input_device_address->device_name(), input_device_address->device_id());
1021 if (input_device_address->GetDeviceType() != device::DeviceType::kAscend) {
1022 // Not transmitting host shape information under Ascend for better performance.
1023 kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(output_tensor->data_type())));
1024 kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(output_tensor->shape()));
1025 }
1026 kernel_tensor->set_tensor_storage_info(storage_info);
1027 kernel_tensor->set_size(input_device_address->GetSize());
1028 kernel_tensor->set_stream_id(input_device_address->stream_id());
1029
1030 const auto &device_context = runtime::OpRunner::GetDeviceContext(input_device_address->device_name());
1031 MS_EXCEPTION_IF_NULL(device_context);
1032 auto output_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1033 MS_EXCEPTION_IF_NULL(output_device_address);
1034
1035 output_device_address->set_pointer_ref_count(input_device_address->pointer_ref_count());
1036 output_tensor->set_device_address(output_device_address);
1037 if (op_run_info->requires_grad) {
1038 output_tensor->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1039 output_tensor->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1040 }
1041 (void)op_run_info->base_op_run_info.output_tensors.emplace_back(output_tensor);
1042 }
1043
RunOpInMsInner(const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)1044 ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info,
1045 const BackendOpRunInfoPtr &backend_op_run_info) {
1046 const auto &outputs = RunOpBackendInner(op_run_info, backend_op_run_info);
1047 bool is_out_sequence = (op_run_info->base_op_run_info.abstract == nullptr ||
1048 op_run_info->base_op_run_info.abstract->isa<abstract::AbstractSequence>());
1049 const auto &result_v =
1050 PyNativeAlgo::DataConvert::VectorRefToValue(outputs, op_run_info->requires_grad, is_out_sequence);
1051 MS_LOG(DEBUG) << "RunOpInMs end";
1052 return result_v;
1053 }
1054
ClearRes()1055 void ForwardExecutor::ClearRes() {
1056 MS_LOG(DEBUG) << "Clear forward res";
1057 {
1058 GilReleaseWithCheck gil_release;
1059 runtime::Pipeline::Get().frontend_stage()->Clear();
1060 }
1061 for (const auto &item : mindrt_backends_) {
1062 MS_EXCEPTION_IF_NULL(item.second);
1063 item.second->ClearOpExecutorResource();
1064 }
1065 init_ = false;
1066 enable_async_ = false;
1067 is_jit_compiling_ = false;
1068 last_target_ = "Unknown";
1069 cast_operation()->ClearRes();
1070 ClearNodeAbsMap();
1071 infer_operation()->ClearPrimAbsList();
1072 infer_operation()->ClearConstFlagPrimCache();
1073 std::stack<CellPtr>().swap(forward_cell_stack_);
1074 mindrt_backends_.clear();
1075 slice_prim_cache_.clear();
1076 }
1077
ChildAfterFork()1078 void ForwardExecutor::ChildAfterFork() {
1079 MS_LOG(DEBUG) << "ForwardExecutor reinitialize after fork.";
1080 MS_LOG(DEBUG) << "Reinitialize frontend_queue_.";
1081 runtime::Pipeline::Get().frontend_stage()->ChildAfterFork();
1082 MS_LOG(DEBUG) << "ForwardExecutor reinitialize after fork done.";
1083 }
1084 } // namespace pynative
1085 } // namespace mindspore
1086