• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h"
18 #include "runtime/hardware/device_context_manager.h"
19 #include "include/backend/mem_reuse/mem_tracker.h"
20 #include "ops/framework_ops.h"
21 #include "utils/profile.h"
22 
23 namespace mindspore {
24 namespace runtime {
ControlActor(const std::string & name,KernelTransformType type,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const AnfNodePtr & node)25 ControlActor::ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid,
26                            const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
27     : MemoryAwareActor(name, type, nullptr, memory_manager_aid), formal_parameters_(parameters), node_(node) {
28   input_partials_.resize(parameters.size());
29   input_device_tensors_.resize(parameters.size());
30   backend_parameters_.resize(parameters.size());
31   output_data_by_output_index_.resize(parameters.size());
32 }
33 
Init()34 void ControlActor::Init() {
35   InitOutputData();
36   if (output_data_.size() != output_data_arrows_.size()) {
37     MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
38   }
39 
40   size_t output_data_index = 0;
41   for (auto &data_arrow : output_data_arrows_) {
42     auto data = output_data_[output_data_index].first.get();
43     MS_EXCEPTION_IF_NULL(data);
44     MS_EXCEPTION_IF_NULL(data_arrow);
45     if (IntToSize(data_arrow->from_output_index_) >= output_data_by_output_index_.size()) {
46       MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID();
47     }
48     (void)output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(data);
49     ++output_data_index;
50   }
51 }
52 
GetAllDeviceTensors(const OpPartialPtr & op_partial,std::vector<DeviceTensor * > * device_tensors)53 void ControlActor::GetAllDeviceTensors(const OpPartialPtr &op_partial, std::vector<DeviceTensor *> *device_tensors) {
54   MS_EXCEPTION_IF_NULL(op_partial);
55   (void)std::transform(op_partial->device_tensors_.begin(), op_partial->device_tensors_.end(),
56                        std::back_inserter(*device_tensors),
57                        [](const auto &device_tensor) { return device_tensor.second; });
58 
59   // Foreach the op partial to fetch the device tensors.
60   for (auto &partial : op_partial->partials_) {
61     GetAllDeviceTensors(partial.second, device_tensors);
62   }
63 }
64 
GetAllDeviceTensors(const OpRealParameterWithBranchID & op_real_parameter,std::vector<DeviceTensor * > * device_tensors)65 void ControlActor::GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter,
66                                        std::vector<DeviceTensor *> *device_tensors) {
67   MS_EXCEPTION_IF_NULL(device_tensors);
68   for (auto &device_tensor : op_real_parameter.device_tensors_) {
69     (void)device_tensors->emplace_back(device_tensor.second);
70   }
71 
72   // Foreach the op partial to fetch the device tensors.
73   for (auto &partial : op_real_parameter.partials_) {
74     GetAllDeviceTensors(partial.second, device_tensors);
75   }
76 }
77 
IncreaseDynamicRefCount(const OpData<DeviceTensor> * op_data) const78 void ControlActor::IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) const {
79   MS_EXCEPTION_IF_NULL(op_data);
80   MS_EXCEPTION_IF_NULL(op_data->data_);
81   op_data->data_->IncreaseDynamicRefCount(GetAID().Name());
82 }
83 
IncreaseDynamicRefCount(const OpPartialPtr & op_partial)84 void ControlActor::IncreaseDynamicRefCount(const OpPartialPtr &op_partial) {
85   if (op_partial == nullptr) {
86     MS_LOG(EXCEPTION) << "Empty op partial for actor:" << GetAID();
87   }
88   std::vector<DeviceTensor *> partial_device_tensors;
89   GetAllDeviceTensors(op_partial, &partial_device_tensors);
90   for (auto &partial_device_tensor : partial_device_tensors) {
91     MS_EXCEPTION_IF_NULL(partial_device_tensor);
92     partial_device_tensor->IncreaseDynamicRefCount(GetAID().Name());
93   }
94 }
95 
IncreaseDynamicRefCount(const OpRealParameterWithBranchID & op_real_parameter)96 void ControlActor::IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter) {
97   std::vector<DeviceTensor *> partial_device_tensors;
98   GetAllDeviceTensors(op_real_parameter, &partial_device_tensors);
99   for (auto &partial_device_tensor : partial_device_tensors) {
100     MS_EXCEPTION_IF_NULL(partial_device_tensor);
101     partial_device_tensor->IncreaseDynamicRefCount(GetAID().Name());
102   }
103 }
104 
FetchNodePosition(const KernelWithIndex & node) const105 size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
106   const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
107   if (iter == formal_parameters_.end()) {
108     const auto &load_iter =
109       std::find_if(formal_parameters_.begin(), formal_parameters_.end(), [&node](const KernelWithIndex &pair) {
110         return pair.first != nullptr && common::AnfAlgo::CheckPrimitiveType(pair.first, prim::kPrimLoad) &&
111                pair.first->cast<CNodePtr>()->input(1) == node.first && node.second == 0;
112       });
113     if (load_iter != formal_parameters_.end()) {
114       return load_iter - formal_parameters_.begin();
115     }
116     for (const auto &formal_parameter : formal_parameters_) {
117       MS_LOG(WARNING) << "Actor:" << GetAID() << " formal parameter:"
118                       << (formal_parameter.first != nullptr ? formal_parameter.first->DebugString() : "")
119                       << " index:" << formal_parameter.second << " node ptr:" << formal_parameter.first;
120     }
121     MS_LOG_WITH_NODE(EXCEPTION, node.first)
122       << "Invalid formal parameter:" << (node.first != nullptr ? node.first->DebugString() : "")
123       << " node ptr:" << node.first << " index:" << node.second << " for actor:" << GetAID();
124   }
125   return iter - formal_parameters_.begin();
126 }
127 
Run(OpContext<DeviceTensor> * const context)128 void ControlActor::Run(OpContext<DeviceTensor> *const context) {
129   try {
130     // The exit actor is the output of kernel graph when the node_ is null.
131     if (type_ == KernelTransformType::kExitActor && node_ == nullptr) {
132       double end_time = GetTime();
133       const size_t kSecondsToMilliseconds = 1000;
134       MS_LOG(DEBUG) << "Kernel graph group exit actor:" << GetAID()
135                     << " cost time:" << (end_time - start_time_) * kSecondsToMilliseconds;
136     }
137 
138     FetchInput(context);
139     if (IsRunningFailed(context)) {
140       MS_LOG(INFO) << "Run failed and early stop.";
141       return;
142     }
143 
144     // Note that IncreaseDynamicRefCounts must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
145     // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
146     // incremented.
147     IncreaseDynamicRefCounts(context);
148     SendMemoryFreeReq(context);
149 
150     EraseInput(context);
151     SendOutput(context);
152   } catch (const std::exception &e) {
153     MsException::Instance().SetException();
154     std::string error_info = "Actor fun failed:" + GetAID().Name();
155     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
156   }
157 }
158 
RunOpPartial(const OpPartialPtr & partial,size_t position,OpContext<DeviceTensor> * const context)159 void ControlActor::RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) {
160   MS_EXCEPTION_IF_NULL(context);
161   auto &sequential_num = context->sequential_num_;
162   (void)input_op_partials_[sequential_num].emplace_back(position, partial);
163 
164   auto is_run = CheckRunningCondition(context);
165   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
166                 << ") receive the input op partial and check running condition:" << is_run;
167   if (is_run) {
168     Run(context);
169   }
170 }
171 
RunBranchID(int branch_id,OpContext<DeviceTensor> * const context)172 void ControlActor::RunBranchID(int branch_id, OpContext<DeviceTensor> *const context) {
173   MS_EXCEPTION_IF_NULL(context);
174   auto &sequential_num = context->sequential_num_;
175   input_branch_ids_[sequential_num].push(branch_id);
176 
177   auto is_run = CheckRunningCondition(context);
178   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
179                 << ") receive the input branch id and check running condition:" << is_run;
180   if (is_run) {
181     Run(context);
182   }
183 }
184 
CheckRunningCondition(const OpContext<DeviceTensor> * context) const185 bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
186   MS_EXCEPTION_IF_NULL(context);
187 
188   if (!AbstractActor::CheckRunningCondition(context)) {
189     return false;
190   }
191 
192   if (input_partials_num_ != 0) {
193     const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
194     if (partial_iter == input_op_partials_.end()) {
195       return false;
196     }
197     if (partial_iter->second.size() < input_partials_num_) {
198       return false;
199     } else if (partial_iter->second.size() > input_partials_num_) {
200       MS_LOG(ERROR) << "Invalid input partial num:" << partial_iter->second.size() << " need:" << input_partials_num_
201                     << " for actor:" << GetAID();
202       return false;
203     }
204   }
205   return true;
206 }
207 
FetchInput(OpContext<DeviceTensor> * const context)208 void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
209   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
210   MS_EXCEPTION_IF_NULL(context);
211 
212   // Fetch input device tensor from input data.
213   const auto &data_iter = input_op_datas_.find(context->sequential_num_);
214   if (data_iter != input_op_datas_.end()) {
215     for (auto &input_data : data_iter->second) {
216       MS_EXCEPTION_IF_NULL(input_data);
217       if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
218         std::string error_info = "Invalid index, need:" + std::to_string(input_data->index_) +
219                                  " current:" + std::to_string(input_device_tensors_.size()) +
220                                  " for actor:" + GetAID().Name();
221         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
222       }
223       MS_EXCEPTION_IF_NULL(input_data->data_);
224       input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
225     }
226   }
227 
228   // Fetch input device tensor from local device tensor.
229   for (auto &local_device_tensor : local_device_tensors_) {
230     MS_EXCEPTION_IF_NULL(local_device_tensor.second);
231     if (local_device_tensor.first >= input_device_tensors_.size()) {
232       std::string error_info = "Invalid local index:" + std::to_string(local_device_tensor.first) +
233                                " current:" + std::to_string(local_device_tensors_.size()) +
234                                " for actor:" + GetAID().Name();
235       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
236     }
237     input_device_tensors_[local_device_tensor.first] = local_device_tensor.second;
238   }
239 
240   // Fetch input device tensor from device tensor store.
241   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
242     auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
243     if (device_tensors.empty()) {
244       auto &device_context = device_contexts_[device_tensor_store_key.first];
245       MS_EXCEPTION_IF_NULL(device_context);
246       MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
247       std::string error_info = GetAID().Name() +
248                                " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
249                                ", device type:" + std::to_string(static_cast<int>(device_context->GetDeviceType()));
250       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
251     }
252 
253     if (device_tensor_store_key.first >= input_device_tensors_.size()) {
254       std::string error_info =
255         "The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
256         " current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
257       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
258     }
259     MS_EXCEPTION_IF_NULL(device_tensors[0]);
260     input_device_tensors_[device_tensor_store_key.first] = device_tensors[0].get();
261   }
262 
263   for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
264     if (output_data_by_output_index_[i].empty()) {
265       continue;
266     }
267     const auto &data = input_device_tensors_[i];
268     MS_EXCEPTION_IF_NULL(data);
269     for (auto &output_data : output_data_by_output_index_[i]) {
270       MS_EXCEPTION_IF_NULL(output_data);
271       output_data->data_ = data;
272     }
273   }
274 
275   // Fetch input partial from input data.
276   const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
277   if (partial_iter != input_op_partials_.end()) {
278     for (const auto &input_partial : partial_iter->second) {
279       if (input_partial.first >= input_partials_.size()) {
280         std::string error_info = "Invalid partial index:" + std::to_string(input_partial.first) +
281                                  " vector size:" + std::to_string(input_partials_.size()) +
282                                  " for actor:" + GetAID().Name();
283         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
284       }
285       input_partials_[input_partial.first] = input_partial.second;
286     }
287   }
288   // Fetch input partial from local partial.
289   for (const auto &local_partial : local_partials_) {
290     if (local_partial.first >= input_partials_.size()) {
291       std::string error_info = "Invalid partial index:" + std::to_string(local_partial.first) +
292                                " vector size:" + std::to_string(input_partials_.size()) +
293                                " for actor:" + GetAID().Name();
294       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
295     }
296     MS_EXCEPTION_IF_NULL(local_partial.second);
297     input_partials_[local_partial.first] = local_partial.second;
298   }
299   // Fetch branch id in stack.
300   auto iter = input_branch_ids_.find(context->sequential_num_);
301   if (iter != input_branch_ids_.end() && (!iter->second.empty())) {
302     output_branch_id_ = iter->second.top();
303   }
304 }
305 
IncreaseDynamicRefCounts(OpContext<DeviceTensor> * const context)306 void ControlActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
307   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
308   MS_EXCEPTION_IF_NULL(context);
309   // Increase dynamic ref count by the output data.
310   for (size_t i = 0; i < output_data_.size(); ++i) {
311     MS_EXCEPTION_IF_NULL(output_data_[i].first);
312     if (output_data_[i].first->data_ == nullptr) {
313       std::string error_info = GetAID().Name() + " fetches data null, data index:" + std::to_string(i) +
314                                " to actor:" + output_data_[i].first->op_id_.Name() +
315                                " index:" + std::to_string(output_data_[i].first->index_);
316       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317     }
318     IncreaseDynamicRefCount(output_data_[i].first.get());
319   }
320 
321   // Increase dynamic ref count by the output partial.
322   for (const auto &output_partial_arrow : output_partial_arrows_) {
323     MS_EXCEPTION_IF_NULL(output_partial_arrow);
324     if (IntToSize(output_partial_arrow->from_output_index_) >= input_partials_.size()) {
325       std::string error_info = "Invalid partial input:" + std::to_string(output_partial_arrow->from_output_index_) +
326                                " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
327       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
328     }
329     auto output_partial = input_partials_[IntToSize(output_partial_arrow->from_output_index_)];
330     IncreaseDynamicRefCount(output_partial);
331   }
332 }
333 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)334 void ControlActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
335   MS_EXCEPTION_IF_NULL(context);
336   const auto &sequential_num = context->sequential_num_;
337 
338   // Collect the input device tensors.
339   std::vector<DeviceTensor *> memory_free_list;
340   if (input_op_datas_.count(sequential_num) > 0) {
341     for (auto &input_op_data : input_op_datas_[sequential_num]) {
342       MS_EXCEPTION_IF_NULL(input_op_data);
343       MS_EXCEPTION_IF_NULL(input_op_data->data_);
344       (void)memory_free_list.emplace_back(input_op_data->data_);
345     }
346   }
347 
348   if (input_op_partials_.count(sequential_num) > 0) {
349     for (auto &input_op_partial : input_op_partials_[sequential_num]) {
350       GetAllDeviceTensors(input_op_partial.second, &memory_free_list);
351     }
352   }
353 
354   if (memory_free_list.size() > 0) {
355     memory_free_lists_.push(memory_free_list);
356     if (ActorDispatcher::is_memory_free_sync()) {
357       ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
358                                 device_contexts_[0], context, GetAID());
359     } else {
360       ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
361                             device_contexts_[0], context, GetAID());
362     }
363   }
364 }
365 
EraseInput(const OpContext<DeviceTensor> * context)366 void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) {
367   MS_EXCEPTION_IF_NULL(context);
368   const auto &sequential_num = context->sequential_num_;
369   AbstractActor::EraseInput(context);
370   if (input_partials_num_ != 0) {
371     auto ret = input_op_partials_.erase(sequential_num);
372     if (ret == 0) {
373       std::string error_info = "Erase input partial failed: " + GetAID().Name();
374       // The sequential num may be invalid, can't set the promise value of context.
375       MS_LOG(ERROR) << error_info << ", sequential_num: " << sequential_num;
376     }
377   }
378 
379   if (input_branch_ids_.find(sequential_num) != input_branch_ids_.end()) {
380     input_branch_ids_[sequential_num].pop();
381     if (input_branch_ids_[sequential_num].empty()) {
382       auto ret = input_branch_ids_.erase(sequential_num);
383       if (ret == 0) {
384         MS_LOG(ERROR) << "Erase input branch id failed: " << GetAID() << ", sequential_num: " << sequential_num;
385         return;
386       }
387     }
388   }
389 }
390 
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr & data_arrow,const AnfNodePtr &,OpContext<DeviceTensor> * const context)391 void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
392                                     const AnfNodePtr &, OpContext<DeviceTensor> *const context) {
393   MS_EXCEPTION_IF_NULL(output_data);
394   MS_EXCEPTION_IF_NULL(data_arrow);
395   auto formal_parameter_position = data_arrow->from_output_index_;
396   // Has no the ref node formal parameter.
397   if (ref_node_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) {
398     return;
399   }
400 
401   MS_EXCEPTION_IF_NULL(context);
402   const auto &data = output_data->data_;
403   MS_EXCEPTION_IF_NULL(data);
404   if ((!data->IsPtrValid()) || (data->ref_count() != SIZE_MAX) || (data->dynamic_ref_count() != INT32_MAX)) {
405     std::string error_info = "The address of the " + std::to_string(formal_parameter_position) +
406                              " position real parameter is nullptr or ref count is wrong.";
407     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
408   }
409 
410   // Foreach the device tensors to set the ptr from data, only the formal parameter device tensor of ref node need set
411   // before kernel running, because it will be used by ref output node.
412   for (auto &device_tensor : ref_node_formal_parameter_device_tensors_[formal_parameter_position]) {
413     MS_EXCEPTION_IF_NULL(device_tensor);
414     if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
415       continue;
416     }
417     auto formal_parameter = device_tensor->GetNodeIndex();
418     MS_EXCEPTION_IF_NULL(formal_parameter.first);
419     if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->type_id() != data->type_id())) {
420       MS_LOG(WARNING) << "The formal parameter: " << formal_parameter.first->DebugString()
421                       << " position:" << formal_parameter_position
422                       << "please check the size and type id, formal parameter size:" << device_tensor->GetSize()
423                       << " type id:" << device_tensor->type_id() << ", real parameter size:" << data->GetSize()
424                       << " type id:" << data->type_id();
425     }
426 
427     // Copy from the real parameter to formal parameter and insert the device tensor copy store.
428     if ((!AnfAlgo::IsEquivalentFormat(device_tensor->format(), data->format())) ||
429         (device_tensor->GetDeviceType() != data->GetDeviceType())) {
430       MS_LOG(INFO) << GetAID().Name() << " the input position:" << formal_parameter_position
431                    << " copy from real parameter address:" << data << ", type:" << data->GetDeviceType()
432                    << ", format:" << data->format() << " to formal parameter address:" << device_tensor.get()
433                    << ", type:" << device_tensor->GetDeviceType() << ", format:" << device_tensor->format()
434                    << ", formal parameter name:" << formal_parameter.first->DebugString();
435       const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
436         {device_tensor->device_name(), device_tensor->device_id()});
437       MS_EXCEPTION_IF_NULL(device_context);
438       device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther, 0);
439       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "UpdateOutputData", "");
440       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(), device::tracker::MemType::kOther,
441                                                      device_tensor->GetSize(), device_tensor.get());
442       auto data_stream_id = data->stream_id();
443       auto device_tensor_stream_id = device_tensor->stream_id();
444       if (device_tensor_stream_id != data_stream_id) {
445         MS_LOG(INFO) << "Rewrite device tesnor stream id from : " << device_tensor_stream_id
446                      << " to data stream id : " << data_stream_id << ".";
447         device_tensor->set_stream_id(data_stream_id);
448       }
449       if ((device_tensor->GetPtr() == nullptr) &&
450           (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex))) {
451         SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
452                                                     formal_parameter.first->DebugString(), device_tensor->GetSize());
453       }
454       if (!Copy(device_tensor.get(), data)) {
455         std::string error_info = "The formal parameter: " + formal_parameter.first->DebugString() +
456                                  " position:" + std::to_string(formal_parameter_position) + " copy failed.";
457         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
458       }
459       DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
460       output_data->data_ = device_tensor.get();
461       continue;
462     }
463 
464     // Ref node may use the ptr of device tensor as the output address, so need set the ptr from data.
465     device_tensor->set_ptr(data->GetMutablePtr());
466     MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
467                   << " for the ref formal parameter: " << formal_parameter.first->DebugString()
468                   << " in the actor: " << GetAID().Name();
469   }
470 }
471 
SendOutput(OpContext<DeviceTensor> * const context)472 void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
473   // Send branch id.
474   for (const auto &branch_id_arrow : output_branch_id_arrows_) {
475     ActorDispatcher::Send(branch_id_arrow, &ControlActor::RunBranchID, output_branch_id_, context);
476   }
477 
478   // Send data in base class.
479   AbstractActor::SendOutput(context);
480 
481   // Send Partial.
482   for (const auto &partial_arrow : output_partial_arrows_) {
483     MS_EXCEPTION_IF_NULL(partial_arrow);
484     if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
485       std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
486                                " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
487       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
488     }
489     auto output_partial = input_partials_[IntToSize(partial_arrow->from_output_index_)];
490     MS_EXCEPTION_IF_NULL(output_partial);
491     ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
492                           IntToSize(partial_arrow->to_input_index_), context);
493   }
494 
495   // Update the start time in end actor.
496   for (const auto &actor : end_actors_) {
497     MS_EXCEPTION_IF_NULL(actor);
498     actor->set_start_time(GetTime());
499   }
500 }
501 namespace {
CreateRealMakeTuple(const std::vector<DeviceTensor * > & addr_list,const FuncGraphPtr & func_graph)502 CNodePtr CreateRealMakeTuple(const std::vector<DeviceTensor *> &addr_list, const FuncGraphPtr &func_graph) {
503   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRealMakeTuple)};
504   auto new_cnode = func_graph->NewCNode(inputs);
505   std::vector<std::string> formats;
506   MS_EXCEPTION_IF_NULL(new_cnode);
507   std::vector<abstract::AbstractBasePtr> abs_list;
508   for (const auto &addr : addr_list) {
509     MS_EXCEPTION_IF_NULL(addr);
510     auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(addr->type_id()), addr->host_shape());
511     abs_list.emplace_back(abs);
512     formats.emplace_back(addr->format());
513     MS_LOG(DEBUG) << "Create new abstract:" << abs->ToString();
514   }
515   auto tuple_abs = std::make_shared<abstract::AbstractTuple>(abs_list);
516   MS_LOG(DEBUG) << "Create abstract for real make tuple:" << tuple_abs->ToString();
517   // Set dynamic len element abstract to check the abstract is dynamic len.
518   abstract::AbstractBasePtr element_abs = (abs_list.empty() ? std::make_shared<abstract::AbstractTensor>(
519                                                                 TypeIdToType(TypeId::kNumberTypeInt64), ShapeVector())
520                                                             : abs_list[0]);
521   tuple_abs->set_dynamic_len_element_abs(element_abs);
522   new_cnode->set_abstract(tuple_abs);
523 
524   // Create kernel info for node and set format for it.
525   auto kernel_info = std::make_shared<device::KernelInfo>();
526   MS_EXCEPTION_IF_NULL(kernel_info);
527   auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
528   MS_EXCEPTION_IF_NULL(builder);
529   kernel_info->set_select_kernel_build_info(builder->Build());
530   new_cnode->set_kernel_info(kernel_info);
531   builder->SetOutputsFormat(formats);
532   return new_cnode;
533 }
534 
CheckDeviceAddressConsist(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,const std::string & actor_name)535 void CheckDeviceAddressConsist(OpContext<DeviceTensor> *const context, const std::vector<DeviceTensor *> &addr_list,
536                                const std::string &actor_name) {
537   MS_EXCEPTION_IF_NULL(context);
538   if (addr_list.empty() || addr_list[0] == nullptr) {
539     return;
540   }
541   // Check consistence of device address.
542   const auto &shape = addr_list[0]->host_shape();
543   const auto &size = addr_list[0]->GetSize();
544   const auto &type = addr_list[0]->type_id();
545   const auto &device_name = addr_list[0]->device_name();
546   for (size_t i = 1; i < addr_list.size(); ++i) {
547     MS_EXCEPTION_IF_NULL(addr_list[i]);
548     if (size != addr_list[i]->GetSize() || type != addr_list[i]->type_id()) {
549       MS_LOG(ERROR) << "Failed to merge two device address, addr1:" << addr_list[0] << " size:" << size
550                     << " shape:" << shape << " device name:" << device_name << " type:" << type
551                     << " addr2:" << addr_list[i] << " size:" << addr_list[i]->GetSize()
552                     << " shape:" << addr_list[i]->host_shape() << " device name:" << addr_list[i]->device_name()
553                     << " type" << addr_list[i]->type_id() << " for actor:" << actor_name;
554       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Failed to merge two device address");
555     }
556     if (shape != addr_list[i]->host_shape()) {
557       MS_LOG(WARNING) << "Merge two device address with different shape, addr1 shape:" << shape
558                       << " addr2 shape:" << addr_list[i]->host_shape() << " for actor:" << actor_name;
559     }
560   }
561 }
562 }  // namespace
563 
MergeDeviceAddress(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,DeviceTensor ** device_tensor)564 void ControlActor::MergeDeviceAddress(OpContext<DeviceTensor> *const context,
565                                       const std::vector<DeviceTensor *> &addr_list, DeviceTensor **device_tensor) {
566   MS_EXCEPTION_IF_NULL(context);
567   MS_EXCEPTION_IF_NULL(device_tensor);
568   if (addr_list.empty()) {
569     MergeEmptyAddressDeviceAddress(context, addr_list, device_tensor);
570     return;
571   }
572 
573   CheckDeviceAddressConsist(context, addr_list, GetAID().Name());
574   MS_EXCEPTION_IF_NULL(addr_list[0]);
575   const auto &total_size = addr_list[0]->GetSize() * addr_list.size();
576   ShapeVector total_shape = {SizeToLong(addr_list.size())};
577   const auto &shape = addr_list[0]->host_shape();
578   total_shape.insert(total_shape.end(), shape.begin(), shape.end());
579   auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
580     {addr_list[0]->device_name(), addr_list[0]->device_id()});
581   MS_EXCEPTION_IF_NULL(device_context);
582   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
583 
584   abstract::BaseShapePtrList shape_list(addr_list.size(), addr_list[0]->kernel_tensor()->GetShape());
585   auto tuple_shape = std::make_shared<abstract::TupleShape>(shape_list);
586   TypePtrList type_list(addr_list.size(), addr_list[0]->kernel_tensor()->GetType());
587   auto tuple_type = std::make_shared<Tuple>(type_list);
588   MS_LOG(DEBUG) << "Create kernel tensor by shape:" << tuple_shape->ToString() << " type:" << tuple_type->ToString()
589                 << " in device address:" << addr_list[0];
590   const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
591     tuple_shape, tuple_type, nullptr, nullptr, total_size, addr_list[0]->format(), addr_list[0]->type_id(), total_shape,
592     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
593   kernel_tensor->set_stream_id(addr_list[0]->stream_id());
594   const auto &new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
595   MS_EXCEPTION_IF_NULL(new_device_tensor);
596 
597   MS_LOG(DEBUG) << "Create device tensor:" << new_device_tensor << " type:" << new_device_tensor->type_id();
598   if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
599     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
600                                                 GetAID().Name(), new_device_tensor->GetSize());
601   }
602   MS_EXCEPTION_IF_NULL(new_device_tensor->GetMutablePtr());
603 
604   // Create a new real maketuple node for new device address.
605   FuncGraphPtr fg = std::make_shared<FuncGraph>();
606   auto new_cnode = CreateRealMakeTuple(addr_list, fg);
607   AnfAlgo::SetOutputAddr(new_device_tensor, 0, new_cnode.get());
608   created_new_graphs_.emplace_back(fg);
609   created_new_nodes_.emplace_back(new_cnode);
610   new_device_tensor->SetNodeIndex(new_cnode, 0);
611   new_device_tensor->set_from_persistent_mem(addr_list[0]->from_persistent_mem());
612   new_device_tensor->set_dynamic_ref_count(0);
613   new_device_tensor->set_original_ref_count(SIZE_MAX);
614   new_device_tensor->ResetRefCount();
615 
616   // Merge device address list into a single device address.
617   auto tmp_kernel_tensor = std::make_shared<kernel::KernelTensor>(
618     new_device_tensor->GetMutablePtr(), addr_list[0]->GetSize(), kernel::GetFormatFromStrToEnum(addr_list[0]->format()),
619     addr_list[0]->type_id(), shape, device_context->device_context_key().device_name_,
620     device_context->device_context_key().device_id_);
621   tmp_kernel_tensor->set_stream_id(addr_list[0]->stream_id());
622   const auto &tmp_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(tmp_kernel_tensor);
623   MS_EXCEPTION_IF_NULL(tmp_device_tensor);
624   MS_LOG(DEBUG) << "Create device tensor:" << tmp_device_tensor << " type:" << tmp_device_tensor->type_id();
625   std::shared_ptr<int64_t> max_task_id_on_stream = nullptr;
626   for (size_t i = 0; i < addr_list.size(); ++i) {
627     auto device_tensor_addr = addr_list[i];
628     auto task_id_on_stream = device_tensor_addr->kernel_tensor()->task_id_on_stream();
629     if (task_id_on_stream != nullptr) {
630       if (max_task_id_on_stream == nullptr) {
631         max_task_id_on_stream = task_id_on_stream;
632       } else {
633         if (*max_task_id_on_stream < *task_id_on_stream) {
634           max_task_id_on_stream = task_id_on_stream;
635         }
636       }
637     }
638     bool ret = false;
639     if (addr_list[i]->device_name() == addr_list[0]->device_name()) {
640       ret = tmp_device_tensor->SyncDeviceToDevice(addr_list[i]);
641     } else if (addr_list[0]->device_name() == kCPUDevice) {
642       ret = addr_list[i]->SyncDeviceToHost(addr_list[i]->GetSize(), tmp_device_tensor->GetMutablePtr());
643     } else if (addr_list[i]->device_name() == kCPUDevice) {
644       ret = tmp_device_tensor->SyncHostToDevice(addr_list[i]->GetSize(), addr_list[i]->GetMutablePtr());
645     } else {
646       MS_LOG(ERROR) << "Invalid device name for addr1:" << addr_list[0] << " name:" << addr_list[0]->device_name()
647                     << " and addr2:" << addr_list[i] << " name:" << addr_list[i]->device_name();
648     }
649     if (!ret) {
650       SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
651     }
652     tmp_device_tensor->set_ptr((reinterpret_cast<char *>(tmp_device_tensor->GetMutablePtr())) +
653                                addr_list[0]->GetSize());
654   }
655   new_device_tensor->kernel_tensor()->set_task_id_on_stream(max_task_id_on_stream);
656   tmp_device_tensor->set_ptr(nullptr);
657   created_device_tensors_.emplace_back(new_device_tensor);
658   MS_LOG(DEBUG) << "actor:" << GetAID() << " create new device address:" << new_device_tensor
659                 << " for addr list size:" << addr_list.size()
660                 << " device address shape:" << new_device_tensor->host_shape();
661   (*device_tensor) = new_device_tensor.get();
662   return;
663 }
664 
MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> * const context,const std::vector<DeviceTensor * > & addr_list,DeviceTensor ** device_tensor)665 void ControlActor::MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> *const context,
666                                                   const std::vector<DeviceTensor *> &addr_list,
667                                                   DeviceTensor **device_tensor) {
668   // Create device address for empty tuple.
669   // Fetch the default device context for empty sequence.
670   auto context_ptr = MsContext::GetInstance();
671   MS_EXCEPTION_IF_NULL(context_ptr);
672   auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
673     {context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET), context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
674   MS_EXCEPTION_IF_NULL(device_context);
675   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
676 
677   auto tuple_shape = std::make_shared<abstract::TupleShape>();
678   auto tuple_type = std::make_shared<Tuple>();
679   const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
680     tuple_shape, tuple_type, nullptr, nullptr, 0, kOpFormat_DEFAULT, TypeId::kNumberTypeInt64, ShapeVector(),
681     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
682   const auto &new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
683   MS_EXCEPTION_IF_NULL(new_device_tensor);
684   new_device_tensor->set_dynamic_ref_count(0);
685   new_device_tensor->set_original_ref_count(SIZE_MAX);
686   new_device_tensor->ResetRefCount();
687   if (!device_context->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
688     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
689                                                 GetAID().Name(), new_device_tensor->GetSize());
690   }
691   created_device_tensors_.emplace_back(new_device_tensor);
692   (*device_tensor) = new_device_tensor.get();
693   MS_LOG(DEBUG) << "actor:" << GetAID() << " create new device address:" << new_device_tensor << " for empty addr list";
694 }
695 }  // namespace runtime
696 }  // namespace mindspore
697