1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h"
18 #include "runtime/hardware/device_context_manager.h"
19 #include "include/backend/mem_reuse/mem_tracker.h"
20 #include "ops/framework_ops.h"
21 #include "utils/profile.h"
22
23 namespace mindspore {
24 namespace runtime {
ControlActor(const std::string & name,KernelTransformType type,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const AnfNodePtr & node)25 ControlActor::ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid,
26 const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node)
27 : MemoryAwareActor(name, type, nullptr, memory_manager_aid), formal_parameters_(parameters), node_(node) {
28 input_partials_.resize(parameters.size());
29 input_device_tensors_.resize(parameters.size());
30 backend_parameters_.resize(parameters.size());
31 output_data_by_output_index_.resize(parameters.size());
32 }
33
Init()34 void ControlActor::Init() {
35 InitOutputData();
36 if (output_data_.size() != output_data_arrows_.size()) {
37 MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
38 }
39
40 size_t output_data_index = 0;
41 for (auto &data_arrow : output_data_arrows_) {
42 auto data = output_data_[output_data_index].first.get();
43 MS_EXCEPTION_IF_NULL(data);
44 MS_EXCEPTION_IF_NULL(data_arrow);
45 if (IntToSize(data_arrow->from_output_index_) >= output_data_by_output_index_.size()) {
46 MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID();
47 }
48 (void)output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(data);
49 ++output_data_index;
50 }
51 }
52
GetAllDeviceTensors(const OpPartialPtr & op_partial,std::vector<DeviceTensor * > * device_tensors)53 void ControlActor::GetAllDeviceTensors(const OpPartialPtr &op_partial, std::vector<DeviceTensor *> *device_tensors) {
54 MS_EXCEPTION_IF_NULL(op_partial);
55 (void)std::transform(op_partial->device_tensors_.begin(), op_partial->device_tensors_.end(),
56 std::back_inserter(*device_tensors),
57 [](const auto &device_tensor) { return device_tensor.second; });
58
59 // Foreach the op partial to fetch the device tensors.
60 for (auto &partial : op_partial->partials_) {
61 GetAllDeviceTensors(partial.second, device_tensors);
62 }
63 }
64
GetAllDeviceTensors(const OpRealParameterWithBranchID & op_real_parameter,std::vector<DeviceTensor * > * device_tensors)65 void ControlActor::GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter,
66 std::vector<DeviceTensor *> *device_tensors) {
67 MS_EXCEPTION_IF_NULL(device_tensors);
68 for (auto &device_tensor : op_real_parameter.device_tensors_) {
69 (void)device_tensors->emplace_back(device_tensor.second);
70 }
71
72 // Foreach the op partial to fetch the device tensors.
73 for (auto &partial : op_real_parameter.partials_) {
74 GetAllDeviceTensors(partial.second, device_tensors);
75 }
76 }
77
IncreaseDynamicRefCount(const OpData<DeviceTensor> * op_data) const78 void ControlActor::IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) const {
79 MS_EXCEPTION_IF_NULL(op_data);
80 MS_EXCEPTION_IF_NULL(op_data->data_);
81 op_data->data_->IncreaseDynamicRefCount(GetAID().Name());
82 }
83
IncreaseDynamicRefCount(const OpPartialPtr & op_partial)84 void ControlActor::IncreaseDynamicRefCount(const OpPartialPtr &op_partial) {
85 if (op_partial == nullptr) {
86 MS_LOG(EXCEPTION) << "Empty op partial for actor:" << GetAID();
87 }
88 std::vector<DeviceTensor *> partial_device_tensors;
89 GetAllDeviceTensors(op_partial, &partial_device_tensors);
90 for (auto &partial_device_tensor : partial_device_tensors) {
91 MS_EXCEPTION_IF_NULL(partial_device_tensor);
92 partial_device_tensor->IncreaseDynamicRefCount(GetAID().Name());
93 }
94 }
95
IncreaseDynamicRefCount(const OpRealParameterWithBranchID & op_real_parameter)96 void ControlActor::IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter) {
97 std::vector<DeviceTensor *> partial_device_tensors;
98 GetAllDeviceTensors(op_real_parameter, &partial_device_tensors);
99 for (auto &partial_device_tensor : partial_device_tensors) {
100 MS_EXCEPTION_IF_NULL(partial_device_tensor);
101 partial_device_tensor->IncreaseDynamicRefCount(GetAID().Name());
102 }
103 }
104
FetchNodePosition(const KernelWithIndex & node) const105 size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
106 const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
107 if (iter == formal_parameters_.end()) {
108 const auto &load_iter =
109 std::find_if(formal_parameters_.begin(), formal_parameters_.end(), [&node](const KernelWithIndex &pair) {
110 return pair.first != nullptr && common::AnfAlgo::CheckPrimitiveType(pair.first, prim::kPrimLoad) &&
111 pair.first->cast<CNodePtr>()->input(1) == node.first && node.second == 0;
112 });
113 if (load_iter != formal_parameters_.end()) {
114 return load_iter - formal_parameters_.begin();
115 }
116 for (const auto &formal_parameter : formal_parameters_) {
117 MS_LOG(WARNING) << "Actor:" << GetAID() << " formal parameter:"
118 << (formal_parameter.first != nullptr ? formal_parameter.first->DebugString() : "")
119 << " index:" << formal_parameter.second << " node ptr:" << formal_parameter.first;
120 }
121 MS_LOG_WITH_NODE(EXCEPTION, node.first)
122 << "Invalid formal parameter:" << (node.first != nullptr ? node.first->DebugString() : "")
123 << " node ptr:" << node.first << " index:" << node.second << " for actor:" << GetAID();
124 }
125 return iter - formal_parameters_.begin();
126 }
127
Run(OpContext<DeviceTensor> * const context)128 void ControlActor::Run(OpContext<DeviceTensor> *const context) {
129 try {
130 // The exit actor is the output of kernel graph when the node_ is null.
131 if (type_ == KernelTransformType::kExitActor && node_ == nullptr) {
132 double end_time = GetTime();
133 const size_t kSecondsToMilliseconds = 1000;
134 MS_LOG(DEBUG) << "Kernel graph group exit actor:" << GetAID()
135 << " cost time:" << (end_time - start_time_) * kSecondsToMilliseconds;
136 }
137
138 FetchInput(context);
139 if (IsRunningFailed(context)) {
140 MS_LOG(INFO) << "Run failed and early stop.";
141 return;
142 }
143
144 // Note that IncreaseDynamicRefCounts must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
145 // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
146 // incremented.
147 IncreaseDynamicRefCounts(context);
148 SendMemoryFreeReq(context);
149
150 EraseInput(context);
151 SendOutput(context);
152 } catch (const std::exception &e) {
153 MsException::Instance().SetException();
154 std::string error_info = "Actor fun failed:" + GetAID().Name();
155 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
156 }
157 }
158
RunOpPartial(const OpPartialPtr & partial,size_t position,OpContext<DeviceTensor> * const context)159 void ControlActor::RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) {
160 MS_EXCEPTION_IF_NULL(context);
161 auto &sequential_num = context->sequential_num_;
162 (void)input_op_partials_[sequential_num].emplace_back(position, partial);
163
164 auto is_run = CheckRunningCondition(context);
165 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
166 << ") receive the input op partial and check running condition:" << is_run;
167 if (is_run) {
168 Run(context);
169 }
170 }
171
RunBranchID(int branch_id,OpContext<DeviceTensor> * const context)172 void ControlActor::RunBranchID(int branch_id, OpContext<DeviceTensor> *const context) {
173 MS_EXCEPTION_IF_NULL(context);
174 auto &sequential_num = context->sequential_num_;
175 input_branch_ids_[sequential_num].push(branch_id);
176
177 auto is_run = CheckRunningCondition(context);
178 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
179 << ") receive the input branch id and check running condition:" << is_run;
180 if (is_run) {
181 Run(context);
182 }
183 }
184
CheckRunningCondition(const OpContext<DeviceTensor> * context) const185 bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
186 MS_EXCEPTION_IF_NULL(context);
187
188 if (!AbstractActor::CheckRunningCondition(context)) {
189 return false;
190 }
191
192 if (input_partials_num_ != 0) {
193 const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
194 if (partial_iter == input_op_partials_.end()) {
195 return false;
196 }
197 if (partial_iter->second.size() < input_partials_num_) {
198 return false;
199 } else if (partial_iter->second.size() > input_partials_num_) {
200 MS_LOG(ERROR) << "Invalid input partial num:" << partial_iter->second.size() << " need:" << input_partials_num_
201 << " for actor:" << GetAID();
202 return false;
203 }
204 }
205 return true;
206 }
207
FetchInput(OpContext<DeviceTensor> * const context)208 void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
209 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
210 MS_EXCEPTION_IF_NULL(context);
211
212 // Fetch input device tensor from input data.
213 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
214 if (data_iter != input_op_datas_.end()) {
215 for (auto &input_data : data_iter->second) {
216 MS_EXCEPTION_IF_NULL(input_data);
217 if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
218 std::string error_info = "Invalid index, need:" + std::to_string(input_data->index_) +
219 " current:" + std::to_string(input_device_tensors_.size()) +
220 " for actor:" + GetAID().Name();
221 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
222 }
223 MS_EXCEPTION_IF_NULL(input_data->data_);
224 input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
225 }
226 }
227
228 // Fetch input device tensor from local device tensor.
229 for (auto &local_device_tensor : local_device_tensors_) {
230 MS_EXCEPTION_IF_NULL(local_device_tensor.second);
231 if (local_device_tensor.first >= input_device_tensors_.size()) {
232 std::string error_info = "Invalid local index:" + std::to_string(local_device_tensor.first) +
233 " current:" + std::to_string(local_device_tensors_.size()) +
234 " for actor:" + GetAID().Name();
235 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
236 }
237 input_device_tensors_[local_device_tensor.first] = local_device_tensor.second;
238 }
239
240 // Fetch input device tensor from device tensor store.
241 for (auto &device_tensor_store_key : device_tensor_store_keys_) {
242 auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
243 if (device_tensors.empty()) {
244 auto &device_context = device_contexts_[device_tensor_store_key.first];
245 MS_EXCEPTION_IF_NULL(device_context);
246 MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
247 std::string error_info = GetAID().Name() +
248 " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
249 ", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceType()));
250 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
251 }
252
253 if (device_tensor_store_key.first >= input_device_tensors_.size()) {
254 std::string error_info =
255 "The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
256 " current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
257 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
258 }
259 MS_EXCEPTION_IF_NULL(device_tensors[0]);
260 input_device_tensors_[device_tensor_store_key.first] = device_tensors[0].get();
261 }
262
263 for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
264 if (output_data_by_output_index_[i].empty()) {
265 continue;
266 }
267 const auto &data = input_device_tensors_[i];
268 MS_EXCEPTION_IF_NULL(data);
269 for (auto &output_data : output_data_by_output_index_[i]) {
270 MS_EXCEPTION_IF_NULL(output_data);
271 output_data->data_ = data;
272 }
273 }
274
275 // Fetch input partial from input data.
276 const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
277 if (partial_iter != input_op_partials_.end()) {
278 for (const auto &input_partial : partial_iter->second) {
279 if (input_partial.first >= input_partials_.size()) {
280 std::string error_info = "Invalid partial index:" + std::to_string(input_partial.first) +
281 " vector size:" + std::to_string(input_partials_.size()) +
282 " for actor:" + GetAID().Name();
283 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
284 }
285 input_partials_[input_partial.first] = input_partial.second;
286 }
287 }
288 // Fetch input partial from local partial.
289 for (const auto &local_partial : local_partials_) {
290 if (local_partial.first >= input_partials_.size()) {
291 std::string error_info = "Invalid partial index:" + std::to_string(local_partial.first) +
292 " vector size:" + std::to_string(input_partials_.size()) +
293 " for actor:" + GetAID().Name();
294 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
295 }
296 MS_EXCEPTION_IF_NULL(local_partial.second);
297 input_partials_[local_partial.first] = local_partial.second;
298 }
299 // Fetch branch id in stack.
300 auto iter = input_branch_ids_.find(context->sequential_num_);
301 if (iter != input_branch_ids_.end() && (!iter->second.empty())) {
302 output_branch_id_ = iter->second.top();
303 }
304 }
305
IncreaseDynamicRefCounts(OpContext<DeviceTensor> * const context)306 void ControlActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
307 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
308 MS_EXCEPTION_IF_NULL(context);
309 // Increase dynamic ref count by the output data.
310 for (size_t i = 0; i < output_data_.size(); ++i) {
311 MS_EXCEPTION_IF_NULL(output_data_[i].first);
312 if (output_data_[i].first->data_ == nullptr) {
313 std::string error_info = GetAID().Name() + " fetches data null, data index:" + std::to_string(i) +
314 " to actor:" + output_data_[i].first->op_id_.Name() +
315 " index:" + std::to_string(output_data_[i].first->index_);
316 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317 }
318 IncreaseDynamicRefCount(output_data_[i].first.get());
319 }
320
321 // Increase dynamic ref count by the output partial.
322 for (const auto &output_partial_arrow : output_partial_arrows_) {
323 MS_EXCEPTION_IF_NULL(output_partial_arrow);
324 if (IntToSize(output_partial_arrow->from_output_index_) >= input_partials_.size()) {
325 std::string error_info = "Invalid partial input:" + std::to_string(output_partial_arrow->from_output_index_) +
326 " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
327 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
328 }
329 auto output_partial = input_partials_[IntToSize(output_partial_arrow->from_output_index_)];
330 IncreaseDynamicRefCount(output_partial);
331 }
332 }
333
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)334 void ControlActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
335 MS_EXCEPTION_IF_NULL(context);
336 const auto &sequential_num = context->sequential_num_;
337
338 // Collect the input device tensors.
339 std::vector<DeviceTensor *> memory_free_list;
340 if (input_op_datas_.count(sequential_num) > 0) {
341 for (auto &input_op_data : input_op_datas_[sequential_num]) {
342 MS_EXCEPTION_IF_NULL(input_op_data);
343 MS_EXCEPTION_IF_NULL(input_op_data->data_);
344 (void)memory_free_list.emplace_back(input_op_data->data_);
345 }
346 }
347
348 if (input_op_partials_.count(sequential_num) > 0) {
349 for (auto &input_op_partial : input_op_partials_[sequential_num]) {
350 GetAllDeviceTensors(input_op_partial.second, &memory_free_list);
351 }
352 }
353
354 if (memory_free_list.size() > 0) {
355 memory_free_lists_.push(memory_free_list);
356 if (ActorDispatcher::is_memory_free_sync()) {
357 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
358 device_contexts_[0], context, GetAID());
359 } else {
360 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
361 device_contexts_[0], context, GetAID());
362 }
363 }
364 }
365
EraseInput(const OpContext<DeviceTensor> * context)366 void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) {
367 MS_EXCEPTION_IF_NULL(context);
368 const auto &sequential_num = context->sequential_num_;
369 AbstractActor::EraseInput(context);
370 if (input_partials_num_ != 0) {
371 auto ret = input_op_partials_.erase(sequential_num);
372 if (ret == 0) {
373 std::string error_info = "Erase input partial failed: " + GetAID().Name();
374 // The sequential num may be invalid, can't set the promise value of context.
375 MS_LOG(ERROR) << error_info << ", sequential_num: " << sequential_num;
376 }
377 }
378
379 if (input_branch_ids_.find(sequential_num) != input_branch_ids_.end()) {
380 input_branch_ids_[sequential_num].pop();
381 if (input_branch_ids_[sequential_num].empty()) {
382 auto ret = input_branch_ids_.erase(sequential_num);
383 if (ret == 0) {
384 MS_LOG(ERROR) << "Erase input branch id failed: " << GetAID() << ", sequential_num: " << sequential_num;
385 return;
386 }
387 }
388 }
389 }
390
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr & data_arrow,const AnfNodePtr &,OpContext<DeviceTensor> * const context)391 void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
392 const AnfNodePtr &, OpContext<DeviceTensor> *const context) {
393 MS_EXCEPTION_IF_NULL(output_data);
394 MS_EXCEPTION_IF_NULL(data_arrow);
395 auto formal_parameter_position = data_arrow->from_output_index_;
396 // Has no the ref node formal parameter.
397 if (ref_node_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) {
398 return;
399 }
400
401 MS_EXCEPTION_IF_NULL(context);
402 const auto &data = output_data->data_;
403 MS_EXCEPTION_IF_NULL(data);
404 if ((!data->IsPtrValid()) || (data->ref_count() != SIZE_MAX) || (data->dynamic_ref_count() != INT32_MAX)) {
405 std::string error_info = "The address of the " + std::to_string(formal_parameter_position) +
406 " position real parameter is nullptr or ref count is wrong.";
407 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
408 }
409
410 // Foreach the device tensors to set the ptr from data, only the formal parameter device tensor of ref node need set
411 // before kernel running, because it will be used by ref output node.
412 for (auto &device_tensor : ref_node_formal_parameter_device_tensors_[formal_parameter_position]) {
413 MS_EXCEPTION_IF_NULL(device_tensor);
414 if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
415 continue;
416 }
417 auto formal_parameter = device_tensor->GetNodeIndex();
418 MS_EXCEPTION_IF_NULL(formal_parameter.first);
419 if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->type_id() != data->type_id())) {
420 MS_LOG(WARNING) << "The formal parameter: " << formal_parameter.first->DebugString()
421 << " position:" << formal_parameter_position
422 << "please check the size and type id, formal parameter size:" << device_tensor->GetSize()
423 << " type id:" << device_tensor->type_id() << ", real parameter size:" << data->GetSize()
424 << " type id:" << data->type_id();
425 }
426
427 // Copy from the real parameter to formal parameter and insert the device tensor copy store.
428 if ((!AnfAlgo::IsEquivalentFormat(device_tensor->format(), data->format())) ||
429 (device_tensor->GetDeviceType() != data->GetDeviceType())) {
430 MS_LOG(INFO) << GetAID().Name() << " the input position:" << formal_parameter_position
431 << " copy from real parameter address:" << data << ", type:" << data->GetDeviceType()
432 << ", format:" << data->format() << " to formal parameter address:" << device_tensor.get()
433 << ", type:" << device_tensor->GetDeviceType() << ", format:" << device_tensor->format()
434 << ", formal parameter name:" << formal_parameter.first->DebugString();
435 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
436 {device_tensor->device_name(), device_tensor->device_id()});
437 MS_EXCEPTION_IF_NULL(device_context);
438 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther, 0);
439 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "UpdateOutputData", "");
440 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(), device::tracker::MemType::kOther,
441 device_tensor->GetSize(), device_tensor.get());
442 auto data_stream_id = data->stream_id();
443 auto device_tensor_stream_id = device_tensor->stream_id();
444 if (device_tensor_stream_id != data_stream_id) {
445 MS_LOG(INFO) << "Rewrite device tesnor stream id from : " << device_tensor_stream_id
446 << " to data stream id : " << data_stream_id << ".";
447 device_tensor->set_stream_id(data_stream_id);
448 }
449 if ((device_tensor->GetPtr() == nullptr) &&
450 (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex))) {
451 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
452 formal_parameter.first->DebugString(), device_tensor->GetSize());
453 }
454 if (!Copy(device_tensor.get(), data)) {
455 std::string error_info = "The formal parameter: " + formal_parameter.first->DebugString() +
456 " position:" + std::to_string(formal_parameter_position) + " copy failed.";
457 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
458 }
459 DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
460 output_data->data_ = device_tensor.get();
461 continue;
462 }
463
464 // Ref node may use the ptr of device tensor as the output address, so need set the ptr from data.
465 device_tensor->set_ptr(data->GetMutablePtr());
466 MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
467 << " for the ref formal parameter: " << formal_parameter.first->DebugString()
468 << " in the actor: " << GetAID().Name();
469 }
470 }
471
SendOutput(OpContext<DeviceTensor> * const context)472 void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
473 // Send branch id.
474 for (const auto &branch_id_arrow : output_branch_id_arrows_) {
475 ActorDispatcher::Send(branch_id_arrow, &ControlActor::RunBranchID, output_branch_id_, context);
476 }
477
478 // Send data in base class.
479 AbstractActor::SendOutput(context);
480
481 // Send Partial.
482 for (const auto &partial_arrow : output_partial_arrows_) {
483 MS_EXCEPTION_IF_NULL(partial_arrow);
484 if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
485 std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
486 " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
487 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
488 }
489 auto output_partial = input_partials_[IntToSize(partial_arrow->from_output_index_)];
490 MS_EXCEPTION_IF_NULL(output_partial);
491 ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
492 IntToSize(partial_arrow->to_input_index_), context);
493 }
494
495 // Update the start time in end actor.
496 for (const auto &actor : end_actors_) {
497 MS_EXCEPTION_IF_NULL(actor);
498 actor->set_start_time(GetTime());
499 }
500 }
501 namespace {
CreateRealMakeTuple(const std::vector<DeviceTensor * > & addr_list,const FuncGraphPtr & func_graph)502 CNodePtr CreateRealMakeTuple(const std::vector<DeviceTensor *> &addr_list, const FuncGraphPtr &func_graph) {
503 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRealMakeTuple)};
504 auto new_cnode = func_graph->NewCNode(inputs);
505 std::vector<std::string> formats;
506 MS_EXCEPTION_IF_NULL(new_cnode);
507 std::vector<abstract::AbstractBasePtr> abs_list;
508 for (const auto &addr : addr_list) {
509 MS_EXCEPTION_IF_NULL(addr);
510 auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(addr->type_id()), addr->host_shape());
511 abs_list.emplace_back(abs);
512 formats.emplace_back(addr->format());
513 MS_LOG(DEBUG) << "Create new abstract:" << abs->ToString();
514 }
515 auto tuple_abs = std::make_shared<abstract::AbstractTuple>(abs_list);
516 MS_LOG(DEBUG) << "Create abstract for real make tuple:" << tuple_abs->ToString();
517 // Set dynamic len element abstract to check the abstract is dynamic len.
518 abstract::AbstractBasePtr element_abs = (abs_list.empty() ? std::make_shared<abstract::AbstractTensor>(
519 TypeIdToType(TypeId::kNumberTypeInt64), ShapeVector())
520 : abs_list[0]);
521 tuple_abs->set_dynamic_len_element_abs(element_abs);
522 new_cnode->set_abstract(tuple_abs);
523
524 // Create kernel info for node and set format for it.
525 auto kernel_info = std::make_shared<device::KernelInfo>();
526 MS_EXCEPTION_IF_NULL(kernel_info);
527 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
528 MS_EXCEPTION_IF_NULL(builder);
529 kernel_info->set_select_kernel_build_info(builder->Build());
530 new_cnode->set_kernel_info(kernel_info);
531 builder->SetOutputsFormat(formats);
532 return new_cnode;
533 }
534
CheckDeviceAddressConsist(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,const std::string & actor_name)535 void CheckDeviceAddressConsist(OpContext<DeviceTensor> *const context, const std::vector<DeviceTensor *> &addr_list,
536 const std::string &actor_name) {
537 MS_EXCEPTION_IF_NULL(context);
538 if (addr_list.empty() || addr_list[0] == nullptr) {
539 return;
540 }
541 // Check consistence of device address.
542 const auto &shape = addr_list[0]->host_shape();
543 const auto &size = addr_list[0]->GetSize();
544 const auto &type = addr_list[0]->type_id();
545 const auto &device_name = addr_list[0]->device_name();
546 for (size_t i = 1; i < addr_list.size(); ++i) {
547 MS_EXCEPTION_IF_NULL(addr_list[i]);
548 if (size != addr_list[i]->GetSize() || type != addr_list[i]->type_id()) {
549 MS_LOG(ERROR) << "Failed to merge two device address, addr1:" << addr_list[0] << " size:" << size
550 << " shape:" << shape << " device name:" << device_name << " type:" << type
551 << " addr2:" << addr_list[i] << " size:" << addr_list[i]->GetSize()
552 << " shape:" << addr_list[i]->host_shape() << " device name:" << addr_list[i]->device_name()
553 << " type" << addr_list[i]->type_id() << " for actor:" << actor_name;
554 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Failed to merge two device address");
555 }
556 if (shape != addr_list[i]->host_shape()) {
557 MS_LOG(WARNING) << "Merge two device address with different shape, addr1 shape:" << shape
558 << " addr2 shape:" << addr_list[i]->host_shape() << " for actor:" << actor_name;
559 }
560 }
561 }
562 } // namespace
563
MergeDeviceAddress(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,DeviceTensor ** device_tensor)564 void ControlActor::MergeDeviceAddress(OpContext<DeviceTensor> *const context,
565 const std::vector<DeviceTensor *> &addr_list, DeviceTensor **device_tensor) {
566 MS_EXCEPTION_IF_NULL(context);
567 MS_EXCEPTION_IF_NULL(device_tensor);
568 if (addr_list.empty()) {
569 MergeEmptyAddressDeviceAddress(context, addr_list, device_tensor);
570 return;
571 }
572
573 CheckDeviceAddressConsist(context, addr_list, GetAID().Name());
574 MS_EXCEPTION_IF_NULL(addr_list[0]);
575 const auto &total_size = addr_list[0]->GetSize() * addr_list.size();
576 ShapeVector total_shape = {SizeToLong(addr_list.size())};
577 const auto &shape = addr_list[0]->host_shape();
578 total_shape.insert(total_shape.end(), shape.begin(), shape.end());
579 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
580 {addr_list[0]->device_name(), addr_list[0]->device_id()});
581 MS_EXCEPTION_IF_NULL(device_context);
582 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
583
584 abstract::BaseShapePtrList shape_list(addr_list.size(), addr_list[0]->kernel_tensor()->GetShape());
585 auto tuple_shape = std::make_shared<abstract::TupleShape>(shape_list);
586 TypePtrList type_list(addr_list.size(), addr_list[0]->kernel_tensor()->GetType());
587 auto tuple_type = std::make_shared<Tuple>(type_list);
588 MS_LOG(DEBUG) << "Create kernel tensor by shape:" << tuple_shape->ToString() << " type:" << tuple_type->ToString()
589 << " in device address:" << addr_list[0];
590 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
591 tuple_shape, tuple_type, nullptr, nullptr, total_size, addr_list[0]->format(), addr_list[0]->type_id(), total_shape,
592 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
593 kernel_tensor->set_stream_id(addr_list[0]->stream_id());
594 const auto &new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
595 MS_EXCEPTION_IF_NULL(new_device_tensor);
596
597 MS_LOG(DEBUG) << "Create device tensor:" << new_device_tensor << " type:" << new_device_tensor->type_id();
598 if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
599 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
600 GetAID().Name(), new_device_tensor->GetSize());
601 }
602 MS_EXCEPTION_IF_NULL(new_device_tensor->GetMutablePtr());
603
604 // Create a new real maketuple node for new device address.
605 FuncGraphPtr fg = std::make_shared<FuncGraph>();
606 auto new_cnode = CreateRealMakeTuple(addr_list, fg);
607 AnfAlgo::SetOutputAddr(new_device_tensor, 0, new_cnode.get());
608 created_new_graphs_.emplace_back(fg);
609 created_new_nodes_.emplace_back(new_cnode);
610 new_device_tensor->SetNodeIndex(new_cnode, 0);
611 new_device_tensor->set_from_persistent_mem(addr_list[0]->from_persistent_mem());
612 new_device_tensor->set_dynamic_ref_count(0);
613 new_device_tensor->set_original_ref_count(SIZE_MAX);
614 new_device_tensor->ResetRefCount();
615
616 // Merge device address list into a single device address.
617 auto tmp_kernel_tensor = std::make_shared<kernel::KernelTensor>(
618 new_device_tensor->GetMutablePtr(), addr_list[0]->GetSize(), kernel::GetFormatFromStrToEnum(addr_list[0]->format()),
619 addr_list[0]->type_id(), shape, device_context->device_context_key().device_name_,
620 device_context->device_context_key().device_id_);
621 tmp_kernel_tensor->set_stream_id(addr_list[0]->stream_id());
622 const auto &tmp_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(tmp_kernel_tensor);
623 MS_EXCEPTION_IF_NULL(tmp_device_tensor);
624 MS_LOG(DEBUG) << "Create device tensor:" << tmp_device_tensor << " type:" << tmp_device_tensor->type_id();
625 std::shared_ptr<int64_t> max_task_id_on_stream = nullptr;
626 for (size_t i = 0; i < addr_list.size(); ++i) {
627 auto device_tensor_addr = addr_list[i];
628 auto task_id_on_stream = device_tensor_addr->kernel_tensor()->task_id_on_stream();
629 if (task_id_on_stream != nullptr) {
630 if (max_task_id_on_stream == nullptr) {
631 max_task_id_on_stream = task_id_on_stream;
632 } else {
633 if (*max_task_id_on_stream < *task_id_on_stream) {
634 max_task_id_on_stream = task_id_on_stream;
635 }
636 }
637 }
638 bool ret = false;
639 if (addr_list[i]->device_name() == addr_list[0]->device_name()) {
640 ret = tmp_device_tensor->SyncDeviceToDevice(addr_list[i]);
641 } else if (addr_list[0]->device_name() == kCPUDevice) {
642 ret = addr_list[i]->SyncDeviceToHost(addr_list[i]->GetSize(), tmp_device_tensor->GetMutablePtr());
643 } else if (addr_list[i]->device_name() == kCPUDevice) {
644 ret = tmp_device_tensor->SyncHostToDevice(addr_list[i]->GetSize(), addr_list[i]->GetMutablePtr());
645 } else {
646 MS_LOG(ERROR) << "Invalid device name for addr1:" << addr_list[0] << " name:" << addr_list[0]->device_name()
647 << " and addr2:" << addr_list[i] << " name:" << addr_list[i]->device_name();
648 }
649 if (!ret) {
650 SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
651 }
652 tmp_device_tensor->set_ptr((reinterpret_cast<char *>(tmp_device_tensor->GetMutablePtr())) +
653 addr_list[0]->GetSize());
654 }
655 new_device_tensor->kernel_tensor()->set_task_id_on_stream(max_task_id_on_stream);
656 tmp_device_tensor->set_ptr(nullptr);
657 created_device_tensors_.emplace_back(new_device_tensor);
658 MS_LOG(DEBUG) << "actor:" << GetAID() << " create new device address:" << new_device_tensor
659 << " for addr list size:" << addr_list.size()
660 << " device address shape:" << new_device_tensor->host_shape();
661 (*device_tensor) = new_device_tensor.get();
662 return;
663 }
664
MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,DeviceTensor ** device_tensor)665 void ControlActor::MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> *const context,
666 const std::vector<DeviceTensor *> &addr_list,
667 DeviceTensor **device_tensor) {
668 // Create device address for empty tuple.
669 // Fetch the default device context for empty sequence.
670 auto context_ptr = MsContext::GetInstance();
671 MS_EXCEPTION_IF_NULL(context_ptr);
672 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
673 {context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET), context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
674 MS_EXCEPTION_IF_NULL(device_context);
675 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
676
677 auto tuple_shape = std::make_shared<abstract::TupleShape>();
678 auto tuple_type = std::make_shared<Tuple>();
679 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
680 tuple_shape, tuple_type, nullptr, nullptr, 0, kOpFormat_DEFAULT, TypeId::kNumberTypeInt64, ShapeVector(),
681 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
682 const auto &new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
683 MS_EXCEPTION_IF_NULL(new_device_tensor);
684 new_device_tensor->set_dynamic_ref_count(0);
685 new_device_tensor->set_original_ref_count(SIZE_MAX);
686 new_device_tensor->ResetRefCount();
687 if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
688 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
689 GetAID().Name(), new_device_tensor->GetSize());
690 }
691 created_device_tensors_.emplace_back(new_device_tensor);
692 (*device_tensor) = new_device_tensor.get();
693 MS_LOG(DEBUG) << "actor:" << GetAID() << " create new device address:" << new_device_tensor << " for empty addr list";
694 }
695 } // namespace runtime
696 } // namespace mindspore
697