1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
19 #include "runtime/graph_scheduler/actor/output_actor.h"
20 #include "runtime/hardware/device_context_manager.h"
21 #include "include/backend/mem_reuse/mem_tracker.h"
22
23 namespace mindspore {
24 namespace runtime {
Init()25 void ExitActor::Init() {
26 // Init output data in base class.
27 ControlActor::Init();
28
29 // Init output data in each output branch.
30 for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) {
31 auto &output_branch_data_arrows = output_branch_data_arrows_[i];
32 for (auto &data_arrow : output_branch_data_arrows) {
33 MS_EXCEPTION_IF_NULL(data_arrow);
34 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
35 (void)output_branch_data_[i].emplace_back(data_arrow->from_output_index_, std::move(data));
36
37 // Identify whether the output data flag is kOutputDataFlagToStack.
38 bool is_to_stack = (data_arrow->to_op_id_.Name().find(kStackActorNameSuffix) != std::string::npos);
39 size_t output_data_flag = is_to_stack ? kOutputDataFlagToStack : kOutputDataFlagInit;
40 (void)output_branch_data_flag_[i].emplace_back(output_data_flag);
41 }
42 }
43
44 // Check device contexts number.
45 if (device_contexts_.size() != input_device_tensors_.size()) {
46 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
47 }
48 }
49
FetchInput(OpContext<DeviceTensor> * const context)50 void ExitActor::FetchInput(OpContext<DeviceTensor> *const context) {
51 MS_EXCEPTION_IF_NULL(context);
52 if (!WaitRuntimePipelineFinish(context)) {
53 MS_LOG(INFO) << "Run failed and early stop.";
54 return;
55 }
56 ControlActor::FetchInput(context);
57
58 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
59 CopyDeviceAddress(context);
60 if (output_branch_dynamic_len_index_.find(output_branch_id_) == output_branch_dynamic_len_index_.end()) {
61 auto data_iter = output_branch_data_.find(output_branch_id_);
62 if (data_iter != output_branch_data_.end()) {
63 for (auto &output_data : data_iter->second) {
64 MS_EXCEPTION_IF_NULL(output_data.second);
65 if (output_data.first >= input_device_tensors_.size()) {
66 MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID()
67 << " to actor:" << output_data.second->op_id_ << " to index:" << output_data.second->index_;
68 }
69 MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]);
70 output_data.second->data_ = input_device_tensors_[output_data.first];
71 }
72 }
73 } else {
74 // The branch id need merge device address.
75 MS_LOG(DEBUG) << "Exit actor:" << GetAID() << " merge output";
76 MergeDynamiclenDeviceAddress(context);
77 }
78 }
79
SendOutput(OpContext<DeviceTensor> * const context)80 void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
81 MS_EXCEPTION_IF_NULL(context);
82 // Before the exit actor sends output, it is necessary to ensure that all reference count calculations in the
83 // graph are completed, otherwise the device tensor in the free memory list will be overwritten the next time
84 // it is executed, resulting in multiple releases of ptr.
85 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID());
86 }
87
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)88 void ExitActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
89 MS_EXCEPTION_IF_NULL(context);
90 if (IsRunningFailed(context)) {
91 return;
92 }
93
94 // 1.Send output in base class.
95 ControlActor::SendOutput(context);
96
97 // 2.Send output data in output branch.
98 const auto &branch_data_iter = output_branch_data_.find(output_branch_id_);
99 if (branch_data_iter != output_branch_data_.end()) {
100 MS_EXCEPTION_IF_CHECK_FAIL((output_branch_data_flag_.count(output_branch_id_) > 0),
101 "The output branch id is invalid.");
102 const auto &output_data_flags = output_branch_data_flag_[output_branch_id_];
103 MS_EXCEPTION_IF_CHECK_FAIL((output_data_flags.size() == branch_data_iter->second.size()),
104 "The output data flag size is wrong.");
105 for (size_t i = 0; i < branch_data_iter->second.size(); ++i) {
106 const auto &output_data = branch_data_iter->second[i];
107 MS_EXCEPTION_IF_NULL(output_data.second);
108 // Create a new op data for stack actor.
109 if (TEST_FLAG(output_data_flags[i], kOutputDataFlagToStack)) {
110 auto to_stack_data = std::make_unique<OpData<DeviceTensor>>(
111 output_data.second->op_id_, output_data.second->data_, output_data.second->index_);
112 (void)to_stack_data_.emplace_back(std::move(to_stack_data));
113 ActorDispatcher::Send(output_data.second->op_id_, &OpActor::RunOpData, to_stack_data_.back().get(), context);
114 } else {
115 ActorDispatcher::Send(output_data.second->op_id_, &OpActor::RunOpData, output_data.second.get(), context);
116 }
117 }
118 }
119
120 // 3.Send output control in output branch.
121 const auto &control_iter = output_branch_control_arrows_.find(output_branch_id_);
122 if (control_iter != output_branch_control_arrows_.end()) {
123 auto source_aid = const_cast<AID *>(&GetAID());
124 for (const auto &control_arrow : control_iter->second) {
125 ActorDispatcher::Send(control_arrow, &OpActor::RunOpControl, source_aid, context);
126 }
127 }
128
129 // 3.Send output partial in output branch.
130 const auto &partial_iter = output_branch_partial_arrows_.find(output_branch_id_);
131 if (partial_iter != output_branch_partial_arrows_.end()) {
132 for (const auto &arrow : partial_iter->second) {
133 MS_EXCEPTION_IF_NULL(arrow);
134 if (IntToSize(arrow->from_output_index_) >= input_partials_.size()) {
135 std::string error_info = "Invalid partial input:" + std::to_string(arrow->from_output_index_) +
136 " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
137 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
138 }
139 auto output_partial = input_partials_[IntToSize(arrow->from_output_index_)];
140 ActorDispatcher::Send(arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
141 IntToSize(arrow->to_input_index_), context);
142 }
143 }
144 last_step_created_device_tensors_.clear();
145 }
146
IncreaseDynamicRefCounts(OpContext<DeviceTensor> * const context)147 void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
148 MS_EXCEPTION_IF_NULL(context);
149 ControlActor::IncreaseDynamicRefCounts(context);
150
151 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
152 // Increase dynamic ref count by the output data in output branch.
153 if (output_branch_data_.count(output_branch_id_) > 0) {
154 for (auto &output_data : output_branch_data_[output_branch_id_]) {
155 MS_EXCEPTION_IF_NULL(output_data.second);
156 IncreaseDynamicRefCount(output_data.second.get());
157 }
158 }
159
160 // Increase dynamic ref count by the output partial in output branch.
161 if (output_branch_partial_arrows_.count(output_branch_id_) > 0) {
162 for (const auto &partial_arrow : output_branch_partial_arrows_[output_branch_id_]) {
163 MS_EXCEPTION_IF_NULL(partial_arrow);
164 if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
165 std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
166 " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
167 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
168 }
169 auto output_partial = input_partials_[IntToSize(partial_arrow->from_output_index_)];
170 IncreaseDynamicRefCount(output_partial);
171 }
172 }
173 if (input_device_tensors_.size() != device_contexts_.size()) {
174 MS_LOG(ERROR) << "Input device tensor size:" << input_device_tensors_.size()
175 << " is not equal to context size:" << device_contexts_.size() << " for actor:" << GetAID();
176 }
177 // The input device tensor may not have users and needs to free the memory.
178 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
179 if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) &&
180 (device_contexts_[i] != nullptr)) {
181 MS_LOG(INFO) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
182 // Update the real used device context by the input data.
183 if (device_contexts_[i]->GetDeviceType() != input_device_tensors_[i]->GetDeviceType()) {
184 device_contexts_[i] = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
185 {input_device_tensors_[i]->device_name(), input_device_tensors_[i]->device_id()});
186 MS_LOG(INFO) << "Update device context type to:" << device_contexts_[i]->GetDeviceType();
187 }
188 device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
189 }
190 }
191 }
192
MergeDynamiclenDeviceAddress(OpContext<DeviceTensor> * const context)193 void ExitActor::MergeDynamiclenDeviceAddress(OpContext<DeviceTensor> *const context) {
194 if (output_branch_dynamic_len_index_.find(output_branch_id_) == output_branch_dynamic_len_index_.end()) {
195 return;
196 }
197 auto real_indexes = output_branch_dynamic_len_index_[output_branch_id_];
198 std::vector<OpPartialPtr> new_partials;
199 std::vector<DeviceTensor *> new_device_tensors;
200 // Collect the new output of actor, merge the device address for dynamic len.
201 for (size_t i = 0; i < real_indexes.size(); ++i) {
202 const auto &indexes = real_indexes[i].first;
203 if (real_indexes[i].second) {
204 std::vector<DeviceTensor *> addr_list;
205 for (size_t index : indexes) {
206 if (index >= input_device_tensors_.size()) {
207 std::string error_info = "Invalid real index:" + std::to_string(index) + " for index:" + std::to_string(i) +
208 " total size:" + std::to_string(input_device_tensors_.size()) +
209 " for actor:" + GetAID().Name();
210 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
211 }
212 if (input_device_tensors_[index] == nullptr) {
213 std::string error_info =
214 "Invalid input device address index:" + std::to_string(index) + " for index:" + std::to_string(i) +
215 " total size:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
216 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
217 }
218 addr_list.emplace_back(input_device_tensors_[index]);
219 }
220 DeviceTensor *new_device_tensor = nullptr;
221 MergeDeviceAddress(context, addr_list, &new_device_tensor);
222 new_device_tensors.emplace_back(new_device_tensor);
223 new_partials.emplace_back(nullptr);
224 } else if (indexes.empty() || indexes[0] >= input_partials_.size()) {
225 std::string error_info = "Invalid index num:" + std::to_string(indexes.size()) +
226 " for index:" + std::to_string(i) + " for actor:" + GetAID().Name();
227 MS_LOG(WARNING) << error_info;
228 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
229 } else if (input_partials_[indexes[0]] != nullptr) {
230 new_device_tensors.emplace_back(nullptr);
231 new_partials.emplace_back(input_partials_[indexes[0]]);
232 } else if (input_device_tensors_[indexes[0]] != nullptr) {
233 new_device_tensors.emplace_back(input_device_tensors_[indexes[0]]);
234 new_partials.emplace_back(nullptr);
235 } else {
236 std::string error_info = "Failed to get input for real index:" + std::to_string(indexes[0]) +
237 " for index:" + std::to_string(i) + " for actor:" + GetAID().Name();
238 MS_LOG(WARNING) << error_info;
239 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
240 }
241 }
242 auto data_iter = output_branch_data_.find(output_branch_id_);
243 if (data_iter != output_branch_data_.end()) {
244 for (auto &output_data : data_iter->second) {
245 MS_EXCEPTION_IF_NULL(output_data.second);
246 if (output_data.first >= new_device_tensors.size()) {
247 MS_EXCEPTION_IF_NULL(output_data.second);
248 MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID()
249 << " to actor:" << output_data.second->op_id_ << " to index:" << output_data.second->index_;
250 }
251 MS_EXCEPTION_IF_NULL(new_device_tensors[output_data.first]);
252 output_data.second->data_ = new_device_tensors[output_data.first];
253 }
254 }
255 }
256
IsNeedCopyDeviceAddress(DeviceTensor * const input_device_tensor,size_t index)257 bool ExitActor::IsNeedCopyDeviceAddress(DeviceTensor *const input_device_tensor, size_t index) {
258 if ((input_device_tensor == nullptr) || (!is_need_copy_device_tensors_[index])) {
259 return false;
260 }
261
262 if (is_need_dynamic_checks_[index]) {
263 if (input_device_tensor->dynamic_ref_count() != INT32_MAX) {
264 return false;
265 }
266 const auto &node = input_device_tensor->GetNodeIndex().first;
267 if (node != nullptr) {
268 if (!node->isa<CNode>()) {
269 MS_LOG(DEBUG) << "Input device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
270 << " for node:" << node->DebugString() << " is not need replace ptr for actor:" << GetAID();
271 return false;
272 }
273 const auto &iter = ref_out_in_map_.find(input_device_tensor->GetNodeIndex());
274 if (iter != ref_out_in_map_.end() && iter->second.first != nullptr && (!iter->second.first->isa<CNode>())) {
275 MS_LOG(DEBUG) << "Input device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
276 << " for node:" << node->DebugString()
277 << " is a ref node of:" << iter->second.first->DebugString()
278 << " not need replace ptr for actor:" << GetAID();
279 return false;
280 }
281 }
282 }
283 return true;
284 }
285
UpdateDeviceOutputData()286 void ExitActor::UpdateDeviceOutputData() {
287 for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
288 if (output_data_by_output_index_[i].empty()) {
289 continue;
290 }
291
292 const auto &device_tensor = input_device_tensors_[i];
293 MS_EXCEPTION_IF_NULL(device_tensor);
294 for (auto &output_data : output_data_by_output_index_[i]) {
295 MS_EXCEPTION_IF_NULL(output_data);
296 output_data->data_ = device_tensor;
297 }
298 }
299 }
300
CopyDeviceAddress(OpContext<DeviceTensor> * const context)301 void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
302 MS_EXCEPTION_IF_NULL(context);
303 // If node is not empty, it is the exit of funcgraph, no need to create device address.
304 if (node_ != nullptr) {
305 return;
306 }
307 if (input_device_tensors_.size() != is_need_copy_device_tensors_.size() ||
308 input_device_tensors_.size() != is_dynamic_shapes_.size() ||
309 input_device_tensors_.size() != device_contexts_.size() ||
310 input_device_tensors_.size() != is_need_dynamic_checks_.size()) {
311 std::string error_info = "Invalid input device tensor size:" + std::to_string(input_device_tensors_.size()) +
312 " need tensor size:" + std::to_string(is_need_copy_device_tensors_.size()) +
313 " need dynamic shape size:" + std::to_string(is_dynamic_shapes_.size()) +
314 " need context size:" + std::to_string(device_contexts_.size()) +
315 " for actor:" + GetAID().Name();
316 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317 }
318
319 std::vector<DeviceTensor *> new_device_tensors;
320 mindspore::HashMap<DeviceTensor *, DeviceTensor *> device_tensor_map;
321 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
322 auto &input_device_tensor = input_device_tensors_[i];
323 if (!IsNeedCopyDeviceAddress(input_device_tensor, i)) {
324 (void)new_device_tensors.emplace_back(input_device_tensor);
325 continue;
326 }
327
328 auto iter = device_tensor_map.find(input_device_tensor);
329 if (iter != device_tensor_map.end()) {
330 (void)new_device_tensors.emplace_back(iter->second);
331 continue;
332 }
333
334 // Update the real used device context by the input data.
335 auto &device_context = device_contexts_[i];
336 MS_EXCEPTION_IF_NULL(device_context);
337 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
338 if (device_context->GetDeviceType() != input_device_tensor->GetDeviceType()) {
339 device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
340 {input_device_tensor->device_name(), input_device_tensor->device_id()});
341 MS_LOG(INFO) << "Update device context type to:" << device_context->GetDeviceType();
342 }
343
344 const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
345 MS_EXCEPTION_IF_NULL(node_with_index.first);
346 // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
347 const auto &kernel_tensor = input_device_tensor->kernel_tensor();
348 MS_EXCEPTION_IF_NULL(kernel_tensor);
349 auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
350 MS_EXCEPTION_IF_NULL(new_kernel_tensor);
351 new_kernel_tensor->set_device_ptr(nullptr);
352 DeviceTensorPtr new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
353 MS_EXCEPTION_IF_NULL(new_device_tensor);
354 MS_LOG(DEBUG) << "Actor:" << GetAID() << " create new device tensor:" << new_device_tensor
355 << " type:" << new_device_tensor->type_id() << " by input device tensor:" << input_device_tensor
356 << " shape:"
357 << (kernel_tensor->GetShape() == nullptr ? "null" : kernel_tensor->GetShape()->ToString())
358 << (kernel_tensor->GetType() == nullptr ? "null" : kernel_tensor->GetType()->ToString());
359 const auto &swap_manager = device_context->device_res_manager_->swap_manager();
360 if (swap_manager != nullptr) {
361 swap_manager->AddSwappableTensor(new_device_tensor);
362 }
363 (void)created_device_tensors_.emplace_back(new_device_tensor);
364 (void)new_device_tensors.emplace_back(new_device_tensor.get());
365 device_tensor_map[input_device_tensor] = new_device_tensor.get();
366 new_device_tensor->set_need_sync_user_data(input_device_tensor->need_sync_user_data());
367 new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
368 new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem());
369 // The device address which is created by actor uses the dynamic ref count.
370 new_device_tensor->set_dynamic_ref_count(0);
371 new_device_tensor->set_original_ref_count(SIZE_MAX);
372 new_device_tensor->ResetRefCount();
373
374 // If the address ptr can't be changed, then alloc the new device memory and copy the data.
375 if (input_device_tensor->is_ptr_persisted()) {
376 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "CopyDeviceAddress", "");
377 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(), device::tracker::MemType::kOther,
378 new_device_tensor->GetSize(), new_device_tensor.get());
379 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);
380 if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
381 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
382 GetAID().Name(), new_device_tensor->GetSize());
383 }
384 if (!new_device_tensor->SyncDeviceToDevice(input_device_tensor)) {
385 SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
386 }
387 } else {
388 // Move the device ptr from input_device_tensor to new_device_tensor.
389 input_device_tensor->Swap(new_device_tensor.get());
390 if (new_device_tensor->deleter() == nullptr) {
391 new_device_tensor->set_from_mem_pool(true);
392 }
393 }
394 MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
395 << ", ptr:" << new_device_tensor->GetPtr()
396 << ", from node:" << node_with_index.first->fullname_with_scope()
397 << " with index:" << node_with_index.second;
398 }
399 input_device_tensors_.swap(new_device_tensors);
400 UpdateDeviceOutputData();
401 }
402 } // namespace runtime
403 } // namespace mindspore
404