• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/actor/data_source_actor.h"
18 #include "runtime/framework/actor/kernel_actor.h"
19 #include "runtime/framework/actor/memory_manager_actor.h"
20 #include "runtime/framework/actor/output_actor.h"
21 #include "runtime/framework/actor/recorder_actor.h"
22 #include "runtime/framework/actor/debug_actor.h"
23 #include "mindrt/include/async/async.h"
24 #include "common/trans.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace runtime {
Init()29 void DataSourceActor::Init() {
30   // Check device contexts number.
31   if (device_contexts_.size() < device::kDeviceContextsNumOne) {
32     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
33   }
34 
35   // Init output data.
36   for (auto &data_arrow : output_data_arrows_) {
37     MS_EXCEPTION_IF_NULL(data_arrow);
38     auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
39     (void)output_data_.emplace_back(std::move(data));
40   }
41 }
42 
FetchData(OpContext<DeviceTensor> * const context)43 void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) {
44   MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data.";
45   MS_EXCEPTION_IF_NULL(context);
46   // Pop the data of last time.
47   if (!buffers_.empty()) {
48     buffers_.pop();
49   }
50 
51   // Construct device tensors and fill to the buffers from member nodes.
52   FillDataBuffer();
53   if (buffers_.size() == 0) {
54     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
55   }
56 
57   // Allocate memory for device tensors.
58   SendMemoryAllocReq(context);
59 }
60 
SendOutput(OpContext<DeviceTensor> * const context)61 void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
62   MS_EXCEPTION_IF_NULL(context);
63   // No output.
64   if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
65       (output_result_arrows_.size() == 0)) {
66     SET_OPCONTEXT_SUCCESS_RET((*context));
67   }
68 
69   if (buffers_.size() == 0) {
70     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
71   }
72 
73   // Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
74   // 1.Send graph output result.
75   SendResult(context);
76 
77   // 2.Send output data.
78   const auto &output_device_tensors = buffers_.front();
79   for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
80     auto &data_arrow = output_data_arrows_[i];
81     auto &output_data = output_data_[i];
82     MS_EXCEPTION_IF_NULL(data_arrow);
83     MS_EXCEPTION_IF_NULL(output_data);
84     if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) {
85       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
86     }
87     output_data->data_ = output_device_tensors[data_arrow->from_output_index_];
88     Async(data_arrow->to_op_id_, &OpActor::RunOpData, output_data.get(), context);
89   }
90 
91   // 3.Send output control.
92   if (output_control_arrows_.size() > 0) {
93     auto source_aid = const_cast<AID *>(&GetAID());
94     for (auto &output_control : output_control_arrows_) {
95       Async(output_control, &OpActor::RunOpControl, source_aid, context);
96     }
97   }
98 
99   // 4.Send recorder info.
100   if (recorder_aid_ != nullptr) {
101     SendRecorderInfo(context);
102   }
103 }
104 
Init()105 void DeviceQueueDataSourceActor::Init() {
106   // Check device contexts number.
107   if (device_contexts_.size() != device::kDeviceContextsNumOne) {
108     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
109   }
110 
111   // Init output data.
112   for (auto &data_arrow : output_data_arrows_) {
113     MS_EXCEPTION_IF_NULL(data_arrow);
114     auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
115     (void)output_data_.emplace_back(std::move(data));
116   }
117 
118   // Init kernel launch info.
119   MS_EXCEPTION_IF_NULL(kernel_info_);
120   for (size_t i = 0; i < kernel_info_->output_address_list().size(); ++i) {
121     (void)launch_info_.outputs_.emplace_back(std::make_shared<Address>());
122   }
123 }
124 
FillDataBuffer()125 void DeviceQueueDataSourceActor::FillDataBuffer() {
126   MS_EXCEPTION_IF_NULL(kernel_info_);
127   // Construct device tensors.
128   std::vector<DeviceTensor *> device_tensors;
129   for (auto &device_tensor : kernel_info_->output_address_list()) {
130     MS_EXCEPTION_IF_NULL(device_tensor);
131     (void)device_tensors.emplace_back(device_tensor.get());
132   }
133 
134   buffers_.push(device_tensors);
135 }
136 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)137 void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
138   auto &device_tensors = buffers_.back();
139   Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
140         GetAID());
141 }
142 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)143 void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
144   auto &device_tensors = buffers_.front();
145   Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
146 }
147 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)148 void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
149   MS_EXCEPTION_IF_NULL(context);
150   MS_EXCEPTION_IF_NULL(data_kernel_);
151   MS_EXCEPTION_IF_NULL(device_contexts_[0]);
152   if (buffers_.size() == 0) {
153     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
154   }
155 
156   // Construct outputs of data kernel launching.
157   auto &device_tensors = buffers_.back();
158   if (launch_info_.outputs_.size() != device_tensors.size()) {
159     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is not equal to the device tensors number.");
160   }
161   for (size_t i = 0; i < device_tensors.size(); ++i) {
162     MS_EXCEPTION_IF_NULL(launch_info_.outputs_[i]);
163     MS_EXCEPTION_IF_NULL(device_tensors[i]);
164     launch_info_.outputs_[i]->addr = device_tensors[i]->GetMutablePtr();
165     launch_info_.outputs_[i]->size = device_tensors[i]->GetSize();
166   }
167 
168   // Copy data from device queue by data kernel launching.
169   try {
170     auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
171                                                  launch_info_.outputs_);
172     if (!ret) {
173       std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
174       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
175     }
176   } catch (const std::exception &e) {
177     MsException::Instance().SetException();
178     std::string error_info = "Launch kernel exception: " + data_kernel_->fullname_with_scope();
179     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
180   }
181 
182   // Debug actor is blocked, must wait debug actor callback message to process continue.
183   if (debug_aid_ != nullptr) {
184     SendDebugReq(context);
185     return;
186   }
187 
188   // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
189   // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
190   // the current actor is in front of SendMemoryAllocReq of the next actor.  One is to reuse the memory more fully,
191   // the other is to ensure the execution order and avoid the illegal memory timing problem.
192   SendMemoryFreeReq(context);
193   SendOutput(context);
194 }
195 
SendDebugReq(OpContext<DeviceTensor> * const context)196 void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
197   Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
198 }
199 
OnDebugFinish(OpContext<DeviceTensor> * const context)200 void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
201   SendMemoryFreeReq(context);
202   SendOutput(context);
203 }
204 
SendResult(OpContext<DeviceTensor> * const context)205 void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
206   for (const auto &result_arrow : output_result_arrows_) {
207     MS_EXCEPTION_IF_NULL(result_arrow);
208     Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_kernel_, result_arrow->from_output_index_,
209           result_arrow->to_input_index_, context);
210   }
211 }
212 
SendRecorderInfo(OpContext<DeviceTensor> * const context)213 void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
214   if (recorder_aid_ != nullptr) {
215     MS_EXCEPTION_IF_NULL(data_kernel_);
216     Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
217           device_contexts_[0], context);
218   }
219 }
220 
FillDataBuffer()221 void HostQueueDataSourceActor::FillDataBuffer() {
222   // Construct device tensors.
223   std::vector<DeviceTensor *> device_tensors;
224   for (auto &data_node : data_nodes_) {
225     auto device_address = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
226     MS_EXCEPTION_IF_NULL(device_address);
227     (void)device_tensors.emplace_back(device_address.get());
228   }
229 
230   buffers_.push(device_tensors);
231 }
232 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)233 void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
234   auto &device_tensors = buffers_.back();
235   if (IsSameDeviceType()) {
236     Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
237           GetAID());
238   } else {
239     Async(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors, &device_contexts_, context,
240           GetAID());
241   }
242 }
243 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)244 void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
245   auto &device_tensors = buffers_.front();
246   if (IsSameDeviceType()) {
247     Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
248   } else {
249     Async(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors, &device_contexts_, context);
250   }
251 }
252 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)253 void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
254   MS_EXCEPTION_IF_NULL(context);
255   if (buffers_.size() == 0) {
256     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
257   }
258 
259   // Get host tensors from host queue and get device tensors from buffers.
260   MS_EXCEPTION_IF_NULL(host_queue_);
261   if (host_queue_->IsEmpty()) {
262     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data queue is empty.");
263   }
264   auto &host_tensors = host_queue_->Pull();
265   auto &device_tensors = buffers_.back();
266   if (host_tensors.size() != device_tensors.size()) {
267     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
268                                       "The length of host tensors is not equal to the length of device tensors.");
269   }
270 
271   // Copy data from host tensor to device tensor.
272   for (size_t i = 0; i < host_tensors.size(); ++i) {
273     auto &host_tensor = host_tensors[i];
274     auto &device_tensor = device_tensors[i];
275     MS_EXCEPTION_IF_NULL(host_tensor);
276     MS_EXCEPTION_IF_NULL(device_tensor);
277     auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
278     // Sync data from host_tensor_device_address to device_tensor.
279     if (tensor_device_address != nullptr) {
280       if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
281         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
282       }
283       continue;
284     }
285 
286     // Sync data from host_tensor to device_tensor.
287     if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
288                                          LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
289                                          host_tensor->data_c(), host_tensor->device_info().host_format_)) {
290       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
291     }
292   }
293   host_queue_->Pop();
294 
295   // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
296   // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
297   // the current actor is in front of SendMemoryAllocReq of the next actor.  One is to reuse the memory more fully,
298   // the other is to ensure the execution order and avoid the illegal memory timing problem.
299   SendMemoryFreeReq(context);
300   SendOutput(context);
301 }
302 
SendResult(OpContext<DeviceTensor> * const context)303 void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
304   for (const auto &result_arrow : output_result_arrows_) {
305     MS_EXCEPTION_IF_NULL(result_arrow);
306     if (IntToSize(result_arrow->from_output_index_) >= data_nodes_.size()) {
307       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
308     }
309     Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_nodes_[result_arrow->from_output_index_], 0,
310           result_arrow->to_input_index_, context);
311   }
312 }
313 
FetchNodePosition(const AnfNodePtr & data_node) const314 size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node) const {
315   MS_EXCEPTION_IF_NULL(data_node);
316   const auto &iter = data_node_position_map_.find(data_node);
317   if (iter == data_node_position_map_.end()) {
318     MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist.";
319   }
320   return iter->second;
321 }
322 
FetchNode(size_t node_position) const323 AnfNodePtr HostQueueDataSourceActor::FetchNode(size_t node_position) const {
324   if (node_position >= data_nodes_.size()) {
325     MS_LOG(EXCEPTION) << "The position of node is out of range: " << node_position;
326   }
327   return data_nodes_[node_position];
328 }
329 
IsSameDeviceType() const330 bool HostQueueDataSourceActor::IsSameDeviceType() const {
331   for (size_t i = 1; i < device_contexts_.size(); i++) {
332     if (device_contexts_[i] != device_contexts_[0]) {
333       return false;
334     }
335   }
336   return true;
337 }
338 }  // namespace runtime
339 }  // namespace mindspore
340