• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
19 #include "runtime/graph_scheduler/actor/output_actor.h"
20 #include "runtime/hardware/device_context_manager.h"
21 #include "include/backend/mem_reuse/mem_tracker.h"
22 
23 namespace mindspore {
24 namespace runtime {
Init()25 void ExitActor::Init() {
26   // Init output data in base class.
27   ControlActor::Init();
28 
29   // Init output data in each output branch.
30   for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) {
31     auto &output_branch_data_arrows = output_branch_data_arrows_[i];
32     for (auto &data_arrow : output_branch_data_arrows) {
33       MS_EXCEPTION_IF_NULL(data_arrow);
34       auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
35       (void)output_branch_data_[i].emplace_back(data_arrow->from_output_index_, std::move(data));
36 
37       // Identify whether the output data flag is kOutputDataFlagToStack.
38       bool is_to_stack = (data_arrow->to_op_id_.Name().find(kStackActorNameSuffix) != std::string::npos);
39       size_t output_data_flag = is_to_stack ? kOutputDataFlagToStack : kOutputDataFlagInit;
40       (void)output_branch_data_flag_[i].emplace_back(output_data_flag);
41     }
42   }
43 
44   // Check device contexts number.
45   if (device_contexts_.size() != input_device_tensors_.size()) {
46     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
47   }
48 }
49 
FetchInput(OpContext<DeviceTensor> * const context)50 void ExitActor::FetchInput(OpContext<DeviceTensor> *const context) {
51   MS_EXCEPTION_IF_NULL(context);
52   if (!WaitRuntimePipelineFinish(context)) {
53     MS_LOG(INFO) << "Run failed and early stop.";
54     return;
55   }
56   ControlActor::FetchInput(context);
57 
58   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
59   CopyDeviceAddress(context);
60   if (output_branch_dynamic_len_index_.find(output_branch_id_) == output_branch_dynamic_len_index_.end()) {
61     auto data_iter = output_branch_data_.find(output_branch_id_);
62     if (data_iter != output_branch_data_.end()) {
63       for (auto &output_data : data_iter->second) {
64         MS_EXCEPTION_IF_NULL(output_data.second);
65         if (output_data.first >= input_device_tensors_.size()) {
66           MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID()
67                             << " to actor:" << output_data.second->op_id_ << " to index:" << output_data.second->index_;
68         }
69         MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]);
70         output_data.second->data_ = input_device_tensors_[output_data.first];
71       }
72     }
73   } else {
74     // The branch id need merge device address.
75     MS_LOG(DEBUG) << "Exit actor:" << GetAID() << " merge output";
76     MergeDynamiclenDeviceAddress(context);
77   }
78 }
79 
SendOutput(OpContext<DeviceTensor> * const context)80 void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
81   MS_EXCEPTION_IF_NULL(context);
82   // Before the exit actor sends output, it is necessary to ensure that all reference count calculations in the
83   // graph are completed, otherwise the device tensor in the free memory list will be overwritten the next time
84   // it is executed, resulting in multiple releases of ptr.
85   ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID());
86 }
87 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)88 void ExitActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
89   MS_EXCEPTION_IF_NULL(context);
90   if (IsRunningFailed(context)) {
91     return;
92   }
93 
94   // 1.Send output in base class.
95   ControlActor::SendOutput(context);
96 
97   // 2.Send output data in output branch.
98   const auto &branch_data_iter = output_branch_data_.find(output_branch_id_);
99   if (branch_data_iter != output_branch_data_.end()) {
100     MS_EXCEPTION_IF_CHECK_FAIL((output_branch_data_flag_.count(output_branch_id_) > 0),
101                                "The output branch id is invalid.");
102     const auto &output_data_flags = output_branch_data_flag_[output_branch_id_];
103     MS_EXCEPTION_IF_CHECK_FAIL((output_data_flags.size() == branch_data_iter->second.size()),
104                                "The output data flag size is wrong.");
105     for (size_t i = 0; i < branch_data_iter->second.size(); ++i) {
106       const auto &output_data = branch_data_iter->second[i];
107       MS_EXCEPTION_IF_NULL(output_data.second);
108       // Create a new op data for stack actor.
109       if (TEST_FLAG(output_data_flags[i], kOutputDataFlagToStack)) {
110         auto to_stack_data = std::make_unique<OpData<DeviceTensor>>(
111           output_data.second->op_id_, output_data.second->data_, output_data.second->index_);
112         (void)to_stack_data_.emplace_back(std::move(to_stack_data));
113         ActorDispatcher::Send(output_data.second->op_id_, &OpActor::RunOpData, to_stack_data_.back().get(), context);
114       } else {
115         ActorDispatcher::Send(output_data.second->op_id_, &OpActor::RunOpData, output_data.second.get(), context);
116       }
117     }
118   }
119 
120   // 3.Send output control in output branch.
121   const auto &control_iter = output_branch_control_arrows_.find(output_branch_id_);
122   if (control_iter != output_branch_control_arrows_.end()) {
123     auto source_aid = const_cast<AID *>(&GetAID());
124     for (const auto &control_arrow : control_iter->second) {
125       ActorDispatcher::Send(control_arrow, &OpActor::RunOpControl, source_aid, context);
126     }
127   }
128 
129   // 3.Send output partial in output branch.
130   const auto &partial_iter = output_branch_partial_arrows_.find(output_branch_id_);
131   if (partial_iter != output_branch_partial_arrows_.end()) {
132     for (const auto &arrow : partial_iter->second) {
133       MS_EXCEPTION_IF_NULL(arrow);
134       if (IntToSize(arrow->from_output_index_) >= input_partials_.size()) {
135         std::string error_info = "Invalid partial input:" + std::to_string(arrow->from_output_index_) +
136                                  " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
137         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
138       }
139       auto output_partial = input_partials_[IntToSize(arrow->from_output_index_)];
140       ActorDispatcher::Send(arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
141                             IntToSize(arrow->to_input_index_), context);
142     }
143   }
144   last_step_created_device_tensors_.clear();
145 }
146 
IncreaseDynamicRefCounts(OpContext<DeviceTensor> * const context)147 void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
148   MS_EXCEPTION_IF_NULL(context);
149   ControlActor::IncreaseDynamicRefCounts(context);
150 
151   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
152   // Increase dynamic ref count by the output data in output branch.
153   if (output_branch_data_.count(output_branch_id_) > 0) {
154     for (auto &output_data : output_branch_data_[output_branch_id_]) {
155       MS_EXCEPTION_IF_NULL(output_data.second);
156       IncreaseDynamicRefCount(output_data.second.get());
157     }
158   }
159 
160   // Increase dynamic ref count by the output partial in output branch.
161   if (output_branch_partial_arrows_.count(output_branch_id_) > 0) {
162     for (const auto &partial_arrow : output_branch_partial_arrows_[output_branch_id_]) {
163       MS_EXCEPTION_IF_NULL(partial_arrow);
164       if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
165         std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
166                                  " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
167         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
168       }
169       auto output_partial = input_partials_[IntToSize(partial_arrow->from_output_index_)];
170       IncreaseDynamicRefCount(output_partial);
171     }
172   }
173   if (input_device_tensors_.size() != device_contexts_.size()) {
174     MS_LOG(ERROR) << "Input device tensor size:" << input_device_tensors_.size()
175                   << " is not equal to context size:" << device_contexts_.size() << " for actor:" << GetAID();
176   }
177   // The input device tensor may not have users and needs to free the memory.
178   for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
179     if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) &&
180         (device_contexts_[i] != nullptr)) {
181       MS_LOG(INFO) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
182       // Update the real used device context by the input data.
183       if (device_contexts_[i]->GetDeviceType() != input_device_tensors_[i]->GetDeviceType()) {
184         device_contexts_[i] = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
185           {input_device_tensors_[i]->device_name(), input_device_tensors_[i]->device_id()});
186         MS_LOG(INFO) << "Update device context type to:" << device_contexts_[i]->GetDeviceType();
187       }
188       device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
189     }
190   }
191 }
192 
MergeDynamiclenDeviceAddress(OpContext<DeviceTensor> * const context)193 void ExitActor::MergeDynamiclenDeviceAddress(OpContext<DeviceTensor> *const context) {
194   if (output_branch_dynamic_len_index_.find(output_branch_id_) == output_branch_dynamic_len_index_.end()) {
195     return;
196   }
197   auto real_indexes = output_branch_dynamic_len_index_[output_branch_id_];
198   std::vector<OpPartialPtr> new_partials;
199   std::vector<DeviceTensor *> new_device_tensors;
200   // Collect the new output of actor, merge the device address for dynamic len.
201   for (size_t i = 0; i < real_indexes.size(); ++i) {
202     const auto &indexes = real_indexes[i].first;
203     if (real_indexes[i].second) {
204       std::vector<DeviceTensor *> addr_list;
205       for (size_t index : indexes) {
206         if (index >= input_device_tensors_.size()) {
207           std::string error_info = "Invalid real index:" + std::to_string(index) + " for index:" + std::to_string(i) +
208                                    " total size:" + std::to_string(input_device_tensors_.size()) +
209                                    " for actor:" + GetAID().Name();
210           SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
211         }
212         if (input_device_tensors_[index] == nullptr) {
213           std::string error_info =
214             "Invalid input device address index:" + std::to_string(index) + " for index:" + std::to_string(i) +
215             " total size:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
216           SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
217         }
218         addr_list.emplace_back(input_device_tensors_[index]);
219       }
220       DeviceTensor *new_device_tensor = nullptr;
221       MergeDeviceAddress(context, addr_list, &new_device_tensor);
222       new_device_tensors.emplace_back(new_device_tensor);
223       new_partials.emplace_back(nullptr);
224     } else if (indexes.empty() || indexes[0] >= input_partials_.size()) {
225       std::string error_info = "Invalid index num:" + std::to_string(indexes.size()) +
226                                " for index:" + std::to_string(i) + " for actor:" + GetAID().Name();
227       MS_LOG(WARNING) << error_info;
228       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
229     } else if (input_partials_[indexes[0]] != nullptr) {
230       new_device_tensors.emplace_back(nullptr);
231       new_partials.emplace_back(input_partials_[indexes[0]]);
232     } else if (input_device_tensors_[indexes[0]] != nullptr) {
233       new_device_tensors.emplace_back(input_device_tensors_[indexes[0]]);
234       new_partials.emplace_back(nullptr);
235     } else {
236       std::string error_info = "Failed to get input for real index:" + std::to_string(indexes[0]) +
237                                " for index:" + std::to_string(i) + " for actor:" + GetAID().Name();
238       MS_LOG(WARNING) << error_info;
239       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
240     }
241   }
242   auto data_iter = output_branch_data_.find(output_branch_id_);
243   if (data_iter != output_branch_data_.end()) {
244     for (auto &output_data : data_iter->second) {
245       MS_EXCEPTION_IF_NULL(output_data.second);
246       if (output_data.first >= new_device_tensors.size()) {
247         MS_EXCEPTION_IF_NULL(output_data.second);
248         MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID()
249                           << " to actor:" << output_data.second->op_id_ << " to index:" << output_data.second->index_;
250       }
251       MS_EXCEPTION_IF_NULL(new_device_tensors[output_data.first]);
252       output_data.second->data_ = new_device_tensors[output_data.first];
253     }
254   }
255 }
256 
IsNeedCopyDeviceAddress(DeviceTensor * const input_device_tensor,size_t index)257 bool ExitActor::IsNeedCopyDeviceAddress(DeviceTensor *const input_device_tensor, size_t index) {
258   if ((input_device_tensor == nullptr) || (!is_need_copy_device_tensors_[index])) {
259     return false;
260   }
261 
262   if (is_need_dynamic_checks_[index]) {
263     if (input_device_tensor->dynamic_ref_count() != INT32_MAX) {
264       return false;
265     }
266     const auto &node = input_device_tensor->GetNodeIndex().first;
267     if (node != nullptr) {
268       if (!node->isa<CNode>()) {
269         MS_LOG(DEBUG) << "Input device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
270                       << " for node:" << node->DebugString() << " is not need replace ptr for actor:" << GetAID();
271         return false;
272       }
273       const auto &iter = ref_out_in_map_.find(input_device_tensor->GetNodeIndex());
274       if (iter != ref_out_in_map_.end() && iter->second.first != nullptr && (!iter->second.first->isa<CNode>())) {
275         MS_LOG(DEBUG) << "Input device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
276                       << " for node:" << node->DebugString()
277                       << " is a ref node of:" << iter->second.first->DebugString()
278                       << " not need replace ptr for actor:" << GetAID();
279         return false;
280       }
281     }
282   }
283   return true;
284 }
285 
UpdateDeviceOutputData()286 void ExitActor::UpdateDeviceOutputData() {
287   for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
288     if (output_data_by_output_index_[i].empty()) {
289       continue;
290     }
291 
292     const auto &device_tensor = input_device_tensors_[i];
293     MS_EXCEPTION_IF_NULL(device_tensor);
294     for (auto &output_data : output_data_by_output_index_[i]) {
295       MS_EXCEPTION_IF_NULL(output_data);
296       output_data->data_ = device_tensor;
297     }
298   }
299 }
300 
CopyDeviceAddress(OpContext<DeviceTensor> * const context)301 void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
302   MS_EXCEPTION_IF_NULL(context);
303   // If node is not empty, it is the exit of funcgraph, no need to create device address.
304   if (node_ != nullptr) {
305     return;
306   }
307   if (input_device_tensors_.size() != is_need_copy_device_tensors_.size() ||
308       input_device_tensors_.size() != is_dynamic_shapes_.size() ||
309       input_device_tensors_.size() != device_contexts_.size() ||
310       input_device_tensors_.size() != is_need_dynamic_checks_.size()) {
311     std::string error_info = "Invalid input device tensor size:" + std::to_string(input_device_tensors_.size()) +
312                              " need tensor size:" + std::to_string(is_need_copy_device_tensors_.size()) +
313                              " need dynamic shape size:" + std::to_string(is_dynamic_shapes_.size()) +
314                              " need context size:" + std::to_string(device_contexts_.size()) +
315                              " for actor:" + GetAID().Name();
316     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317   }
318 
319   std::vector<DeviceTensor *> new_device_tensors;
320   mindspore::HashMap<DeviceTensor *, DeviceTensor *> device_tensor_map;
321   for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
322     auto &input_device_tensor = input_device_tensors_[i];
323     if (!IsNeedCopyDeviceAddress(input_device_tensor, i)) {
324       (void)new_device_tensors.emplace_back(input_device_tensor);
325       continue;
326     }
327 
328     auto iter = device_tensor_map.find(input_device_tensor);
329     if (iter != device_tensor_map.end()) {
330       (void)new_device_tensors.emplace_back(iter->second);
331       continue;
332     }
333 
334     // Update the real used device context by the input data.
335     auto &device_context = device_contexts_[i];
336     MS_EXCEPTION_IF_NULL(device_context);
337     MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
338     if (device_context->GetDeviceType() != input_device_tensor->GetDeviceType()) {
339       device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
340         {input_device_tensor->device_name(), input_device_tensor->device_id()});
341       MS_LOG(INFO) << "Update device context type to:" << device_context->GetDeviceType();
342     }
343 
344     const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
345     MS_EXCEPTION_IF_NULL(node_with_index.first);
346     // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
347     const auto &kernel_tensor = input_device_tensor->kernel_tensor();
348     MS_EXCEPTION_IF_NULL(kernel_tensor);
349     auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
350     MS_EXCEPTION_IF_NULL(new_kernel_tensor);
351     new_kernel_tensor->set_device_ptr(nullptr);
352     DeviceTensorPtr new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
353     MS_EXCEPTION_IF_NULL(new_device_tensor);
354     MS_LOG(DEBUG) << "Actor:" << GetAID() << " create new device tensor:" << new_device_tensor
355                   << " type:" << new_device_tensor->type_id() << " by input device tensor:" << input_device_tensor
356                   << " shape:"
357                   << (kernel_tensor->GetShape() == nullptr ? "null" : kernel_tensor->GetShape()->ToString())
358                   << (kernel_tensor->GetType() == nullptr ? "null" : kernel_tensor->GetType()->ToString());
359     const auto &swap_manager = device_context->device_res_manager_->swap_manager();
360     if (swap_manager != nullptr) {
361       swap_manager->AddSwappableTensor(new_device_tensor);
362     }
363     (void)created_device_tensors_.emplace_back(new_device_tensor);
364     (void)new_device_tensors.emplace_back(new_device_tensor.get());
365     device_tensor_map[input_device_tensor] = new_device_tensor.get();
366     new_device_tensor->set_need_sync_user_data(input_device_tensor->need_sync_user_data());
367     new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
368     new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem());
369     // The device address which is created by actor uses the dynamic ref count.
370     new_device_tensor->set_dynamic_ref_count(0);
371     new_device_tensor->set_original_ref_count(SIZE_MAX);
372     new_device_tensor->ResetRefCount();
373 
374     // If the address ptr can't be changed, then alloc the new device memory and copy the data.
375     if (input_device_tensor->is_ptr_persisted()) {
376       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "CopyDeviceAddress", "");
377       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(), device::tracker::MemType::kOther,
378                                                      new_device_tensor->GetSize(), new_device_tensor.get());
379       device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);
380       if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
381         SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
382                                                     GetAID().Name(), new_device_tensor->GetSize());
383       }
384       if (!new_device_tensor->SyncDeviceToDevice(input_device_tensor)) {
385         SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
386       }
387     } else {
388       // Move the device ptr from input_device_tensor to new_device_tensor.
389       input_device_tensor->Swap(new_device_tensor.get());
390       if (new_device_tensor->deleter() == nullptr) {
391         new_device_tensor->set_from_mem_pool(true);
392       }
393     }
394     MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
395                   << ", ptr:" << new_device_tensor->GetPtr()
396                   << ", from node:" << node_with_index.first->fullname_with_scope()
397                   << " with index:" << node_with_index.second;
398   }
399   input_device_tensors_.swap(new_device_tensors);
400   UpdateDeviceOutputData();
401 }
402 }  // namespace runtime
403 }  // namespace mindspore
404