1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/data_source_actor.h"
18 #include "runtime/framework/actor/kernel_actor.h"
19 #include "runtime/framework/actor/memory_manager_actor.h"
20 #include "runtime/framework/actor/output_actor.h"
21 #include "runtime/framework/actor/recorder_actor.h"
22 #include "runtime/framework/actor/debug_actor.h"
23 #include "mindrt/include/async/async.h"
24 #include "common/trans.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace runtime {
Init()29 void DataSourceActor::Init() {
30 // Check device contexts number.
31 if (device_contexts_.size() < device::kDeviceContextsNumOne) {
32 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
33 }
34
35 // Init output data.
36 for (auto &data_arrow : output_data_arrows_) {
37 MS_EXCEPTION_IF_NULL(data_arrow);
38 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
39 (void)output_data_.emplace_back(std::move(data));
40 }
41 }
42
FetchData(OpContext<DeviceTensor> * const context)43 void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) {
44 MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data.";
45 MS_EXCEPTION_IF_NULL(context);
46 // Pop the data of last time.
47 if (!buffers_.empty()) {
48 buffers_.pop();
49 }
50
51 // Construct device tensors and fill to the buffers from member nodes.
52 FillDataBuffer();
53 if (buffers_.size() == 0) {
54 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
55 }
56
57 // Allocate memory for device tensors.
58 SendMemoryAllocReq(context);
59 }
60
SendOutput(OpContext<DeviceTensor> * const context)61 void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
62 MS_EXCEPTION_IF_NULL(context);
63 // No output.
64 if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
65 (output_result_arrows_.size() == 0)) {
66 SET_OPCONTEXT_SUCCESS_RET((*context));
67 }
68
69 if (buffers_.size() == 0) {
70 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
71 }
72
73 // Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
74 // 1.Send graph output result.
75 SendResult(context);
76
77 // 2.Send output data.
78 const auto &output_device_tensors = buffers_.front();
79 for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
80 auto &data_arrow = output_data_arrows_[i];
81 auto &output_data = output_data_[i];
82 MS_EXCEPTION_IF_NULL(data_arrow);
83 MS_EXCEPTION_IF_NULL(output_data);
84 if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) {
85 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
86 }
87 output_data->data_ = output_device_tensors[data_arrow->from_output_index_];
88 Async(data_arrow->to_op_id_, &OpActor::RunOpData, output_data.get(), context);
89 }
90
91 // 3.Send output control.
92 if (output_control_arrows_.size() > 0) {
93 auto source_aid = const_cast<AID *>(&GetAID());
94 for (auto &output_control : output_control_arrows_) {
95 Async(output_control, &OpActor::RunOpControl, source_aid, context);
96 }
97 }
98
99 // 4.Send recorder info.
100 if (recorder_aid_ != nullptr) {
101 SendRecorderInfo(context);
102 }
103 }
104
Init()105 void DeviceQueueDataSourceActor::Init() {
106 // Check device contexts number.
107 if (device_contexts_.size() != device::kDeviceContextsNumOne) {
108 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
109 }
110
111 // Init output data.
112 for (auto &data_arrow : output_data_arrows_) {
113 MS_EXCEPTION_IF_NULL(data_arrow);
114 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
115 (void)output_data_.emplace_back(std::move(data));
116 }
117
118 // Init kernel launch info.
119 MS_EXCEPTION_IF_NULL(kernel_info_);
120 for (size_t i = 0; i < kernel_info_->output_address_list().size(); ++i) {
121 (void)launch_info_.outputs_.emplace_back(std::make_shared<Address>());
122 }
123 }
124
FillDataBuffer()125 void DeviceQueueDataSourceActor::FillDataBuffer() {
126 MS_EXCEPTION_IF_NULL(kernel_info_);
127 // Construct device tensors.
128 std::vector<DeviceTensor *> device_tensors;
129 for (auto &device_tensor : kernel_info_->output_address_list()) {
130 MS_EXCEPTION_IF_NULL(device_tensor);
131 (void)device_tensors.emplace_back(device_tensor.get());
132 }
133
134 buffers_.push(device_tensors);
135 }
136
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)137 void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
138 auto &device_tensors = buffers_.back();
139 Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
140 GetAID());
141 }
142
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)143 void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
144 auto &device_tensors = buffers_.front();
145 Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
146 }
147
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)148 void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
149 MS_EXCEPTION_IF_NULL(context);
150 MS_EXCEPTION_IF_NULL(data_kernel_);
151 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
152 if (buffers_.size() == 0) {
153 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
154 }
155
156 // Construct outputs of data kernel launching.
157 auto &device_tensors = buffers_.back();
158 if (launch_info_.outputs_.size() != device_tensors.size()) {
159 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is not equal to the device tensors number.");
160 }
161 for (size_t i = 0; i < device_tensors.size(); ++i) {
162 MS_EXCEPTION_IF_NULL(launch_info_.outputs_[i]);
163 MS_EXCEPTION_IF_NULL(device_tensors[i]);
164 launch_info_.outputs_[i]->addr = device_tensors[i]->GetMutablePtr();
165 launch_info_.outputs_[i]->size = device_tensors[i]->GetSize();
166 }
167
168 // Copy data from device queue by data kernel launching.
169 try {
170 auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
171 launch_info_.outputs_);
172 if (!ret) {
173 std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
174 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
175 }
176 } catch (const std::exception &e) {
177 MsException::Instance().SetException();
178 std::string error_info = "Launch kernel exception: " + data_kernel_->fullname_with_scope();
179 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
180 }
181
182 // Debug actor is blocked, must wait debug actor callback message to process continue.
183 if (debug_aid_ != nullptr) {
184 SendDebugReq(context);
185 return;
186 }
187
188 // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
189 // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
190 // the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully,
191 // the other is to ensure the execution order and avoid the illegal memory timing problem.
192 SendMemoryFreeReq(context);
193 SendOutput(context);
194 }
195
SendDebugReq(OpContext<DeviceTensor> * const context)196 void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
197 Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
198 }
199
OnDebugFinish(OpContext<DeviceTensor> * const context)200 void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
201 SendMemoryFreeReq(context);
202 SendOutput(context);
203 }
204
SendResult(OpContext<DeviceTensor> * const context)205 void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
206 for (const auto &result_arrow : output_result_arrows_) {
207 MS_EXCEPTION_IF_NULL(result_arrow);
208 Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_kernel_, result_arrow->from_output_index_,
209 result_arrow->to_input_index_, context);
210 }
211 }
212
SendRecorderInfo(OpContext<DeviceTensor> * const context)213 void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
214 if (recorder_aid_ != nullptr) {
215 MS_EXCEPTION_IF_NULL(data_kernel_);
216 Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
217 device_contexts_[0], context);
218 }
219 }
220
FillDataBuffer()221 void HostQueueDataSourceActor::FillDataBuffer() {
222 // Construct device tensors.
223 std::vector<DeviceTensor *> device_tensors;
224 for (auto &data_node : data_nodes_) {
225 auto device_address = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
226 MS_EXCEPTION_IF_NULL(device_address);
227 (void)device_tensors.emplace_back(device_address.get());
228 }
229
230 buffers_.push(device_tensors);
231 }
232
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)233 void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
234 auto &device_tensors = buffers_.back();
235 if (IsSameDeviceType()) {
236 Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
237 GetAID());
238 } else {
239 Async(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors, &device_contexts_, context,
240 GetAID());
241 }
242 }
243
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)244 void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
245 auto &device_tensors = buffers_.front();
246 if (IsSameDeviceType()) {
247 Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
248 } else {
249 Async(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors, &device_contexts_, context);
250 }
251 }
252
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)253 void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
254 MS_EXCEPTION_IF_NULL(context);
255 if (buffers_.size() == 0) {
256 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
257 }
258
259 // Get host tensors from host queue and get device tensors from buffers.
260 MS_EXCEPTION_IF_NULL(host_queue_);
261 if (host_queue_->IsEmpty()) {
262 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data queue is empty.");
263 }
264 auto &host_tensors = host_queue_->Pull();
265 auto &device_tensors = buffers_.back();
266 if (host_tensors.size() != device_tensors.size()) {
267 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
268 "The length of host tensors is not equal to the length of device tensors.");
269 }
270
271 // Copy data from host tensor to device tensor.
272 for (size_t i = 0; i < host_tensors.size(); ++i) {
273 auto &host_tensor = host_tensors[i];
274 auto &device_tensor = device_tensors[i];
275 MS_EXCEPTION_IF_NULL(host_tensor);
276 MS_EXCEPTION_IF_NULL(device_tensor);
277 auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
278 // Sync data from host_tensor_device_address to device_tensor.
279 if (tensor_device_address != nullptr) {
280 if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
281 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
282 }
283 continue;
284 }
285
286 // Sync data from host_tensor to device_tensor.
287 if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
288 LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
289 host_tensor->data_c(), host_tensor->device_info().host_format_)) {
290 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
291 }
292 }
293 host_queue_->Pop();
294
295 // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
296 // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
297 // the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully,
298 // the other is to ensure the execution order and avoid the illegal memory timing problem.
299 SendMemoryFreeReq(context);
300 SendOutput(context);
301 }
302
SendResult(OpContext<DeviceTensor> * const context)303 void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
304 for (const auto &result_arrow : output_result_arrows_) {
305 MS_EXCEPTION_IF_NULL(result_arrow);
306 if (IntToSize(result_arrow->from_output_index_) >= data_nodes_.size()) {
307 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
308 }
309 Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_nodes_[result_arrow->from_output_index_], 0,
310 result_arrow->to_input_index_, context);
311 }
312 }
313
FetchNodePosition(const AnfNodePtr & data_node) const314 size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node) const {
315 MS_EXCEPTION_IF_NULL(data_node);
316 const auto &iter = data_node_position_map_.find(data_node);
317 if (iter == data_node_position_map_.end()) {
318 MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist.";
319 }
320 return iter->second;
321 }
322
FetchNode(size_t node_position) const323 AnfNodePtr HostQueueDataSourceActor::FetchNode(size_t node_position) const {
324 if (node_position >= data_nodes_.size()) {
325 MS_LOG(EXCEPTION) << "The position of node is out of range: " << node_position;
326 }
327 return data_nodes_[node_position];
328 }
329
IsSameDeviceType() const330 bool HostQueueDataSourceActor::IsSameDeviceType() const {
331 for (size_t i = 1; i < device_contexts_.size(); i++) {
332 if (device_contexts_[i] != device_contexts_[0]) {
333 return false;
334 }
335 }
336 return true;
337 }
338 } // namespace runtime
339 } // namespace mindspore
340