• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/any_type_kernel_actor.h"
18 #include <set>
19 #include <functional>
20 #include "include/common/debug/anf_ir_dump.h"
21 #include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "include/common/fallback.h"
24 #include "include/common/utils/stub_tensor.h"
25 #include "include/backend/py_execute_utils.h"
26 
27 namespace mindspore {
28 namespace runtime {
29 namespace {
30 using AddressPtr = kernel::AddressPtr;
31 using PyExecuteOutputUserData = kernel::PyExecuteOutputUserData;
32 }  // namespace
33 
34 std::mutex AnyTypeKernelActor::instance_lock_;
35 
AnyTypeKernelActor(const std::string & name,const KernelGraphPtr & graph,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,KernelTransformType type)36 AnyTypeKernelActor::AnyTypeKernelActor(const std::string &name, const KernelGraphPtr &graph,
37                                        const DeviceContext *device_context, const AID &memory_manager_aid,
38                                        const AID *debug_aid, const AID *recorder_aid, KernelTransformType type)
39     : SuperKernelActor(name, graph, device_context, memory_manager_aid, debug_aid, recorder_aid, type) {}
40 
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)41 void AnyTypeKernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
42   MS_EXCEPTION_IF_NULL(input_data);
43   MS_EXCEPTION_IF_NULL(input_data->data_);
44   MS_EXCEPTION_IF_NULL(input_data->data_->kernel_tensor());
45   MS_EXCEPTION_IF_NULL(context);
46   MS_EXCEPTION_IF_NULL(graph());
47   auto &sequential_num = context->sequential_num_;
48   if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
49       !TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed)) {
50     MS_LOG(EXCEPTION) << "The input_data does not have a valid ptr of actor:" << GetAID().Name()
51                       << " with index:" << input_data->index_ << ", flag:" << input_data->data_->flag()
52                       << " device address:" << input_data->data_ << " ref count:" << input_data->data_->ref_count()
53                       << " dynamic ref count:" << input_data->data_->dynamic_ref_count()
54                       << " origin ref count:" << input_data->data_->original_ref_count();
55   }
56   MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data:" << input_data->data_
57                 << " index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
58                 << " ptr:" << input_data->data_->GetPtr() << " user data:" << input_data->data_->user_data()
59                 << " input num:" << input_datas_num_ << " input device tensor size:" << input_device_tensors_.size()
60                 << " ref count:" << input_data->data_->ref_count()
61                 << " dynamic ref count:" << input_data->data_->dynamic_ref_count()
62                 << " origin ref count:" << input_data->data_->original_ref_count()
63                 << " user data:" << input_data->data_->user_data()
64                 << " type:" << input_data->data_->kernel_tensor()->GetType()
65                 << " type id:" << input_data->data_->kernel_tensor()->type_id();
66   if (input_data->index_ < SizeToLong(graph()->input_nodes().size())) {
67     // Collect graph input data.
68     input_op_datas_[sequential_num].emplace_back(input_data);
69     if (CheckRunningCondition(context)) {
70       MS_LOG(DEBUG) << "Begin wait runtime pipeline to run for graph input for actor: " << GetAID().Name();
71       if (!WaitRuntimePipelineFinish(context)) {
72         MS_LOG(INFO) << "Run failed and early stop.";
73         return;
74       }
75       MS_LOG(DEBUG) << "End wait runtime pipeline to run for graph input for actor: " << GetAID().Name();
76       RunForGraphInput(context);
77     }
78   } else {
79     // Collect graph output data.
80     graph_output_op_data_[sequential_num].emplace_back(input_data);
81     if (CheckGraphOutputRunningCondition(context)) {
82       MS_LOG(DEBUG) << "Begin wait runtime pipeline to run for graph output for actor: " << GetAID().Name();
83       if (!WaitRuntimePipelineFinish(context)) {
84         MS_LOG(INFO) << "Run failed and early stop.";
85         return;
86       }
87       MS_LOG(DEBUG) << "End wait runtime pipeline to run for graph output for actor: " << GetAID().Name();
88       RunForGraphOutput(context);
89     }
90   }
91 }
92 
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)93 void AnyTypeKernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
94   MS_EXCEPTION_IF_NULL(context);
95   MS_EXCEPTION_IF_NULL(input_control);
96   auto &sequential_num = context->sequential_num_;
97   MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op control:" << input_control->Name();
98   if (std::any_of(
99         input_control_arrow_aids_.begin(), input_control_arrow_aids_.end(),
100         [input_control](const auto &arrow_pair) { return arrow_pair.first.Name() == input_control->Name(); })) {
101     (void)input_op_controls_[sequential_num].emplace_back(input_control);
102     if (CheckRunningCondition(context)) {
103       if (!WaitRuntimePipelineFinish(context)) {
104         MS_LOG(INFO) << "Run failed and early stop.";
105         return;
106       }
107       RunForGraphInput(context);
108     }
109   } else {
110     graph_output_op_control_[sequential_num].emplace_back(input_control);
111     if (CheckGraphOutputRunningCondition(context)) {
112       if (!WaitRuntimePipelineFinish(context)) {
113         MS_LOG(INFO) << "Run failed and early stop.";
114         return;
115       }
116       RunForGraphOutput(context);
117     }
118   }
119 }
120 
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)121 void AnyTypeKernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
122   MS_EXCEPTION_IF_NULL(context);
123   std::vector<DeviceTensor *> memory_free_list = graph_ouput_device_tensors_;
124   const auto &data_iter = input_op_datas_.find(context->sequential_num_);
125   if (data_iter == input_op_datas_.end()) {
126     memory_free_lists_.push(memory_free_list);
127     return;
128   }
129   for (auto &input_data : data_iter->second) {
130     MS_EXCEPTION_IF_NULL(input_data);
131     MS_EXCEPTION_IF_NULL(input_data->data_);
132     size_t index = IntToSize(input_data->index_);
133     if (index >= input_device_tensors_.size()) {
134       std::string error_info = "Invalid input index:" + std::to_string(index) +
135                                " total:" + std::to_string(input_device_tensors_.size()) +
136                                " for actor:" + GetAID().Name();
137       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
138     }
139     input_device_tensors_[index] = input_data->data_;
140     if (input_data->data_->ref_count() != SIZE_MAX) {
141       (void)memory_free_list.emplace_back(input_data->data_);
142     }
143   }
144   memory_free_lists_.push(memory_free_list);
145 
146   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
147     MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
148     if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
149       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
150                                                     "Invalid device context for any type actor:" + GetAID().Name());
151     }
152     auto device_tensor = DeviceTensorStore::GetInstance()
153                            .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
154                            .get();
155     if (device_tensor == nullptr) {
156       MS_LOG_WITH_NODE(EXCEPTION, device_tensor_store_key.second)
157         << "Failed get device tensor for node:" << device_tensor_store_key.second->DebugString()
158         << " index:" << device_tensor_store_key.first << " device type:" << device_contexts_[0]->GetDeviceType();
159       continue;
160     }
161     if (device_tensor_store_key.first >= input_device_tensors_.size()) {
162       std::string error_info = "Invalid input index:" + std::to_string(device_tensor_store_key.first) +
163                                " total:" + std::to_string(input_device_tensors_.size()) +
164                                " for actor:" + GetAID().Name();
165       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
166     }
167     input_device_tensors_[device_tensor_store_key.first] = device_tensor;
168   }
169 }
170 
CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> * context)171 bool AnyTypeKernelActor::CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> *context) {
172   MS_EXCEPTION_IF_NULL(context);
173   MS_LOG(DEBUG) << "graph output data num:" << graph_output_data_num_[current_data_type_]
174                 << " control num:" << graph_output_control_num_[current_data_type_];
175   if (graph_output_data_num_[current_data_type_] != 0) {
176     const auto &data_iter = graph_output_op_data_.find(context->sequential_num_);
177     if (data_iter == graph_output_op_data_.end()) {
178       return false;
179     }
180     if (data_iter->second.size() < graph_output_data_num_[current_data_type_]) {
181       return false;
182     } else if (data_iter->second.size() > graph_output_data_num_[current_data_type_]) {
183       MS_LOG(ERROR) << "Invalid graph output data num:" << data_iter->second.size()
184                     << " need:" << graph_output_data_num_[current_data_type_] << " for actor:" << GetAID()
185                     << ", sequential num:" << context->sequential_num_;
186       return false;
187     }
188   }
189 
190   if (graph_output_control_num_[current_data_type_] != 0) {
191     const auto &control_iter = graph_output_op_control_.find(context->sequential_num_);
192     if (control_iter == graph_output_op_control_.end()) {
193       return false;
194     }
195     if (control_iter->second.size() < graph_output_control_num_[current_data_type_]) {
196       return false;
197     } else if (control_iter->second.size() > graph_output_control_num_[current_data_type_]) {
198       MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size()
199                     << " need:" << graph_output_control_num_[current_data_type_] << " for actor:" << GetAID()
200                     << ", sequential num:" << context->sequential_num_;
201       return false;
202     }
203   }
204   return true;
205 }
206 namespace {
BuildSegmentByGraph(const KernelGraphPtr & graph)207 GraphSegmentPtr BuildSegmentByGraph(const KernelGraphPtr &graph) {
208   MS_EXCEPTION_IF_NULL(graph);
209   std::vector<AnfNodePtr> nodes;
210   std::vector<AnfNodePtr> all_nodes = TopoSort(graph->get_return());
211   for (const auto &node : all_nodes) {
212     if (node == nullptr || (!node->isa<CNode>()) || common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
213       continue;
214     }
215     MS_LOG(DEBUG) << "build new segment node:" << node->DebugString();
216     nodes.emplace_back(node);
217   }
218   return std::make_shared<GraphSegment>(nodes, false);
219 }
220 
GenerateIDForGraph(const std::vector<DeviceTensor * > & device_tensors,const std::vector<size_t> & indexes)221 std::string GenerateIDForGraph(const std::vector<DeviceTensor *> &device_tensors, const std::vector<size_t> &indexes) {
222   std::string id;
223   auto get_shape_and_type_string = [&id](const ShapeVector &shape_vector, TypeId type_id) {
224     id += "shape_";
225     (void)std::for_each(shape_vector.begin(), shape_vector.end(), [&id](int64_t shape) {
226       id += std::to_string(shape);
227       id += "_";
228     });
229     id = id + "type_" + std::to_string(type_id) + "_";
230   };
231   for (const auto &index : indexes) {
232     if (index >= device_tensors.size()) {
233       MS_LOG(EXCEPTION) << "Invalid parameter index:" << index << " for device tensor num:" << device_tensors.size();
234     }
235     id = id + "index_" + std::to_string(index) + "_";
236     const auto &device_tensor = device_tensors[index];
237     if (device_tensor == nullptr) {
238       MS_LOG(EXCEPTION) << "Empty device tensor index:" << index;
239     }
240     if (device_tensor->user_data() == nullptr) {
241       device_tensor->kernel_tensor()->SetType(device_tensor->kernel_tensor()->GetType());
242       device_tensor->kernel_tensor()->SetShape(device_tensor->kernel_tensor()->GetShape());
243       get_shape_and_type_string(device_tensor->host_shape(), device_tensor->type_id());
244       continue;
245     }
246 
247     const auto &user_data_obj =
248       device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
249     MS_EXCEPTION_IF_NULL(user_data_obj);
250     const auto &obj = user_data_obj->obj;
251     py::gil_scoped_acquire gil_acquire;
252     const auto &abstract = pyexecute::GenerateAbstractFromPyObject(obj);
253     MS_EXCEPTION_IF_NULL(abstract);
254     if (abstract->isa<abstract::AbstractSequence>()) {
255       auto sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
256       MS_EXCEPTION_IF_NULL(sequence_abs);
257       id = id + "Tuple_" + std::to_string(sequence_abs->size()) + "_";
258     } else if (abstract->isa<abstract::AbstractScalar>()) {
259       id = id + "Scalar_";
260     } else if (abstract->isa<abstract::AbstractTensor>()) {
261       id = id + "Tensor_";
262     }
263     device_tensor->kernel_tensor()->SetType(abstract->BuildType());
264     device_tensor->kernel_tensor()->SetShape(abstract->BuildShape());
265     get_shape_and_type_string(device_tensor->host_shape(), device_tensor->type_id());
266   }
267   return id;
268 }
269 
InferParameterAbstractForModelGraph(const KernelGraphPtr & graph,const std::vector<DeviceTensor * > & device_tensors,const std::vector<size_t> & indexes)270 void InferParameterAbstractForModelGraph(const KernelGraphPtr &graph, const std::vector<DeviceTensor *> &device_tensors,
271                                          const std::vector<size_t> &indexes) {
272   MS_EXCEPTION_IF_NULL(graph);
273   for (size_t index : indexes) {
274     if (index >= device_tensors.size() || index >= graph->input_nodes().size()) {
275       MS_LOG(EXCEPTION) << "Invalid index:" << index << " for input device tensor size:" << device_tensors.size()
276                         << " for graph:" << graph->ToString();
277     }
278     const auto &device_tensor = device_tensors[index];
279     MS_EXCEPTION_IF_NULL(device_tensor);
280     MS_EXCEPTION_IF_NULL(device_tensor->kernel_tensor());
281     auto input_node = graph->input_nodes()[index];
282     MS_EXCEPTION_IF_NULL(input_node);
283     abstract::AbstractBasePtr abstract;
284     if (device_tensor->user_data() != nullptr &&
285         device_tensor->user_data()->has(kernel::PyExecuteOutputUserData::key)) {
286       MS_LOG(DEBUG) << "User data:" << device_tensor->user_data() << " in device address:" << device_tensor
287                     << " for input:" << input_node->DebugString();
288       const auto &user_data_obj =
289         device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
290       MS_EXCEPTION_IF_NULL(user_data_obj);
291       const auto &obj = user_data_obj->obj;
292       py::gil_scoped_acquire gil_acquire;
293       abstract = pyexecute::GenerateAbstractFromPyObject(obj);
294     } else {
295       abstract =
296         abstract::MakeAbstract(device_tensor->kernel_tensor()->GetShape(), device_tensor->kernel_tensor()->GetType());
297     }
298     MS_EXCEPTION_IF_NULL(abstract);
299     MS_LOG(DEBUG) << "Infer parameter by abstract:" << abstract->ToString();
300     if (!abstract->isa<abstract::AbstractSequence>()) {
301       MS_LOG(DEBUG) << "Set abstract:" << abstract->ToString() << " for input node:" << input_node->DebugString()
302                     << " device tensor:" << device_tensor << " type id:" << device_tensor->type_id();
303       input_node->set_abstract(abstract);
304       continue;
305     }
306     MS_LOG(DEBUG) << "Sequence abstract:" << abstract->ToString();
307     auto new_abstract = abstract->Clone();
308     MS_EXCEPTION_IF_NULL(new_abstract);
309     auto seq_abstract = new_abstract->cast<abstract::AbstractSequencePtr>();
310     MS_EXCEPTION_IF_NULL(seq_abstract);
311     seq_abstract->set_dynamic_len(true);
312     // Dynamic len element is used to check if the sequence is dynamic len.
313     if (!seq_abstract->elements().empty() && seq_abstract->elements()[0] != nullptr) {
314       seq_abstract->set_dynamic_len_element_abs(seq_abstract->elements()[0]->Clone());
315     }
316     MS_LOG(DEBUG) << "Set abstract:" << seq_abstract->ToString() << " for input node:" << input_node->DebugString()
317                   << device_tensor << " type id:" << device_tensor->type_id();
318     input_node->set_abstract(seq_abstract);
319   }
320 }
321 
GetElementType(const abstract::AbstractBasePtr & abstract)322 TypeId GetElementType(const abstract::AbstractBasePtr &abstract) {
323   MS_EXCEPTION_IF_NULL(abstract);
324   TypePtr type = nullptr;
325   if (abstract->isa<abstract::AbstractScalar>()) {
326     type = abstract->BuildType();
327   } else if (abstract->isa<abstract::AbstractTensor>()) {
328     const auto &tensor_abs = abstract->cast<abstract::AbstractTensorPtr>();
329     MS_EXCEPTION_IF_NULL(tensor_abs);
330     MS_EXCEPTION_IF_NULL(tensor_abs->element());
331     type = tensor_abs->element()->BuildType();
332   } else if (abstract->isa<abstract::AbstractSequence>()) {
333     const auto &sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
334     MS_EXCEPTION_IF_NULL(sequence_abs);
335     if (sequence_abs->dynamic_len() || sequence_abs->elements().empty() || sequence_abs->elements()[0] == nullptr) {
336       MS_LOG(INFO) << "Invalid abstract:" << abstract->ToString();
337       return TypeId::kNumberTypeInt64;
338     }
339     return GetElementType(sequence_abs->elements()[0]);
340   } else {
341     MS_LOG(EXCEPTION) << "Invalid abstract:" << abstract->ToString();
342   }
343   MS_EXCEPTION_IF_NULL(type);
344   return type->type_id();
345 }
346 }  // namespace
347 
UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> * const context)348 void AnyTypeKernelActor::UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> *const context) {
349   MS_EXCEPTION_IF_NULL(context);
350   if (graph_input_backend_parameters_.find(current_data_type_) == graph_input_backend_parameters_.end()) {
351     return;
352   }
353   for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
354     if (input_device_tensors_[i] != nullptr && input_device_tensors_[i]->user_data() != nullptr) {
355       MS_EXCEPTION_IF_NULL(input_device_tensors_[i]->kernel_tensor());
356       const auto &user_data_obj = input_device_tensors_[i]->user_data()->get<kernel::PyExecuteOutputUserData>(
357         kernel::PyExecuteOutputUserData::key);
358       MS_EXCEPTION_IF_NULL(user_data_obj);
359       const auto &obj = user_data_obj->obj;
360       auto abstract = pyexecute::GenerateAbstractFromPyObject(obj);
361       MS_EXCEPTION_IF_NULL(abstract);
362       MS_EXCEPTION_IF_NULL(abstract->BuildType());
363       MS_EXCEPTION_IF_NULL(abstract->BuildShape());
364       MS_LOG(DEBUG) << "actor:" << GetAID() << " set shape by abstract:" << abstract->ToString()
365                     << " shape:" << abstract->BuildShape()->ToString() << " type:" << abstract->BuildType()->ToString()
366                     << " for device address:" << input_device_tensors_[i];
367       input_device_tensors_[i]->kernel_tensor()->SetType(abstract->BuildType());
368       input_device_tensors_[i]->kernel_tensor()->SetShape(abstract->BuildShape());
369       MS_LOG(DEBUG) << "Infer abstract:" << abstract->ToString();
370     }
371   }
372 }
373 
374 namespace {
ClearAttrForGraph(const KernelGraphPtr & graph,const std::string & attr_name)375 void ClearAttrForGraph(const KernelGraphPtr &graph, const std::string &attr_name) {
376   MS_EXCEPTION_IF_NULL(graph);
377   for (const auto &node_pair : graph->front_backend_anf_map()) {
378     MS_EXCEPTION_IF_NULL(node_pair.second);
379     if (!node_pair.second->isa<CNode>()) {
380       continue;
381     }
382     MS_LOG(DEBUG) << "Check for node:" << node_pair.second->DebugString() << " attr name:" << attr_name;
383     const auto &cnode = node_pair.second->cast<CNodePtr>();
384     MS_EXCEPTION_IF_NULL(cnode);
385     if (common::AnfAlgo::HasNodeAttr(attr_name, cnode)) {
386       MS_LOG(DEBUG) << "Erase flag for node:" << node_pair.second->DebugString() << " attr name:" << attr_name;
387       common::AnfAlgo::EraseNodeAttr(attr_name, cnode);
388     }
389   }
390 }
391 }  // namespace
392 
RunForGraphInput(OpContext<DeviceTensor> * const context)393 void AnyTypeKernelActor::RunForGraphInput(OpContext<DeviceTensor> *const context) {
394   MS_EXCEPTION_IF_NULL(context);
395   MS_EXCEPTION_IF_NULL(graph());
396   actor_state_ = AnyTypeKernelActorState::kAnyTypeKernelActorSendInput;
397   MS_LOG(DEBUG) << "Any type kernel actor:" << GetAID() << " run for graph input.";
398   FetchInputDeviceTensor(context);
399   current_data_type_ = GenerateIDForGraph(input_device_tensors_, any_type_parameter_indexes_);
400   MS_LOG(DEBUG) << "Current data type:" << current_data_type_ << " for actor:" << GetAID();
401   vector<AbstractActorPtr> actors;
402   if (real_graphs_.find(current_data_type_) == real_graphs_.end()) {
403     try {
404       std::lock_guard<std::mutex> lock(instance_lock_);
405       InferParameterAbstractForModelGraph(graph(), input_device_tensors_, any_type_parameter_indexes_);
406       ClearAttrForGraph(graph(), kAttrInputIsDynamicShape);
407       ClearAttrForGraph(graph(), kAttrOutputIsDynamicShape);
408       graph()->InferType();
409       const auto &return_node = graph()->get_return();
410       MS_EXCEPTION_IF_NULL(return_node);
411       if (!return_node->isa<CNode>() || return_node->cast<CNodePtr>()->size() <= 1) {
412         MS_LOG_WITH_NODE(EXCEPTION, return_node)
413           << "Invalid return node:" << return_node->DebugString() << " for graph:" << graph()->ToString();
414       }
415       if (device_contexts().empty() || device_contexts()[0] == nullptr) {
416         MS_LOG(EXCEPTION) << "Invalid device context for actor:" << GetAID();
417       }
418       AnfNodePtrList inputs{};
419       AnfNodePtrList outputs{return_node->cast<CNodePtr>()->input(1)};
420       auto io_nodes = std::make_pair(inputs, outputs);
421       auto new_graph =
422         compile_func_(BuildSegmentByGraph(graph()), io_nodes, device_contexts()[0], device::RunMode::kKernelMode);
423       MS_EXCEPTION_IF_NULL(new_graph);
424       MS_LOG(INFO) << "Add new kernel graph:" << new_graph->ToString() << " for graph:" << graph()->ToString();
425       real_graphs_[current_data_type_] = new_graph;
426       actors = transform_func_(graph(), new_graph, device_contexts()[0]);
427       actors_[current_data_type_] = actors;
428       schedule_func_(actors);
429 
430       for (const auto &node_pair : new_graph->front_backend_anf_map()) {
431         MS_EXCEPTION_IF_NULL(node_pair.first);
432         if (!node_pair.first->isa<CNode>()) {
433           continue;
434         }
435         MS_LOG(DEBUG) << "Check for node:" << node_pair.first->DebugString();
436         const auto &cnode = node_pair.first->cast<CNodePtr>();
437         MS_EXCEPTION_IF_NULL(cnode);
438         if (cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
439           MS_LOG(DEBUG) << "Erase flag for node:" << node_pair.first->DebugString();
440           cnode->EraseAttr(kAttrReplaceRealKernelInBackend);
441         }
442       }
443     } catch (const std::exception &e) {
444       MsException::Instance().SetException();
445       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), e.what());
446     }
447   }
448   UpdataDynamicShapeParameterForGraphInput(context);
449   EraseInput(context);
450   if (memory_alloc_list_.size() > 0) {
451     MS_LOG(EXCEPTION) << "Any type kernel actor:" << GetAID() << "cannot send memory alloc message.";
452   } else {
453     OnMemoryAllocFinish(context);
454   }
455 }
456 
FetchInputIndexByBackendParameter(const AnfNodePtr & backend_node,const KernelGraphPtr & front_graph,const KernelGraphPtr & backend_graph)457 size_t FetchInputIndexByBackendParameter(const AnfNodePtr &backend_node, const KernelGraphPtr &front_graph,
458                                          const KernelGraphPtr &backend_graph) {
459   MS_EXCEPTION_IF_NULL(backend_node);
460   MS_EXCEPTION_IF_NULL(front_graph);
461   MS_EXCEPTION_IF_NULL(backend_graph);
462   const auto &front_node = backend_graph->GetFrontAnfByBackendAnf(backend_node);
463   MS_EXCEPTION_IF_NULL(front_node);
464   const auto &front_parameters = front_graph->input_nodes();
465   const auto &iter = find(front_parameters.begin(), front_parameters.end(), front_node);
466   if (iter == front_parameters.end()) {
467     MS_LOG_WITH_NODE(EXCEPTION, front_node)
468       << "Invalid front parameter:" << front_node->DebugString() << " for graph:" << front_graph->ToString();
469   }
470   return iter - front_parameters.begin();
471 }
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)472 void AnyTypeKernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
473   MS_EXCEPTION_IF_NULL(graph());
474   if (real_graphs_.find(current_data_type_) == real_graphs_.end()) {
475     MS_LOG(EXCEPTION) << "Invalid index:" << current_data_type_ << " for any type kernel actor:" << GetAID();
476   }
477   const auto &real_graph = real_graphs_[current_data_type_];
478   MS_EXCEPTION_IF_NULL(real_graph);
479   if (real_graph->input_nodes().size() != graph()->input_nodes().size()) {
480     MS_LOG(EXCEPTION) << "Invalid input node num:" << real_graph->input_nodes().size()
481                       << " in graph:" << real_graph->ToString() << " for model graph:" << graph()->ToString()
482                       << " input num:" << graph()->input_nodes().size() << " for actor:" << GetAID();
483   }
484   for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
485     const auto &input_node = real_graph->input_nodes()[i];
486     MS_EXCEPTION_IF_NULL(input_node);
487     if (HasAbstractMonad(input_node)) {
488       continue;
489     }
490     size_t from_index = FetchInputIndexByBackendParameter(input_node, graph(), real_graph);
491     if (!AnfAlgo::OutputAddrExist(input_node, 0, false)) {
492       MS_LOG_WITH_NODE(EXCEPTION, input_node)
493         << "Input node:" << input_node->DebugString() << " has no device address for actor:" << GetAID();
494     }
495     auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
496     MS_EXCEPTION_IF_NULL(device_address);
497     if (from_index >= node_device_tensors_.size() || from_index >= input_device_tensors_.size()) {
498       MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
499                         << " node device tensor size:" << node_device_tensors_.size()
500                         << " input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
501     }
502     node_device_tensors_[from_index] = device_address;
503     if (input_device_tensors_[from_index] == nullptr) {
504       MS_LOG_WITH_NODE(EXCEPTION, input_node)
505         << "actor:" << GetAID() << " real graph:" << real_graph->ToString()
506         << " input node:" << input_node->DebugString() << " index : " << i << " is nullptr ";
507     }
508     node_device_tensors_[from_index]->SetNodeIndex(input_device_tensors_[from_index]->node_index().first.lock(),
509                                                    input_device_tensors_[from_index]->node_index().second);
510     MS_LOG(DEBUG) << "Actor:" << GetAID() << " input " << from_index << ":"
511                   << " device address:" << device_address
512                   << " original ref count:" << device_address->original_ref_count()
513                   << " ref count:" << device_address->ref_count()
514                   << " dynamic ref count:" << device_address->dynamic_ref_count()
515                   << " real shape:" << node_device_tensors_[from_index]->kernel_tensor()->GetShape()->ToString()
516                   << " model shape:" << input_device_tensors_[from_index]->kernel_tensor()->GetShape()->ToString();
517   }
518   if (node_device_tensors_.size() != input_device_tensors_.size()) {
519     MS_LOG(EXCEPTION) << "Invalid device tensor num:" << input_device_tensors_.size() << " and "
520                       << node_device_tensors_.size() << " for actor:" << GetAID();
521   }
522   for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
523     if (node_device_tensors_[i] != nullptr && input_device_tensors_[i] != nullptr) {
524       MS_EXCEPTION_IF_NULL(input_device_tensors_[i]->kernel_tensor());
525       MS_EXCEPTION_IF_NULL(node_device_tensors_[i]->kernel_tensor());
526       MS_LOG(DEBUG) << "set shape:"
527                     << (input_device_tensors_[i]->kernel_tensor()->GetShape() == nullptr
528                           ? "null"
529                           : input_device_tensors_[i]->kernel_tensor()->GetShape()->ToString())
530                     << " type:"
531                     << (input_device_tensors_[i]->kernel_tensor()->GetType() == nullptr
532                           ? "null"
533                           : input_device_tensors_[i]->kernel_tensor()->GetType()->ToString())
534                     << " from device address:" << input_device_tensors_[i]
535                     << " to device address:" << node_device_tensors_[i];
536       node_device_tensors_[i]->kernel_tensor()->SetType(input_device_tensors_[i]->kernel_tensor()->GetType());
537       node_device_tensors_[i]->kernel_tensor()->SetShape(input_device_tensors_[i]->kernel_tensor()->GetShape());
538       MS_LOG(DEBUG) << "set shape:" << input_device_tensors_[i]->kernel_tensor()->GetShape()->ToString()
539                     << " from device address:" << input_device_tensors_[i]
540                     << " to device address:" << node_device_tensors_[i];
541     }
542   }
543   CopyInputData(context, real_graphs_[current_data_type_]);
544   if (!memory_free_lists_.empty()) {
545     for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
546       if (node_device_tensors_[i] != nullptr) {
547         memory_free_lists_.back().emplace_back(node_device_tensors_[i].get());
548       }
549     }
550   }
551   SendOutput(context);
552 }
553 
EraseGraphOutput(OpContext<DeviceTensor> * const context)554 void AnyTypeKernelActor::EraseGraphOutput(OpContext<DeviceTensor> *const context) {
555   MS_EXCEPTION_IF_NULL(context);
556   if ((graph_output_data_num_[current_data_type_] != 0) && (!graph_output_op_data_.empty())) {
557     auto ret = graph_output_op_data_.erase(context->sequential_num_);
558     if (ret == 0) {
559       MS_LOG(WARNING) << "Erase graph output data failed: " << GetAID().Name()
560                       << ", sequential_num: " << context->sequential_num_;
561       return;
562     }
563   }
564 
565   if ((graph_output_control_num_[current_data_type_] != 0) && (!graph_output_op_control_.empty())) {
566     auto ret = graph_output_op_control_.erase(context->sequential_num_);
567     if (ret == 0) {
568       MS_LOG(WARNING) << "Erase graph output controls failed: " << GetAID().Name()
569                       << ", sequential_num: " << context->sequential_num_;
570       return;
571     }
572   }
573 }
574 
RunForGraphOutput(OpContext<DeviceTensor> * const context)575 void AnyTypeKernelActor::RunForGraphOutput(OpContext<DeviceTensor> *const context) {
576   MS_LOG(DEBUG) << "actor:" << GetAID() << " run for graph output start";
577   actor_state_ = AnyTypeKernelActorState::kAnyTypeKernelActorSendOutput;
578   FetchGraphOutput(context);
579   EraseGraphOutput(context);
580   SendMemoryFreeReq(context);
581   AbstractActor::SendOutput(context);
582 }
583 
Init()584 void AnyTypeKernelActor::Init() {
585   MS_EXCEPTION_IF_NULL(graph());
586   MS_LOG(DEBUG) << "actor:" << GetAID() << " init";
587   SuperKernelActor::Init();
588   memory_alloc_list_.clear();
589   for (size_t i = 0; i < graph()->input_nodes().size(); ++i) {
590     const auto &input = graph()->input_nodes()[i];
591     MS_EXCEPTION_IF_NULL(input);
592     const auto &abs = input->abstract();
593     MS_EXCEPTION_IF_NULL(abs);
594     if (abs->isa<abstract::AbstractAny>()) {
595       any_type_parameter_indexes_.emplace_back(i);
596       MS_LOG(DEBUG) << "Add any type parameter index:" << i << " by parameter:" << input->DebugString()
597                     << " for actor:" << GetAID();
598     }
599   }
600   for (const auto &node_with_index : common::AnfAlgo::GetAllOutputWithOutMonadAndParameter(graph()->output())) {
601     MS_EXCEPTION_IF_NULL(node_with_index.first);
602     if (!AnfAlgo::OutputAddrExist(node_with_index.first, node_with_index.second)) {
603       MS_LOG_WITH_NODE(EXCEPTION, node_with_index.first)
604         << "Failed to get output address from node:" << node_with_index.first->DebugString()
605         << " index:" << node_with_index.second << " for actor:" << GetAID();
606     }
607     graph_ouput_device_tensors_.emplace_back(
608       AnfAlgo::GetMutableOutputAddr(node_with_index.first, node_with_index.second, false).get());
609   }
610   fallback_device_tensors_.resize(graph_ouput_device_tensors_.size());
611 }
612 
613 namespace {
FreeMemory(DeviceTensor * device_tensor)614 void FreeMemory(DeviceTensor *device_tensor) {
615   MS_EXCEPTION_IF_NULL(device_tensor);
616   const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
617     {device_tensor->device_name(), device_tensor->device_id()});
618   if (device_context == nullptr || device_context->device_res_manager_ == nullptr) {
619     return;
620   }
621   MS_LOG(DEBUG) << "Device tensor:" << device_tensor << " release memory:" << device_tensor->GetMutablePtr();
622   device_context->device_res_manager_->FreeMemory(device_tensor->GetMutablePtr());
623   device_tensor->set_ptr(nullptr);
624 }
625 }  // namespace
626 
CheckParams(OpContext<DeviceTensor> * const context)627 void AnyTypeKernelActor::CheckParams(OpContext<DeviceTensor> *const context) {
628   MS_EXCEPTION_IF_NULL(context);
629   MS_EXCEPTION_IF_NULL(graph());
630   if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
631     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
632                                                   "Invalid device context for any type actor:" + GetAID().Name());
633   }
634 }
635 
FetchGraphOutput(OpContext<DeviceTensor> * const context)636 void AnyTypeKernelActor::FetchGraphOutput(OpContext<DeviceTensor> *const context) {
637   CheckParams(context);
638   const auto &data_iter = graph_output_op_data_.find(context->sequential_num_);
639   if (data_iter != graph_output_op_data_.end()) {
640     std::set<DeviceTensor *> clear_device_tensors;
641     for (auto &graph_output_data : data_iter->second) {
642       MS_EXCEPTION_IF_NULL(graph_output_data);
643       MS_EXCEPTION_IF_NULL(graph_output_data->data_);
644       size_t index = IntToSize(graph_output_data->index_);
645       if (index < graph()->input_nodes().size()) {
646         MS_LOG(WARNING) << "Invalid graph output index:" << index << " input num:" << input_datas_num_
647                         << " for actor:" << GetAID();
648         continue;
649       }
650       index -= graph()->input_nodes().size();
651       if (index >= graph_ouput_device_tensors_.size() ||
652           graph_ouput_device_tensors_.size() != fallback_device_tensors_.size()) {
653         std::string error_info = "Invalid input index:" + std::to_string(index) +
654                                  " total:" + std::to_string(graph_ouput_device_tensors_.size()) +
655                                  " for actor:" + GetAID().Name();
656         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
657       }
658       MS_LOG(DEBUG) << "Fetch graph output index:" << index << " set ptr:" << graph_output_data->data_->GetMutablePtr()
659                     << " size:" << graph_output_data->data_->GetSize()
660                     << " from device address:" << graph_output_data->data_
661                     << " to:" << graph_ouput_device_tensors_[index] << " for actor:" << GetAID();
662       MS_EXCEPTION_IF_NULL(graph_ouput_device_tensors_[index]);
663       if (graph_ouput_device_tensors_[index]->GetDeviceType() != graph_output_data->data_->GetDeviceType()) {
664         MS_LOG(INFO) << "Different device type for actor:" << GetAID()
665                      << " front device address:" << graph_ouput_device_tensors_[index]
666                      << " device type:" << graph_ouput_device_tensors_[index]->GetDeviceType()
667                      << " backend device address:" << graph_output_data->data_
668                      << " device type:" << graph_output_data->data_->GetDeviceType();
669         if (fallback_device_tensors_[index] != nullptr) {
670           if (fallback_device_tensors_[index]->GetDeviceType() != graph_output_data->data_->GetDeviceType()) {
671             MS_LOG(ERROR) << "Invalid device type for actor:" << GetAID()
672                           << " fallback device address:" << fallback_device_tensors_[index]
673                           << " device type:" << fallback_device_tensors_[index]->GetDeviceType()
674                           << " backend device address:" << graph_output_data->data_
675                           << " device type:" << graph_output_data->data_->GetDeviceType();
676             SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), GetAID().Name() + " invalid device type.");
677           }
678         } else {
679           auto tmp_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
680             {graph_output_data->data_->device_name(), graph_output_data->data_->device_id()});
681           MS_EXCEPTION_IF_NULL(tmp_device_context);
682 
683           const auto &graph_output_kernel_tensor = graph_output_data->data_->kernel_tensor();
684           MS_EXCEPTION_IF_NULL(graph_output_kernel_tensor);
685           const auto &fallback_kernel_tensor = graph_output_kernel_tensor->CloneKernelTensor();
686           MS_EXCEPTION_IF_NULL(fallback_kernel_tensor);
687           fallback_kernel_tensor->set_device_ptr(nullptr);
688           fallback_device_tensors_[index] =
689             tmp_device_context->device_res_manager_->CreateDeviceAddress(fallback_kernel_tensor);
690           MS_EXCEPTION_IF_NULL(fallback_device_tensors_[index]);
691           MS_LOG(DEBUG) << "Create device address:" << fallback_device_tensors_[index] << " for actor:" << GetAID()
692                         << " index:" << index << " device type:" << fallback_device_tensors_[index]->GetDeviceType()
693                         << " size:" << fallback_device_tensors_[index]->GetSize();
694           fallback_device_tensors_[index]->set_ref_count(graph_ouput_device_tensors_[index]->ref_count());
695           fallback_device_tensors_[index]->set_original_ref_count(
696             graph_ouput_device_tensors_[index]->original_ref_count());
697           fallback_device_tensors_[index]->set_dynamic_ref_count(
698             graph_ouput_device_tensors_[index]->dynamic_ref_count());
699         }
700         graph_ouput_device_tensors_[index] = fallback_device_tensors_[index].get();
701       }
702       if (graph_ouput_device_tensors_[index]->GetPtr() != nullptr) {
703         // As the from memory pool flag of any type kernel graph is false, the memory cannot be released automatically,
704         // and the memory needs to be released before overwriting.
705         FreeMemory(graph_ouput_device_tensors_[index]);
706       }
707       graph_ouput_device_tensors_[index]->set_ptr(graph_output_data->data_->GetMutablePtr());
708       graph_ouput_device_tensors_[index]->set_need_sync_user_data(graph_output_data->data_->need_sync_user_data());
709       clear_device_tensors.emplace(graph_output_data->data_);
710       graph_ouput_device_tensors_[index]->SetSize(graph_output_data->data_->GetSize());
711 
712       // Update Shape.
713       const auto &graph_output_device_kernel_tensor = graph_ouput_device_tensors_[index]->kernel_tensor();
714       const auto &graph_output_data_kernel_tensor = graph_output_data->data_->kernel_tensor();
715       MS_EXCEPTION_IF_NULL(graph_output_device_kernel_tensor);
716       MS_EXCEPTION_IF_NULL(graph_output_data_kernel_tensor);
717       MS_LOG(DEBUG) << "actor:" << GetAID() << " set shape from device address:" << graph_output_data->data_
718                     << " to:" << graph_ouput_device_tensors_[index]
719                     << " for shape:" << graph_output_data_kernel_tensor->GetShape()->ToString();
720       graph_output_device_kernel_tensor->SetType(graph_output_data_kernel_tensor->GetType()->Clone());
721       graph_output_device_kernel_tensor->SetShape(graph_output_data_kernel_tensor->GetShape()->Clone());
722 
723       auto node_with_index = graph_output_data->data_->node_index();
724       graph_ouput_device_tensors_[index]->SetNodeIndex(node_with_index.first.lock(), node_with_index.second);
725       MS_LOG(DEBUG) << "Actor:" << GetAID() << "src device address:" << graph_output_data->data_
726                     << " shape:" << graph_output_data->data_->host_shape()
727                     << " type:" << graph_output_data->data_->type_id()
728                     << "dst device address:" << graph_ouput_device_tensors_[index]
729                     << " shape:" << graph_ouput_device_tensors_[index]->host_shape()
730                     << " type:" << graph_ouput_device_tensors_[index]->type_id();
731       graph_ouput_device_tensors_[index]->set_type_id(graph_output_data->data_->type_id());
732       graph_ouput_device_tensors_[index]->set_host_shape(graph_output_data->data_->host_shape());
733       graph_ouput_device_tensors_[index]->set_user_data(graph_output_data->data_->user_data());
734     }
735     for_each(clear_device_tensors.begin(), clear_device_tensors.end(),
736              [](DeviceTensor *device_tensor) { device_tensor->set_ptr(nullptr); });
737   }
738 }
739 
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr & data_arrow,const AnfNodePtr & output_node,OpContext<DeviceTensor> * const context)740 void AnyTypeKernelActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
741                                           const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) {
742   MS_EXCEPTION_IF_NULL(output_data);
743   MS_EXCEPTION_IF_NULL(data_arrow);
744   MS_EXCEPTION_IF_NULL(output_node);
745   MS_EXCEPTION_IF_NULL(context);
746   MS_EXCEPTION_IF_NULL(graph());
747   if (actor_state_ == AnyTypeKernelActorState::kAnyTypeKernelActorSendOutput) {
748     size_t index = IntToSize(data_arrow->from_output_index_);
749     const auto &real_output = common::AnfAlgo::GetAllOutputWithOutMonadAndParameter(graph()->output());
750     const auto &output_iter = find(real_output.begin(), real_output.end(), std::make_pair(output_node, index));
751     if (output_iter == real_output.end()) {
752       MS_LOG_WITH_NODE(EXCEPTION, output_node) << "Invalid output node:" << output_node->DebugString()
753                                                << " index:" << index << " for graph:" << graph()->ToString();
754     }
755     size_t real_output_index = LongToSize(output_iter - real_output.begin());
756     if (real_output_index >= graph_ouput_device_tensors_.size()) {
757       MS_LOG_WITH_NODE(EXCEPTION, output_node)
758         << "Invalid input index:" << real_output_index << " by node:" << output_node->DebugString()
759         << " for actor:" << GetAID();
760     }
761     MS_LOG(DEBUG) << "actor:" << GetAID() << " output node:" << output_node->DebugString()
762                   << " to actor:" << data_arrow->to_op_id_ << " from index:" << real_output_index;
763     MS_EXCEPTION_IF_NULL(graph_ouput_device_tensors_[real_output_index]);
764     output_data->data_ = graph_ouput_device_tensors_[real_output_index];
765     return;
766   }
767 
768   const auto &real_graph = real_graphs_[current_data_type_];
769   MS_EXCEPTION_IF_NULL(real_graph);
770   const auto &front_node = real_graph->GetFrontAnfByBackendAnf(output_node);
771   MS_EXCEPTION_IF_NULL(front_node);
772   const auto &model_graph = SuperKernelActor::graph();
773   MS_EXCEPTION_IF_NULL(model_graph);
774   auto &input_nodes = model_graph->input_nodes();
775   const auto &iter = find(input_nodes.begin(), input_nodes.end(), front_node);
776   if (iter == input_nodes.end()) {
777     MS_LOG_WITH_NODE(EXCEPTION, output_node)
778       << "Invalid input node:" << output_node->DebugString() << " front node:" << front_node->DebugString();
779   }
780   size_t index = LongToSize(iter - input_nodes.begin());
781   if (index >= node_device_tensors_.size()) {
782     MS_LOG_WITH_NODE(EXCEPTION, output_node)
783       << "Invalid input index:" << index << " by node:" << output_node->DebugString() << " for actor:" << GetAID();
784   }
785   if (node_device_tensors_[index] == nullptr) {
786     MS_LOG(EXCEPTION) << "failed to get input index:" << index << " for actor:" << GetAID();
787   }
788   output_data->data_ = node_device_tensors_[index].get();
789 }
790 
SendOutput(OpContext<DeviceTensor> * const context)791 void AnyTypeKernelActor::SendOutput(OpContext<DeviceTensor> *const context) {
792   MS_EXCEPTION_IF_NULL(context);
793   MS_LOG(DEBUG) << "Any type actor:" << GetAID() << " send output";
794   // Must be the execution order: send data --> send control, avoid the illegal timing problem.
795   SendOutputData(context, graph_input_data_nodes_[current_data_type_], graph_input_data_arrows_[current_data_type_],
796                  graph_input_data_[current_data_type_], data_arrow_to_graph_input_actor_indexs_[current_data_type_],
797                  &batch_graph_input_data_[current_data_type_]);
798 
799   // 2.Send output control.
800   if (graph_input_control_arrows_[current_data_type_].size() > 0) {
801     auto from_aid = const_cast<AID *>(&GetAID());
802     for (auto &output_control : graph_input_control_arrows_[current_data_type_]) {
803       MS_EXCEPTION_IF_NULL(output_control);
804       if (TEST_FLAG(output_control->flag_, kOutputDataFlagBetweenFusion)) {
805         const auto &to_actor = FetchSubActorInFusionActor(output_control->to_op_id_.Name());
806         ActorDispatcher::SendSync(to_actor, &OpActor::RunOpControl, from_aid, context);
807       } else {
808         ActorDispatcher::Send(output_control->to_op_id_, &OpActor::RunOpControl, from_aid, context);
809       }
810     }
811   }
812 }
813 }  // namespace runtime
814 }  // namespace mindspore
815