1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h"
17
18 #include "tensorflow/core/common_runtime/copy_tensor.h"
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h"
24 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_reference.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/refcount.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #include "tensorflow/core/platform/tensor_coding.h"
37 #include "tensorflow/core/util/util.h"
38
39 // IMPLEMENTATION NOTE:
40 //
41 // 1. Within this module, we intentionally LOG(FATAL) if any stream
42 // involved in memcpy becomes !stream->ok(), because TF process
43 // today (3/2021) can not properly recover from such an error.
44 //
45 // 2. When 0-size tensor is being copied, we should not schedule a
46 // copy ThenMemcpy since there is no byte to move. However, we must
47 // ensure the causal ordering by arranging the copy done callback
48 // to happen after all activities scheduled on the given stream being
49 // finished.
50
51 namespace tensorflow {
52
53 using se::DeviceMemoryBase;
54 using se::Stream;
55
PrepareCopy(Device * device,const DeviceContext * ctx,const Tensor & src,const Tensor * dst,const DeviceBase::AcceleratorDeviceInfo ** dev_info,se::Stream ** stream)56 static Status PrepareCopy(Device* device, const DeviceContext* ctx,
57 const Tensor& src, const Tensor* dst,
58 const DeviceBase::AcceleratorDeviceInfo** dev_info,
59 se::Stream** stream) {
60 if (device == nullptr) {
61 return errors::Internal("Unexpected null device.");
62 }
63 auto di = device->tensorflow_accelerator_device_info();
64 if (di == nullptr) {
65 return errors::Internal("Unexpected null device info.");
66 }
67
68 *dev_info = di;
69 if (ctx == nullptr) {
70 return errors::Internal("Unexpected null device context.");
71 }
72 auto device_stream =
73 static_cast<const PluggableDeviceContext*>(ctx)->stream();
74 if (device_stream == nullptr) {
75 return errors::Internal("No PluggableDevice stream is available.");
76 }
77 *stream = device_stream;
78 if (dst != nullptr) {
79 if (src.dtype() != dst->dtype()) {
80 return errors::Internal("Can't copy a tensor of ",
81 DataTypeString(src.dtype()), " into a tensor of ",
82 DataTypeString(dst->dtype()));
83 }
84 if (src.TotalBytes() != dst->TotalBytes()) {
85 return errors::Internal("Can't copy ", src.TotalBytes(),
86 " bytes of a tensor into another with ",
87 dst->TotalBytes(), " bytes buffer.");
88 }
89 if ((src.TotalBytes() > 0) && !src.IsInitialized()) {
90 return errors::Internal("Src tensor is not initialized.");
91 }
92 if ((dst->TotalBytes() > 0) && !dst->IsInitialized()) {
93 return errors::Internal("Dst tensor is not initialized.");
94 }
95 }
96 if (!DMAHelper::CanUseDMA(&src)) {
97 return errors::Internal("PluggableDevice copy from non-DMA",
98 DataTypeString(src.dtype()), " tensor.");
99 }
100 return OkStatus();
101 }
102
GetBase(const Tensor * src)103 static void* GetBase(const Tensor* src) {
104 return const_cast<void*>(DMAHelper::base(src));
105 }
106
GetBase(Tensor * dst)107 static void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
108
109 // static
DeviceToDeviceCopy(DeviceContext * send_dev_context,DeviceContext * recv_dev_context,Device * src,Device * dst,AllocatorAttributes src_alloc_attr,AllocatorAttributes dst_alloc_attr,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done)110 void PluggableDeviceUtil::DeviceToDeviceCopy(
111 DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
112 Device* src, Device* dst, AllocatorAttributes src_alloc_attr,
113 AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output,
114 int dev_to_dev_stream_index, StatusCallback done) {
115 const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
116 se::Stream* send_stream = nullptr;
117 Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info,
118 &send_stream);
119 if (!s.ok()) {
120 done(s);
121 return;
122 }
123
124 auto send_device_to_device_stream =
125 static_cast<const PluggableDeviceContext*>(send_dev_context)
126 ->device_to_device_stream(dev_to_dev_stream_index);
127 if (send_device_to_device_stream == nullptr) {
128 done(errors::Internal(
129 "No send PluggableDevice copy-out-stream is available."));
130 return;
131 }
132 // Wait for the main stream on the sender to make sure the result is
133 // available.
134 send_device_to_device_stream->ThenWaitFor(send_stream);
135
136 const int64_t total_bytes = input->TotalBytes();
137 if (total_bytes > 0) {
138 void* src_ptr = GetBase(input);
139 DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
140 void* dst_ptr = GetBase(output);
141 DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
142 auto recv_stream =
143 static_cast<const PluggableDeviceContext*>(recv_dev_context)->stream();
144 if (recv_stream == nullptr) {
145 done(errors::Internal("No recv PluggableDevice stream is available."));
146 return;
147 }
148 // Since we want to use the memory from recv_stream in the
149 // send_device_to_host_stream, add a dependency to make sure the memory is
150 // truly free.
151 send_device_to_device_stream->ThenWaitFor(recv_stream);
152
153 VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
154 send_device_to_device_stream->ThenMemcpy(&device_dst_ptr, device_src_ptr,
155 total_bytes);
156 }
157 // Use of input may outlive stack scope, so keep a ref.
158 TensorReference input_ref(*input);
159 dev_info->event_mgr->ThenExecute(
160 send_device_to_device_stream,
161 [done, send_device_to_device_stream, input_ref]() {
162 input_ref.Unref();
163 if (!send_device_to_device_stream->ok()) {
164 LOG(FATAL) << "PluggableDevice->PluggableDevice Memcpy " // Crash OK
165 << "failed.";
166 }
167 done(OkStatus());
168 });
169 send_dev_context->MaintainLifetimeOnStream(input,
170 send_device_to_device_stream);
171 }
172
173 // static
CopyPluggableDeviceTensorToCPU(Device * device,const DeviceContext * device_context,const Tensor * device_tensor,Tensor * cpu_tensor,StatusCallback done)174 void PluggableDeviceUtil::CopyPluggableDeviceTensorToCPU(
175 Device* device, const DeviceContext* device_context,
176 const Tensor* device_tensor, Tensor* cpu_tensor, StatusCallback done) {
177 VLOG(1) << "CopyPluggableDeviceTensorToCPU";
178 const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
179 se::Stream* send_stream = nullptr;
180 Status s = PrepareCopy(device, device_context, *device_tensor, cpu_tensor,
181 &dev_info, &send_stream);
182 if (!s.ok()) {
183 done(s);
184 return;
185 }
186
187 auto send_device_to_host_stream =
188 static_cast<const PluggableDeviceContext*>(device_context)
189 ->device_to_host_stream();
190 if (send_device_to_host_stream == nullptr) {
191 done(errors::Internal(
192 "No send PluggableDevice copy-out-stream is available."));
193 return;
194 }
195 // Wait for the sender's main stream to make sure that the data are available.
196 send_device_to_host_stream->ThenWaitFor(send_stream);
197
198 const int64_t total_bytes = device_tensor->TotalBytes();
199 if (total_bytes > 0) {
200 void* src_ptr = GetBase(device_tensor);
201 DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
202 void* dst_ptr = GetBase(cpu_tensor);
203 send_device_to_host_stream->ThenMemcpy(dst_ptr, device_src_ptr,
204 total_bytes);
205 }
206
207 // Use of the input may outlive stack scope, so keep a ref.
208 TensorReference input_ref(*device_tensor);
209 dev_info->event_mgr->ThenExecute(
210 send_device_to_host_stream,
211 [send_device_to_host_stream, done, input_ref]() {
212 if (!send_device_to_host_stream->ok()) {
213 LOG(FATAL) << "PluggableDevice->CPU Memcpy failed."; // Crash OK
214 }
215 input_ref.Unref();
216 done(OkStatus());
217 });
218 }
219
220 // static
CopyCPUTensorToPluggableDevice(const Tensor * cpu_tensor,const DeviceContext * device_context,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute)221 void PluggableDeviceUtil::CopyCPUTensorToPluggableDevice(
222 const Tensor* cpu_tensor, const DeviceContext* device_context,
223 Device* device, Tensor* device_tensor, StatusCallback done,
224 bool sync_dst_compute) {
225 VLOG(1) << "CopyCPUTensorToPluggableDevice";
226 const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
227 se::Stream* recv_stream = nullptr;
228 Status s = PrepareCopy(device, device_context, *cpu_tensor, device_tensor,
229 &dev_info, &recv_stream);
230 if (!s.ok()) {
231 done(s);
232 return;
233 }
234
235 auto recv_host_to_device_stream =
236 static_cast<const PluggableDeviceContext*>(device_context)
237 ->host_to_device_stream();
238 if (recv_host_to_device_stream == nullptr) {
239 done(errors::Internal(
240 "No send PluggableDevice copy-out-stream is available."));
241 return;
242 }
243 // Wait for the recv-stream to make sure the buffer is truly available.
244 if (sync_dst_compute) {
245 recv_host_to_device_stream->ThenWaitFor(recv_stream);
246 }
247 const int64_t total_bytes = cpu_tensor->TotalBytes();
248 // Note that 0-size tensors have no backing buffer.
249 if (total_bytes > 0) {
250 void* src_ptr = GetBase(cpu_tensor);
251 void* dst_ptr = GetBase(device_tensor);
252 DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
253 recv_host_to_device_stream->ThenMemcpy(&device_dst_ptr, src_ptr,
254 total_bytes);
255 }
256 // Use of cpu_tensor may outlive stack scope, so keep a ref.
257 TensorReference input_ref(*cpu_tensor);
258 dev_info->event_mgr->ThenExecute(
259 recv_host_to_device_stream,
260 [recv_host_to_device_stream, done, input_ref]() {
261 input_ref.Unref();
262 if (!recv_host_to_device_stream->ok()) {
263 LOG(FATAL) << "CPU->PluggableDevice Memcpy failed."; // Crash OK
264 }
265 done(OkStatus());
266 });
267 }
268
Sync(Device * device)269 Status PluggableDeviceUtil::Sync(Device* device) {
270 VLOG(1) << "PluggableDeviceUtil::Sync";
271 auto* dev_info = device->tensorflow_accelerator_device_info();
272 if (!dev_info) {
273 return errors::Internal("Failed to find dest device GPUDeviceInfo.");
274 }
275 return dev_info->stream->BlockHostUntilDone();
276 }
277
SyncAll(Device * device)278 Status PluggableDeviceUtil::SyncAll(Device* device) {
279 VLOG(1) << "PluggableDeviceUtil::SyncAll";
280 auto* dev_info = device->tensorflow_accelerator_device_info();
281 if (!dev_info) {
282 return errors::Internal("Failed to find dest device GPUDeviceInfo.");
283 }
284 if (!dev_info->stream->parent()->SynchronizeAllActivity() ||
285 !dev_info->stream->ok()) {
286 return errors::Internal("PluggableDevice SyncAll failed.");
287 }
288 return OkStatus();
289 }
290
291 // static
CopyPluggableDeviceTensorToSameDevice(Device * device,const DeviceContext * device_context,const Tensor * src_device_tensor,Tensor * dst_device_tensor,StatusCallback done)292 void PluggableDeviceUtil::CopyPluggableDeviceTensorToSameDevice(
293 Device* device, const DeviceContext* device_context,
294 const Tensor* src_device_tensor, Tensor* dst_device_tensor,
295 StatusCallback done) {
296 VLOG(1) << "CopyPluggableDeviceTensorToSameDevice";
297 const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
298 se::Stream* send_stream = nullptr;
299 Status s = PrepareCopy(device, device_context, *src_device_tensor,
300 dst_device_tensor, &dev_info, &send_stream);
301 if (!s.ok()) {
302 done(s);
303 return;
304 }
305
306 const int64_t total_bytes = src_device_tensor->TotalBytes();
307 if (total_bytes > 0) {
308 void* src_ptr = GetBase(src_device_tensor);
309 DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
310 void* dst_ptr = GetBase(dst_device_tensor);
311 DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
312 send_stream->ThenMemcpy(&device_dst_ptr, device_src_ptr, total_bytes);
313 }
314
315 done(OkStatus());
316 }
317
318 } // namespace tensorflow
319