1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/any_type_kernel_actor.h"
18 #include <set>
19 #include <functional>
20 #include "include/common/debug/anf_ir_dump.h"
21 #include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "include/common/fallback.h"
24 #include "include/common/utils/stub_tensor.h"
25 #include "include/backend/py_execute_utils.h"
26
27 namespace mindspore {
28 namespace runtime {
29 namespace {
30 using AddressPtr = kernel::AddressPtr;
31 using PyExecuteOutputUserData = kernel::PyExecuteOutputUserData;
32 } // namespace
33
34 std::mutex AnyTypeKernelActor::instance_lock_;
35
AnyTypeKernelActor(const std::string & name,const KernelGraphPtr & graph,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,KernelTransformType type)36 AnyTypeKernelActor::AnyTypeKernelActor(const std::string &name, const KernelGraphPtr &graph,
37 const DeviceContext *device_context, const AID &memory_manager_aid,
38 const AID *debug_aid, const AID *recorder_aid, KernelTransformType type)
39 : SuperKernelActor(name, graph, device_context, memory_manager_aid, debug_aid, recorder_aid, type) {}
40
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)41 void AnyTypeKernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
42 MS_EXCEPTION_IF_NULL(input_data);
43 MS_EXCEPTION_IF_NULL(input_data->data_);
44 MS_EXCEPTION_IF_NULL(input_data->data_->kernel_tensor());
45 MS_EXCEPTION_IF_NULL(context);
46 MS_EXCEPTION_IF_NULL(graph());
47 auto &sequential_num = context->sequential_num_;
48 if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
49 !TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed)) {
50 MS_LOG(EXCEPTION) << "The input_data does not have a valid ptr of actor:" << GetAID().Name()
51 << " with index:" << input_data->index_ << ", flag:" << input_data->data_->flag()
52 << " device address:" << input_data->data_ << " ref count:" << input_data->data_->ref_count()
53 << " dynamic ref count:" << input_data->data_->dynamic_ref_count()
54 << " origin ref count:" << input_data->data_->original_ref_count();
55 }
56 MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data:" << input_data->data_
57 << " index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
58 << " ptr:" << input_data->data_->GetPtr() << " user data:" << input_data->data_->user_data()
59 << " input num:" << input_datas_num_ << " input device tensor size:" << input_device_tensors_.size()
60 << " ref count:" << input_data->data_->ref_count()
61 << " dynamic ref count:" << input_data->data_->dynamic_ref_count()
62 << " origin ref count:" << input_data->data_->original_ref_count()
63 << " user data:" << input_data->data_->user_data()
64 << " type:" << input_data->data_->kernel_tensor()->GetType()
65 << " type id:" << input_data->data_->kernel_tensor()->type_id();
66 if (input_data->index_ < SizeToLong(graph()->input_nodes().size())) {
67 // Collect graph input data.
68 input_op_datas_[sequential_num].emplace_back(input_data);
69 if (CheckRunningCondition(context)) {
70 MS_LOG(DEBUG) << "Begin wait runtime pipeline to run for graph input for actor: " << GetAID().Name();
71 if (!WaitRuntimePipelineFinish(context)) {
72 MS_LOG(INFO) << "Run failed and early stop.";
73 return;
74 }
75 MS_LOG(DEBUG) << "End wait runtime pipeline to run for graph input for actor: " << GetAID().Name();
76 RunForGraphInput(context);
77 }
78 } else {
79 // Collect graph output data.
80 graph_output_op_data_[sequential_num].emplace_back(input_data);
81 if (CheckGraphOutputRunningCondition(context)) {
82 MS_LOG(DEBUG) << "Begin wait runtime pipeline to run for graph output for actor: " << GetAID().Name();
83 if (!WaitRuntimePipelineFinish(context)) {
84 MS_LOG(INFO) << "Run failed and early stop.";
85 return;
86 }
87 MS_LOG(DEBUG) << "End wait runtime pipeline to run for graph output for actor: " << GetAID().Name();
88 RunForGraphOutput(context);
89 }
90 }
91 }
92
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)93 void AnyTypeKernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
94 MS_EXCEPTION_IF_NULL(context);
95 MS_EXCEPTION_IF_NULL(input_control);
96 auto &sequential_num = context->sequential_num_;
97 MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op control:" << input_control->Name();
98 if (std::any_of(
99 input_control_arrow_aids_.begin(), input_control_arrow_aids_.end(),
100 [input_control](const auto &arrow_pair) { return arrow_pair.first.Name() == input_control->Name(); })) {
101 (void)input_op_controls_[sequential_num].emplace_back(input_control);
102 if (CheckRunningCondition(context)) {
103 if (!WaitRuntimePipelineFinish(context)) {
104 MS_LOG(INFO) << "Run failed and early stop.";
105 return;
106 }
107 RunForGraphInput(context);
108 }
109 } else {
110 graph_output_op_control_[sequential_num].emplace_back(input_control);
111 if (CheckGraphOutputRunningCondition(context)) {
112 if (!WaitRuntimePipelineFinish(context)) {
113 MS_LOG(INFO) << "Run failed and early stop.";
114 return;
115 }
116 RunForGraphOutput(context);
117 }
118 }
119 }
120
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)121 void AnyTypeKernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
122 MS_EXCEPTION_IF_NULL(context);
123 std::vector<DeviceTensor *> memory_free_list = graph_ouput_device_tensors_;
124 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
125 if (data_iter == input_op_datas_.end()) {
126 memory_free_lists_.push(memory_free_list);
127 return;
128 }
129 for (auto &input_data : data_iter->second) {
130 MS_EXCEPTION_IF_NULL(input_data);
131 MS_EXCEPTION_IF_NULL(input_data->data_);
132 size_t index = IntToSize(input_data->index_);
133 if (index >= input_device_tensors_.size()) {
134 std::string error_info = "Invalid input index:" + std::to_string(index) +
135 " total:" + std::to_string(input_device_tensors_.size()) +
136 " for actor:" + GetAID().Name();
137 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
138 }
139 input_device_tensors_[index] = input_data->data_;
140 if (input_data->data_->ref_count() != SIZE_MAX) {
141 (void)memory_free_list.emplace_back(input_data->data_);
142 }
143 }
144 memory_free_lists_.push(memory_free_list);
145
146 for (auto &device_tensor_store_key : device_tensor_store_keys_) {
147 MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
148 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
149 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
150 "Invalid device context for any type actor:" + GetAID().Name());
151 }
152 auto device_tensor = DeviceTensorStore::GetInstance()
153 .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
154 .get();
155 if (device_tensor == nullptr) {
156 MS_LOG_WITH_NODE(EXCEPTION, device_tensor_store_key.second)
157 << "Failed get device tensor for node:" << device_tensor_store_key.second->DebugString()
158 << " index:" << device_tensor_store_key.first << " device type:" << device_contexts_[0]->GetDeviceType();
159 continue;
160 }
161 if (device_tensor_store_key.first >= input_device_tensors_.size()) {
162 std::string error_info = "Invalid input index:" + std::to_string(device_tensor_store_key.first) +
163 " total:" + std::to_string(input_device_tensors_.size()) +
164 " for actor:" + GetAID().Name();
165 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
166 }
167 input_device_tensors_[device_tensor_store_key.first] = device_tensor;
168 }
169 }
170
CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> * context)171 bool AnyTypeKernelActor::CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> *context) {
172 MS_EXCEPTION_IF_NULL(context);
173 MS_LOG(DEBUG) << "graph output data num:" << graph_output_data_num_[current_data_type_]
174 << " control num:" << graph_output_control_num_[current_data_type_];
175 if (graph_output_data_num_[current_data_type_] != 0) {
176 const auto &data_iter = graph_output_op_data_.find(context->sequential_num_);
177 if (data_iter == graph_output_op_data_.end()) {
178 return false;
179 }
180 if (data_iter->second.size() < graph_output_data_num_[current_data_type_]) {
181 return false;
182 } else if (data_iter->second.size() > graph_output_data_num_[current_data_type_]) {
183 MS_LOG(ERROR) << "Invalid graph output data num:" << data_iter->second.size()
184 << " need:" << graph_output_data_num_[current_data_type_] << " for actor:" << GetAID()
185 << ", sequential num:" << context->sequential_num_;
186 return false;
187 }
188 }
189
190 if (graph_output_control_num_[current_data_type_] != 0) {
191 const auto &control_iter = graph_output_op_control_.find(context->sequential_num_);
192 if (control_iter == graph_output_op_control_.end()) {
193 return false;
194 }
195 if (control_iter->second.size() < graph_output_control_num_[current_data_type_]) {
196 return false;
197 } else if (control_iter->second.size() > graph_output_control_num_[current_data_type_]) {
198 MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size()
199 << " need:" << graph_output_control_num_[current_data_type_] << " for actor:" << GetAID()
200 << ", sequential num:" << context->sequential_num_;
201 return false;
202 }
203 }
204 return true;
205 }
206 namespace {
BuildSegmentByGraph(const KernelGraphPtr & graph)207 GraphSegmentPtr BuildSegmentByGraph(const KernelGraphPtr &graph) {
208 MS_EXCEPTION_IF_NULL(graph);
209 std::vector<AnfNodePtr> nodes;
210 std::vector<AnfNodePtr> all_nodes = TopoSort(graph->get_return());
211 for (const auto &node : all_nodes) {
212 if (node == nullptr || (!node->isa<CNode>()) || common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
213 continue;
214 }
215 MS_LOG(DEBUG) << "build new segment node:" << node->DebugString();
216 nodes.emplace_back(node);
217 }
218 return std::make_shared<GraphSegment>(nodes, false);
219 }
220
GenerateIDForGraph(const std::vector<DeviceTensor * > & device_tensors,const std::vector<size_t> & indexes)221 std::string GenerateIDForGraph(const std::vector<DeviceTensor *> &device_tensors, const std::vector<size_t> &indexes) {
222 std::string id;
223 auto get_shape_and_type_string = [&id](const ShapeVector &shape_vector, TypeId type_id) {
224 id += "shape_";
225 (void)std::for_each(shape_vector.begin(), shape_vector.end(), [&id](int64_t shape) {
226 id += std::to_string(shape);
227 id += "_";
228 });
229 id = id + "type_" + std::to_string(type_id) + "_";
230 };
231 for (const auto &index : indexes) {
232 if (index >= device_tensors.size()) {
233 MS_LOG(EXCEPTION) << "Invalid parameter index:" << index << " for device tensor num:" << device_tensors.size();
234 }
235 id = id + "index_" + std::to_string(index) + "_";
236 const auto &device_tensor = device_tensors[index];
237 if (device_tensor == nullptr) {
238 MS_LOG(EXCEPTION) << "Empty device tensor index:" << index;
239 }
240 if (device_tensor->user_data() == nullptr) {
241 device_tensor->kernel_tensor()->SetType(device_tensor->kernel_tensor()->GetType());
242 device_tensor->kernel_tensor()->SetShape(device_tensor->kernel_tensor()->GetShape());
243 get_shape_and_type_string(device_tensor->host_shape(), device_tensor->type_id());
244 continue;
245 }
246
247 const auto &user_data_obj =
248 device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
249 MS_EXCEPTION_IF_NULL(user_data_obj);
250 const auto &obj = user_data_obj->obj;
251 py::gil_scoped_acquire gil_acquire;
252 const auto &abstract = pyexecute::GenerateAbstractFromPyObject(obj);
253 MS_EXCEPTION_IF_NULL(abstract);
254 if (abstract->isa<abstract::AbstractSequence>()) {
255 auto sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
256 MS_EXCEPTION_IF_NULL(sequence_abs);
257 id = id + "Tuple_" + std::to_string(sequence_abs->size()) + "_";
258 } else if (abstract->isa<abstract::AbstractScalar>()) {
259 id = id + "Scalar_";
260 } else if (abstract->isa<abstract::AbstractTensor>()) {
261 id = id + "Tensor_";
262 }
263 device_tensor->kernel_tensor()->SetType(abstract->BuildType());
264 device_tensor->kernel_tensor()->SetShape(abstract->BuildShape());
265 get_shape_and_type_string(device_tensor->host_shape(), device_tensor->type_id());
266 }
267 return id;
268 }
269
InferParameterAbstractForModelGraph(const KernelGraphPtr & graph,const std::vector<DeviceTensor * > & device_tensors,const std::vector<size_t> & indexes)270 void InferParameterAbstractForModelGraph(const KernelGraphPtr &graph, const std::vector<DeviceTensor *> &device_tensors,
271 const std::vector<size_t> &indexes) {
272 MS_EXCEPTION_IF_NULL(graph);
273 for (size_t index : indexes) {
274 if (index >= device_tensors.size() || index >= graph->input_nodes().size()) {
275 MS_LOG(EXCEPTION) << "Invalid index:" << index << " for input device tensor size:" << device_tensors.size()
276 << " for graph:" << graph->ToString();
277 }
278 const auto &device_tensor = device_tensors[index];
279 MS_EXCEPTION_IF_NULL(device_tensor);
280 MS_EXCEPTION_IF_NULL(device_tensor->kernel_tensor());
281 auto input_node = graph->input_nodes()[index];
282 MS_EXCEPTION_IF_NULL(input_node);
283 abstract::AbstractBasePtr abstract;
284 if (device_tensor->user_data() != nullptr &&
285 device_tensor->user_data()->has(kernel::PyExecuteOutputUserData::key)) {
286 MS_LOG(DEBUG) << "User data:" << device_tensor->user_data() << " in device address:" << device_tensor
287 << " for input:" << input_node->DebugString();
288 const auto &user_data_obj =
289 device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
290 MS_EXCEPTION_IF_NULL(user_data_obj);
291 const auto &obj = user_data_obj->obj;
292 py::gil_scoped_acquire gil_acquire;
293 abstract = pyexecute::GenerateAbstractFromPyObject(obj);
294 } else {
295 abstract =
296 abstract::MakeAbstract(device_tensor->kernel_tensor()->GetShape(), device_tensor->kernel_tensor()->GetType());
297 }
298 MS_EXCEPTION_IF_NULL(abstract);
299 MS_LOG(DEBUG) << "Infer parameter by abstract:" << abstract->ToString();
300 if (!abstract->isa<abstract::AbstractSequence>()) {
301 MS_LOG(DEBUG) << "Set abstract:" << abstract->ToString() << " for input node:" << input_node->DebugString()
302 << " device tensor:" << device_tensor << " type id:" << device_tensor->type_id();
303 input_node->set_abstract(abstract);
304 continue;
305 }
306 MS_LOG(DEBUG) << "Sequence abstract:" << abstract->ToString();
307 auto new_abstract = abstract->Clone();
308 MS_EXCEPTION_IF_NULL(new_abstract);
309 auto seq_abstract = new_abstract->cast<abstract::AbstractSequencePtr>();
310 MS_EXCEPTION_IF_NULL(seq_abstract);
311 seq_abstract->set_dynamic_len(true);
312 // Dynamic len element is used to check if the sequence is dynamic len.
313 if (!seq_abstract->elements().empty() && seq_abstract->elements()[0] != nullptr) {
314 seq_abstract->set_dynamic_len_element_abs(seq_abstract->elements()[0]->Clone());
315 }
316 MS_LOG(DEBUG) << "Set abstract:" << seq_abstract->ToString() << " for input node:" << input_node->DebugString()
317 << device_tensor << " type id:" << device_tensor->type_id();
318 input_node->set_abstract(seq_abstract);
319 }
320 }
321
GetElementType(const abstract::AbstractBasePtr & abstract)322 TypeId GetElementType(const abstract::AbstractBasePtr &abstract) {
323 MS_EXCEPTION_IF_NULL(abstract);
324 TypePtr type = nullptr;
325 if (abstract->isa<abstract::AbstractScalar>()) {
326 type = abstract->BuildType();
327 } else if (abstract->isa<abstract::AbstractTensor>()) {
328 const auto &tensor_abs = abstract->cast<abstract::AbstractTensorPtr>();
329 MS_EXCEPTION_IF_NULL(tensor_abs);
330 MS_EXCEPTION_IF_NULL(tensor_abs->element());
331 type = tensor_abs->element()->BuildType();
332 } else if (abstract->isa<abstract::AbstractSequence>()) {
333 const auto &sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
334 MS_EXCEPTION_IF_NULL(sequence_abs);
335 if (sequence_abs->dynamic_len() || sequence_abs->elements().empty() || sequence_abs->elements()[0] == nullptr) {
336 MS_LOG(INFO) << "Invalid abstract:" << abstract->ToString();
337 return TypeId::kNumberTypeInt64;
338 }
339 return GetElementType(sequence_abs->elements()[0]);
340 } else {
341 MS_LOG(EXCEPTION) << "Invalid abstract:" << abstract->ToString();
342 }
343 MS_EXCEPTION_IF_NULL(type);
344 return type->type_id();
345 }
346 } // namespace
347
UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> * const context)348 void AnyTypeKernelActor::UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> *const context) {
349 MS_EXCEPTION_IF_NULL(context);
350 if (graph_input_backend_parameters_.find(current_data_type_) == graph_input_backend_parameters_.end()) {
351 return;
352 }
353 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
354 if (input_device_tensors_[i] != nullptr && input_device_tensors_[i]->user_data() != nullptr) {
355 MS_EXCEPTION_IF_NULL(input_device_tensors_[i]->kernel_tensor());
356 const auto &user_data_obj = input_device_tensors_[i]->user_data()->get<kernel::PyExecuteOutputUserData>(
357 kernel::PyExecuteOutputUserData::key);
358 MS_EXCEPTION_IF_NULL(user_data_obj);
359 const auto &obj = user_data_obj->obj;
360 auto abstract = pyexecute::GenerateAbstractFromPyObject(obj);
361 MS_EXCEPTION_IF_NULL(abstract);
362 MS_EXCEPTION_IF_NULL(abstract->BuildType());
363 MS_EXCEPTION_IF_NULL(abstract->BuildShape());
364 MS_LOG(DEBUG) << "actor:" << GetAID() << " set shape by abstract:" << abstract->ToString()
365 << " shape:" << abstract->BuildShape()->ToString() << " type:" << abstract->BuildType()->ToString()
366 << " for device address:" << input_device_tensors_[i];
367 input_device_tensors_[i]->kernel_tensor()->SetType(abstract->BuildType());
368 input_device_tensors_[i]->kernel_tensor()->SetShape(abstract->BuildShape());
369 MS_LOG(DEBUG) << "Infer abstract:" << abstract->ToString();
370 }
371 }
372 }
373
374 namespace {
ClearAttrForGraph(const KernelGraphPtr & graph,const std::string & attr_name)375 void ClearAttrForGraph(const KernelGraphPtr &graph, const std::string &attr_name) {
376 MS_EXCEPTION_IF_NULL(graph);
377 for (const auto &node_pair : graph->front_backend_anf_map()) {
378 MS_EXCEPTION_IF_NULL(node_pair.second);
379 if (!node_pair.second->isa<CNode>()) {
380 continue;
381 }
382 MS_LOG(DEBUG) << "Check for node:" << node_pair.second->DebugString() << " attr name:" << attr_name;
383 const auto &cnode = node_pair.second->cast<CNodePtr>();
384 MS_EXCEPTION_IF_NULL(cnode);
385 if (common::AnfAlgo::HasNodeAttr(attr_name, cnode)) {
386 MS_LOG(DEBUG) << "Erase flag for node:" << node_pair.second->DebugString() << " attr name:" << attr_name;
387 common::AnfAlgo::EraseNodeAttr(attr_name, cnode);
388 }
389 }
390 }
391 } // namespace
392
RunForGraphInput(OpContext<DeviceTensor> * const context)393 void AnyTypeKernelActor::RunForGraphInput(OpContext<DeviceTensor> *const context) {
394 MS_EXCEPTION_IF_NULL(context);
395 MS_EXCEPTION_IF_NULL(graph());
396 actor_state_ = AnyTypeKernelActorState::kAnyTypeKernelActorSendInput;
397 MS_LOG(DEBUG) << "Any type kernel actor:" << GetAID() << " run for graph input.";
398 FetchInputDeviceTensor(context);
399 current_data_type_ = GenerateIDForGraph(input_device_tensors_, any_type_parameter_indexes_);
400 MS_LOG(DEBUG) << "Current data type:" << current_data_type_ << " for actor:" << GetAID();
401 vector<AbstractActorPtr> actors;
402 if (real_graphs_.find(current_data_type_) == real_graphs_.end()) {
403 try {
404 std::lock_guard<std::mutex> lock(instance_lock_);
405 InferParameterAbstractForModelGraph(graph(), input_device_tensors_, any_type_parameter_indexes_);
406 ClearAttrForGraph(graph(), kAttrInputIsDynamicShape);
407 ClearAttrForGraph(graph(), kAttrOutputIsDynamicShape);
408 graph()->InferType();
409 const auto &return_node = graph()->get_return();
410 MS_EXCEPTION_IF_NULL(return_node);
411 if (!return_node->isa<CNode>() || return_node->cast<CNodePtr>()->size() <= 1) {
412 MS_LOG_WITH_NODE(EXCEPTION, return_node)
413 << "Invalid return node:" << return_node->DebugString() << " for graph:" << graph()->ToString();
414 }
415 if (device_contexts().empty() || device_contexts()[0] == nullptr) {
416 MS_LOG(EXCEPTION) << "Invalid device context for actor:" << GetAID();
417 }
418 AnfNodePtrList inputs{};
419 AnfNodePtrList outputs{return_node->cast<CNodePtr>()->input(1)};
420 auto io_nodes = std::make_pair(inputs, outputs);
421 auto new_graph =
422 compile_func_(BuildSegmentByGraph(graph()), io_nodes, device_contexts()[0], device::RunMode::kKernelMode);
423 MS_EXCEPTION_IF_NULL(new_graph);
424 MS_LOG(INFO) << "Add new kernel graph:" << new_graph->ToString() << " for graph:" << graph()->ToString();
425 real_graphs_[current_data_type_] = new_graph;
426 actors = transform_func_(graph(), new_graph, device_contexts()[0]);
427 actors_[current_data_type_] = actors;
428 schedule_func_(actors);
429
430 for (const auto &node_pair : new_graph->front_backend_anf_map()) {
431 MS_EXCEPTION_IF_NULL(node_pair.first);
432 if (!node_pair.first->isa<CNode>()) {
433 continue;
434 }
435 MS_LOG(DEBUG) << "Check for node:" << node_pair.first->DebugString();
436 const auto &cnode = node_pair.first->cast<CNodePtr>();
437 MS_EXCEPTION_IF_NULL(cnode);
438 if (cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
439 MS_LOG(DEBUG) << "Erase flag for node:" << node_pair.first->DebugString();
440 cnode->EraseAttr(kAttrReplaceRealKernelInBackend);
441 }
442 }
443 } catch (const std::exception &e) {
444 MsException::Instance().SetException();
445 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), e.what());
446 }
447 }
448 UpdataDynamicShapeParameterForGraphInput(context);
449 EraseInput(context);
450 if (memory_alloc_list_.size() > 0) {
451 MS_LOG(EXCEPTION) << "Any type kernel actor:" << GetAID() << "cannot send memory alloc message.";
452 } else {
453 OnMemoryAllocFinish(context);
454 }
455 }
456
FetchInputIndexByBackendParameter(const AnfNodePtr & backend_node,const KernelGraphPtr & front_graph,const KernelGraphPtr & backend_graph)457 size_t FetchInputIndexByBackendParameter(const AnfNodePtr &backend_node, const KernelGraphPtr &front_graph,
458 const KernelGraphPtr &backend_graph) {
459 MS_EXCEPTION_IF_NULL(backend_node);
460 MS_EXCEPTION_IF_NULL(front_graph);
461 MS_EXCEPTION_IF_NULL(backend_graph);
462 const auto &front_node = backend_graph->GetFrontAnfByBackendAnf(backend_node);
463 MS_EXCEPTION_IF_NULL(front_node);
464 const auto &front_parameters = front_graph->input_nodes();
465 const auto &iter = find(front_parameters.begin(), front_parameters.end(), front_node);
466 if (iter == front_parameters.end()) {
467 MS_LOG_WITH_NODE(EXCEPTION, front_node)
468 << "Invalid front parameter:" << front_node->DebugString() << " for graph:" << front_graph->ToString();
469 }
470 return iter - front_parameters.begin();
471 }
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)472 void AnyTypeKernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
473 MS_EXCEPTION_IF_NULL(graph());
474 if (real_graphs_.find(current_data_type_) == real_graphs_.end()) {
475 MS_LOG(EXCEPTION) << "Invalid index:" << current_data_type_ << " for any type kernel actor:" << GetAID();
476 }
477 const auto &real_graph = real_graphs_[current_data_type_];
478 MS_EXCEPTION_IF_NULL(real_graph);
479 if (real_graph->input_nodes().size() != graph()->input_nodes().size()) {
480 MS_LOG(EXCEPTION) << "Invalid input node num:" << real_graph->input_nodes().size()
481 << " in graph:" << real_graph->ToString() << " for model graph:" << graph()->ToString()
482 << " input num:" << graph()->input_nodes().size() << " for actor:" << GetAID();
483 }
484 for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
485 const auto &input_node = real_graph->input_nodes()[i];
486 MS_EXCEPTION_IF_NULL(input_node);
487 if (HasAbstractMonad(input_node)) {
488 continue;
489 }
490 size_t from_index = FetchInputIndexByBackendParameter(input_node, graph(), real_graph);
491 if (!AnfAlgo::OutputAddrExist(input_node, 0, false)) {
492 MS_LOG_WITH_NODE(EXCEPTION, input_node)
493 << "Input node:" << input_node->DebugString() << " has no device address for actor:" << GetAID();
494 }
495 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
496 MS_EXCEPTION_IF_NULL(device_address);
497 if (from_index >= node_device_tensors_.size() || from_index >= input_device_tensors_.size()) {
498 MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
499 << " node device tensor size:" << node_device_tensors_.size()
500 << " input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
501 }
502 node_device_tensors_[from_index] = device_address;
503 if (input_device_tensors_[from_index] == nullptr) {
504 MS_LOG_WITH_NODE(EXCEPTION, input_node)
505 << "actor:" << GetAID() << " real graph:" << real_graph->ToString()
506 << " input node:" << input_node->DebugString() << " index : " << i << " is nullptr ";
507 }
508 node_device_tensors_[from_index]->SetNodeIndex(input_device_tensors_[from_index]->node_index().first.lock(),
509 input_device_tensors_[from_index]->node_index().second);
510 MS_LOG(DEBUG) << "Actor:" << GetAID() << " input " << from_index << ":"
511 << " device address:" << device_address
512 << " original ref count:" << device_address->original_ref_count()
513 << " ref count:" << device_address->ref_count()
514 << " dynamic ref count:" << device_address->dynamic_ref_count()
515 << " real shape:" << node_device_tensors_[from_index]->kernel_tensor()->GetShape()->ToString()
516 << " model shape:" << input_device_tensors_[from_index]->kernel_tensor()->GetShape()->ToString();
517 }
518 if (node_device_tensors_.size() != input_device_tensors_.size()) {
519 MS_LOG(EXCEPTION) << "Invalid device tensor num:" << input_device_tensors_.size() << " and "
520 << node_device_tensors_.size() << " for actor:" << GetAID();
521 }
522 for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
523 if (node_device_tensors_[i] != nullptr && input_device_tensors_[i] != nullptr) {
524 MS_EXCEPTION_IF_NULL(input_device_tensors_[i]->kernel_tensor());
525 MS_EXCEPTION_IF_NULL(node_device_tensors_[i]->kernel_tensor());
526 MS_LOG(DEBUG) << "set shape:"
527 << (input_device_tensors_[i]->kernel_tensor()->GetShape() == nullptr
528 ? "null"
529 : input_device_tensors_[i]->kernel_tensor()->GetShape()->ToString())
530 << " type:"
531 << (input_device_tensors_[i]->kernel_tensor()->GetType() == nullptr
532 ? "null"
533 : input_device_tensors_[i]->kernel_tensor()->GetType()->ToString())
534 << " from device address:" << input_device_tensors_[i]
535 << " to device address:" << node_device_tensors_[i];
536 node_device_tensors_[i]->kernel_tensor()->SetType(input_device_tensors_[i]->kernel_tensor()->GetType());
537 node_device_tensors_[i]->kernel_tensor()->SetShape(input_device_tensors_[i]->kernel_tensor()->GetShape());
538 MS_LOG(DEBUG) << "set shape:" << input_device_tensors_[i]->kernel_tensor()->GetShape()->ToString()
539 << " from device address:" << input_device_tensors_[i]
540 << " to device address:" << node_device_tensors_[i];
541 }
542 }
543 CopyInputData(context, real_graphs_[current_data_type_]);
544 if (!memory_free_lists_.empty()) {
545 for (size_t i = 0; i < node_device_tensors_.size(); ++i) {
546 if (node_device_tensors_[i] != nullptr) {
547 memory_free_lists_.back().emplace_back(node_device_tensors_[i].get());
548 }
549 }
550 }
551 SendOutput(context);
552 }
553
EraseGraphOutput(OpContext<DeviceTensor> * const context)554 void AnyTypeKernelActor::EraseGraphOutput(OpContext<DeviceTensor> *const context) {
555 MS_EXCEPTION_IF_NULL(context);
556 if ((graph_output_data_num_[current_data_type_] != 0) && (!graph_output_op_data_.empty())) {
557 auto ret = graph_output_op_data_.erase(context->sequential_num_);
558 if (ret == 0) {
559 MS_LOG(WARNING) << "Erase graph output data failed: " << GetAID().Name()
560 << ", sequential_num: " << context->sequential_num_;
561 return;
562 }
563 }
564
565 if ((graph_output_control_num_[current_data_type_] != 0) && (!graph_output_op_control_.empty())) {
566 auto ret = graph_output_op_control_.erase(context->sequential_num_);
567 if (ret == 0) {
568 MS_LOG(WARNING) << "Erase graph output controls failed: " << GetAID().Name()
569 << ", sequential_num: " << context->sequential_num_;
570 return;
571 }
572 }
573 }
574
RunForGraphOutput(OpContext<DeviceTensor> * const context)575 void AnyTypeKernelActor::RunForGraphOutput(OpContext<DeviceTensor> *const context) {
576 MS_LOG(DEBUG) << "actor:" << GetAID() << " run for graph output start";
577 actor_state_ = AnyTypeKernelActorState::kAnyTypeKernelActorSendOutput;
578 FetchGraphOutput(context);
579 EraseGraphOutput(context);
580 SendMemoryFreeReq(context);
581 AbstractActor::SendOutput(context);
582 }
583
Init()584 void AnyTypeKernelActor::Init() {
585 MS_EXCEPTION_IF_NULL(graph());
586 MS_LOG(DEBUG) << "actor:" << GetAID() << " init";
587 SuperKernelActor::Init();
588 memory_alloc_list_.clear();
589 for (size_t i = 0; i < graph()->input_nodes().size(); ++i) {
590 const auto &input = graph()->input_nodes()[i];
591 MS_EXCEPTION_IF_NULL(input);
592 const auto &abs = input->abstract();
593 MS_EXCEPTION_IF_NULL(abs);
594 if (abs->isa<abstract::AbstractAny>()) {
595 any_type_parameter_indexes_.emplace_back(i);
596 MS_LOG(DEBUG) << "Add any type parameter index:" << i << " by parameter:" << input->DebugString()
597 << " for actor:" << GetAID();
598 }
599 }
600 for (const auto &node_with_index : common::AnfAlgo::GetAllOutputWithOutMonadAndParameter(graph()->output())) {
601 MS_EXCEPTION_IF_NULL(node_with_index.first);
602 if (!AnfAlgo::OutputAddrExist(node_with_index.first, node_with_index.second)) {
603 MS_LOG_WITH_NODE(EXCEPTION, node_with_index.first)
604 << "Failed to get output address from node:" << node_with_index.first->DebugString()
605 << " index:" << node_with_index.second << " for actor:" << GetAID();
606 }
607 graph_ouput_device_tensors_.emplace_back(
608 AnfAlgo::GetMutableOutputAddr(node_with_index.first, node_with_index.second, false).get());
609 }
610 fallback_device_tensors_.resize(graph_ouput_device_tensors_.size());
611 }
612
613 namespace {
FreeMemory(DeviceTensor * device_tensor)614 void FreeMemory(DeviceTensor *device_tensor) {
615 MS_EXCEPTION_IF_NULL(device_tensor);
616 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
617 {device_tensor->device_name(), device_tensor->device_id()});
618 if (device_context == nullptr || device_context->device_res_manager_ == nullptr) {
619 return;
620 }
621 MS_LOG(DEBUG) << "Device tensor:" << device_tensor << " release memory:" << device_tensor->GetMutablePtr();
622 device_context->device_res_manager_->FreeMemory(device_tensor->GetMutablePtr());
623 device_tensor->set_ptr(nullptr);
624 }
625 } // namespace
626
CheckParams(OpContext<DeviceTensor> * const context)627 void AnyTypeKernelActor::CheckParams(OpContext<DeviceTensor> *const context) {
628 MS_EXCEPTION_IF_NULL(context);
629 MS_EXCEPTION_IF_NULL(graph());
630 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
631 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
632 "Invalid device context for any type actor:" + GetAID().Name());
633 }
634 }
635
FetchGraphOutput(OpContext<DeviceTensor> * const context)636 void AnyTypeKernelActor::FetchGraphOutput(OpContext<DeviceTensor> *const context) {
637 CheckParams(context);
638 const auto &data_iter = graph_output_op_data_.find(context->sequential_num_);
639 if (data_iter != graph_output_op_data_.end()) {
640 std::set<DeviceTensor *> clear_device_tensors;
641 for (auto &graph_output_data : data_iter->second) {
642 MS_EXCEPTION_IF_NULL(graph_output_data);
643 MS_EXCEPTION_IF_NULL(graph_output_data->data_);
644 size_t index = IntToSize(graph_output_data->index_);
645 if (index < graph()->input_nodes().size()) {
646 MS_LOG(WARNING) << "Invalid graph output index:" << index << " input num:" << input_datas_num_
647 << " for actor:" << GetAID();
648 continue;
649 }
650 index -= graph()->input_nodes().size();
651 if (index >= graph_ouput_device_tensors_.size() ||
652 graph_ouput_device_tensors_.size() != fallback_device_tensors_.size()) {
653 std::string error_info = "Invalid input index:" + std::to_string(index) +
654 " total:" + std::to_string(graph_ouput_device_tensors_.size()) +
655 " for actor:" + GetAID().Name();
656 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
657 }
658 MS_LOG(DEBUG) << "Fetch graph output index:" << index << " set ptr:" << graph_output_data->data_->GetMutablePtr()
659 << " size:" << graph_output_data->data_->GetSize()
660 << " from device address:" << graph_output_data->data_
661 << " to:" << graph_ouput_device_tensors_[index] << " for actor:" << GetAID();
662 MS_EXCEPTION_IF_NULL(graph_ouput_device_tensors_[index]);
663 if (graph_ouput_device_tensors_[index]->GetDeviceType() != graph_output_data->data_->GetDeviceType()) {
664 MS_LOG(INFO) << "Different device type for actor:" << GetAID()
665 << " front device address:" << graph_ouput_device_tensors_[index]
666 << " device type:" << graph_ouput_device_tensors_[index]->GetDeviceType()
667 << " backend device address:" << graph_output_data->data_
668 << " device type:" << graph_output_data->data_->GetDeviceType();
669 if (fallback_device_tensors_[index] != nullptr) {
670 if (fallback_device_tensors_[index]->GetDeviceType() != graph_output_data->data_->GetDeviceType()) {
671 MS_LOG(ERROR) << "Invalid device type for actor:" << GetAID()
672 << " fallback device address:" << fallback_device_tensors_[index]
673 << " device type:" << fallback_device_tensors_[index]->GetDeviceType()
674 << " backend device address:" << graph_output_data->data_
675 << " device type:" << graph_output_data->data_->GetDeviceType();
676 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), GetAID().Name() + " invalid device type.");
677 }
678 } else {
679 auto tmp_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
680 {graph_output_data->data_->device_name(), graph_output_data->data_->device_id()});
681 MS_EXCEPTION_IF_NULL(tmp_device_context);
682
683 const auto &graph_output_kernel_tensor = graph_output_data->data_->kernel_tensor();
684 MS_EXCEPTION_IF_NULL(graph_output_kernel_tensor);
685 const auto &fallback_kernel_tensor = graph_output_kernel_tensor->CloneKernelTensor();
686 MS_EXCEPTION_IF_NULL(fallback_kernel_tensor);
687 fallback_kernel_tensor->set_device_ptr(nullptr);
688 fallback_device_tensors_[index] =
689 tmp_device_context->device_res_manager_->CreateDeviceAddress(fallback_kernel_tensor);
690 MS_EXCEPTION_IF_NULL(fallback_device_tensors_[index]);
691 MS_LOG(DEBUG) << "Create device address:" << fallback_device_tensors_[index] << " for actor:" << GetAID()
692 << " index:" << index << " device type:" << fallback_device_tensors_[index]->GetDeviceType()
693 << " size:" << fallback_device_tensors_[index]->GetSize();
694 fallback_device_tensors_[index]->set_ref_count(graph_ouput_device_tensors_[index]->ref_count());
695 fallback_device_tensors_[index]->set_original_ref_count(
696 graph_ouput_device_tensors_[index]->original_ref_count());
697 fallback_device_tensors_[index]->set_dynamic_ref_count(
698 graph_ouput_device_tensors_[index]->dynamic_ref_count());
699 }
700 graph_ouput_device_tensors_[index] = fallback_device_tensors_[index].get();
701 }
702 if (graph_ouput_device_tensors_[index]->GetPtr() != nullptr) {
703 // As the from memory pool flag of any type kernel graph is false, the memory cannot be released automatically,
704 // and the memory needs to be released before overwriting.
705 FreeMemory(graph_ouput_device_tensors_[index]);
706 }
707 graph_ouput_device_tensors_[index]->set_ptr(graph_output_data->data_->GetMutablePtr());
708 graph_ouput_device_tensors_[index]->set_need_sync_user_data(graph_output_data->data_->need_sync_user_data());
709 clear_device_tensors.emplace(graph_output_data->data_);
710 graph_ouput_device_tensors_[index]->SetSize(graph_output_data->data_->GetSize());
711
712 // Update Shape.
713 const auto &graph_output_device_kernel_tensor = graph_ouput_device_tensors_[index]->kernel_tensor();
714 const auto &graph_output_data_kernel_tensor = graph_output_data->data_->kernel_tensor();
715 MS_EXCEPTION_IF_NULL(graph_output_device_kernel_tensor);
716 MS_EXCEPTION_IF_NULL(graph_output_data_kernel_tensor);
717 MS_LOG(DEBUG) << "actor:" << GetAID() << " set shape from device address:" << graph_output_data->data_
718 << " to:" << graph_ouput_device_tensors_[index]
719 << " for shape:" << graph_output_data_kernel_tensor->GetShape()->ToString();
720 graph_output_device_kernel_tensor->SetType(graph_output_data_kernel_tensor->GetType()->Clone());
721 graph_output_device_kernel_tensor->SetShape(graph_output_data_kernel_tensor->GetShape()->Clone());
722
723 auto node_with_index = graph_output_data->data_->node_index();
724 graph_ouput_device_tensors_[index]->SetNodeIndex(node_with_index.first.lock(), node_with_index.second);
725 MS_LOG(DEBUG) << "Actor:" << GetAID() << "src device address:" << graph_output_data->data_
726 << " shape:" << graph_output_data->data_->host_shape()
727 << " type:" << graph_output_data->data_->type_id()
728 << "dst device address:" << graph_ouput_device_tensors_[index]
729 << " shape:" << graph_ouput_device_tensors_[index]->host_shape()
730 << " type:" << graph_ouput_device_tensors_[index]->type_id();
731 graph_ouput_device_tensors_[index]->set_type_id(graph_output_data->data_->type_id());
732 graph_ouput_device_tensors_[index]->set_host_shape(graph_output_data->data_->host_shape());
733 graph_ouput_device_tensors_[index]->set_user_data(graph_output_data->data_->user_data());
734 }
735 for_each(clear_device_tensors.begin(), clear_device_tensors.end(),
736 [](DeviceTensor *device_tensor) { device_tensor->set_ptr(nullptr); });
737 }
738 }
739
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr & data_arrow,const AnfNodePtr & output_node,OpContext<DeviceTensor> * const context)740 void AnyTypeKernelActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
741 const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) {
742 MS_EXCEPTION_IF_NULL(output_data);
743 MS_EXCEPTION_IF_NULL(data_arrow);
744 MS_EXCEPTION_IF_NULL(output_node);
745 MS_EXCEPTION_IF_NULL(context);
746 MS_EXCEPTION_IF_NULL(graph());
747 if (actor_state_ == AnyTypeKernelActorState::kAnyTypeKernelActorSendOutput) {
748 size_t index = IntToSize(data_arrow->from_output_index_);
749 const auto &real_output = common::AnfAlgo::GetAllOutputWithOutMonadAndParameter(graph()->output());
750 const auto &output_iter = find(real_output.begin(), real_output.end(), std::make_pair(output_node, index));
751 if (output_iter == real_output.end()) {
752 MS_LOG_WITH_NODE(EXCEPTION, output_node) << "Invalid output node:" << output_node->DebugString()
753 << " index:" << index << " for graph:" << graph()->ToString();
754 }
755 size_t real_output_index = LongToSize(output_iter - real_output.begin());
756 if (real_output_index >= graph_ouput_device_tensors_.size()) {
757 MS_LOG_WITH_NODE(EXCEPTION, output_node)
758 << "Invalid input index:" << real_output_index << " by node:" << output_node->DebugString()
759 << " for actor:" << GetAID();
760 }
761 MS_LOG(DEBUG) << "actor:" << GetAID() << " output node:" << output_node->DebugString()
762 << " to actor:" << data_arrow->to_op_id_ << " from index:" << real_output_index;
763 MS_EXCEPTION_IF_NULL(graph_ouput_device_tensors_[real_output_index]);
764 output_data->data_ = graph_ouput_device_tensors_[real_output_index];
765 return;
766 }
767
768 const auto &real_graph = real_graphs_[current_data_type_];
769 MS_EXCEPTION_IF_NULL(real_graph);
770 const auto &front_node = real_graph->GetFrontAnfByBackendAnf(output_node);
771 MS_EXCEPTION_IF_NULL(front_node);
772 const auto &model_graph = SuperKernelActor::graph();
773 MS_EXCEPTION_IF_NULL(model_graph);
774 auto &input_nodes = model_graph->input_nodes();
775 const auto &iter = find(input_nodes.begin(), input_nodes.end(), front_node);
776 if (iter == input_nodes.end()) {
777 MS_LOG_WITH_NODE(EXCEPTION, output_node)
778 << "Invalid input node:" << output_node->DebugString() << " front node:" << front_node->DebugString();
779 }
780 size_t index = LongToSize(iter - input_nodes.begin());
781 if (index >= node_device_tensors_.size()) {
782 MS_LOG_WITH_NODE(EXCEPTION, output_node)
783 << "Invalid input index:" << index << " by node:" << output_node->DebugString() << " for actor:" << GetAID();
784 }
785 if (node_device_tensors_[index] == nullptr) {
786 MS_LOG(EXCEPTION) << "failed to get input index:" << index << " for actor:" << GetAID();
787 }
788 output_data->data_ = node_device_tensors_[index].get();
789 }
790
SendOutput(OpContext<DeviceTensor> * const context)791 void AnyTypeKernelActor::SendOutput(OpContext<DeviceTensor> *const context) {
792 MS_EXCEPTION_IF_NULL(context);
793 MS_LOG(DEBUG) << "Any type actor:" << GetAID() << " send output";
794 // Must be the execution order: send data --> send control, avoid the illegal timing problem.
795 SendOutputData(context, graph_input_data_nodes_[current_data_type_], graph_input_data_arrows_[current_data_type_],
796 graph_input_data_[current_data_type_], data_arrow_to_graph_input_actor_indexs_[current_data_type_],
797 &batch_graph_input_data_[current_data_type_]);
798
799 // 2.Send output control.
800 if (graph_input_control_arrows_[current_data_type_].size() > 0) {
801 auto from_aid = const_cast<AID *>(&GetAID());
802 for (auto &output_control : graph_input_control_arrows_[current_data_type_]) {
803 MS_EXCEPTION_IF_NULL(output_control);
804 if (TEST_FLAG(output_control->flag_, kOutputDataFlagBetweenFusion)) {
805 const auto &to_actor = FetchSubActorInFusionActor(output_control->to_op_id_.Name());
806 ActorDispatcher::SendSync(to_actor, &OpActor::RunOpControl, from_aid, context);
807 } else {
808 ActorDispatcher::Send(output_control->to_op_id_, &OpActor::RunOpControl, from_aid, context);
809 }
810 }
811 }
812 }
813 } // namespace runtime
814 } // namespace mindspore
815