• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h"
17 
18 #include "tensorflow/core/common_runtime/copy_tensor.h"
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h"
24 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_reference.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/refcount.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #include "tensorflow/core/platform/tensor_coding.h"
37 #include "tensorflow/core/util/util.h"
38 
39 // IMPLEMENTATION NOTE:
40 //
41 // 1. Within this module, we intentionally LOG(FATAL) if any stream
42 //    involved in memcpy becomes !stream->ok(), because TF process
43 //    today (3/2021) can not properly recover from such an error.
44 //
45 // 2. When 0-size tensor is being copied, we should not schedule a
46 //    copy ThenMemcpy since there is no byte to move. However, we must
47 //    ensure the causal ordering by arranging the copy done callback
48 //    to happen after all activities scheduled on the given stream being
49 //    finished.
50 
51 namespace tensorflow {
52 
53 using se::DeviceMemoryBase;
54 using se::Stream;
55 
PrepareCopy(Device * device,const DeviceContext * ctx,const Tensor & src,const Tensor * dst,const DeviceBase::AcceleratorDeviceInfo ** dev_info,se::Stream ** stream)56 static Status PrepareCopy(Device* device, const DeviceContext* ctx,
57                           const Tensor& src, const Tensor* dst,
58                           const DeviceBase::AcceleratorDeviceInfo** dev_info,
59                           se::Stream** stream) {
60   if (device == nullptr) {
61     return errors::Internal("Unexpected null device.");
62   }
63   auto di = device->tensorflow_accelerator_device_info();
64   if (di == nullptr) {
65     return errors::Internal("Unexpected null device info.");
66   }
67 
68   *dev_info = di;
69   if (ctx == nullptr) {
70     return errors::Internal("Unexpected null device context.");
71   }
72   auto device_stream =
73       static_cast<const PluggableDeviceContext*>(ctx)->stream();
74   if (device_stream == nullptr) {
75     return errors::Internal("No PluggableDevice stream is available.");
76   }
77   *stream = device_stream;
78   if (dst != nullptr) {
79     if (src.dtype() != dst->dtype()) {
80       return errors::Internal("Can't copy a tensor of ",
81                               DataTypeString(src.dtype()), " into a tensor of ",
82                               DataTypeString(dst->dtype()));
83     }
84     if (src.TotalBytes() != dst->TotalBytes()) {
85       return errors::Internal("Can't copy ", src.TotalBytes(),
86                               " bytes of a tensor into another with ",
87                               dst->TotalBytes(), " bytes buffer.");
88     }
89     if ((src.TotalBytes() > 0) && !src.IsInitialized()) {
90       return errors::Internal("Src tensor is not initialized.");
91     }
92     if ((dst->TotalBytes() > 0) && !dst->IsInitialized()) {
93       return errors::Internal("Dst tensor is not initialized.");
94     }
95   }
96   if (!DMAHelper::CanUseDMA(&src)) {
97     return errors::Internal("PluggableDevice copy from non-DMA",
98                             DataTypeString(src.dtype()), " tensor.");
99   }
100   return OkStatus();
101 }
102 
GetBase(const Tensor * src)103 static void* GetBase(const Tensor* src) {
104   return const_cast<void*>(DMAHelper::base(src));
105 }
106 
GetBase(Tensor * dst)107 static void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
108 
109 // static
DeviceToDeviceCopy(DeviceContext * send_dev_context,DeviceContext * recv_dev_context,Device * src,Device * dst,AllocatorAttributes src_alloc_attr,AllocatorAttributes dst_alloc_attr,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done)110 void PluggableDeviceUtil::DeviceToDeviceCopy(
111     DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
112     Device* src, Device* dst, AllocatorAttributes src_alloc_attr,
113     AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output,
114     int dev_to_dev_stream_index, StatusCallback done) {
115   const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
116   se::Stream* send_stream = nullptr;
117   Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info,
118                          &send_stream);
119   if (!s.ok()) {
120     done(s);
121     return;
122   }
123 
124   auto send_device_to_device_stream =
125       static_cast<const PluggableDeviceContext*>(send_dev_context)
126           ->device_to_device_stream(dev_to_dev_stream_index);
127   if (send_device_to_device_stream == nullptr) {
128     done(errors::Internal(
129         "No send PluggableDevice copy-out-stream is available."));
130     return;
131   }
132   // Wait for the main stream on the sender to make sure the result is
133   // available.
134   send_device_to_device_stream->ThenWaitFor(send_stream);
135 
136   const int64_t total_bytes = input->TotalBytes();
137   if (total_bytes > 0) {
138     void* src_ptr = GetBase(input);
139     DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
140     void* dst_ptr = GetBase(output);
141     DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
142     auto recv_stream =
143         static_cast<const PluggableDeviceContext*>(recv_dev_context)->stream();
144     if (recv_stream == nullptr) {
145       done(errors::Internal("No recv PluggableDevice stream is available."));
146       return;
147     }
148     // Since we want to use the memory from recv_stream in the
149     // send_device_to_host_stream, add a dependency to make sure the memory is
150     // truly free.
151     send_device_to_device_stream->ThenWaitFor(recv_stream);
152 
153     VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
154     send_device_to_device_stream->ThenMemcpy(&device_dst_ptr, device_src_ptr,
155                                              total_bytes);
156   }
157   // Use of input may outlive stack scope, so keep a ref.
158   TensorReference input_ref(*input);
159   dev_info->event_mgr->ThenExecute(
160       send_device_to_device_stream,
161       [done, send_device_to_device_stream, input_ref]() {
162         input_ref.Unref();
163         if (!send_device_to_device_stream->ok()) {
164           LOG(FATAL) << "PluggableDevice->PluggableDevice Memcpy "  // Crash OK
165                      << "failed.";
166         }
167         done(OkStatus());
168       });
169   send_dev_context->MaintainLifetimeOnStream(input,
170                                              send_device_to_device_stream);
171 }
172 
173 // static
CopyPluggableDeviceTensorToCPU(Device * device,const DeviceContext * device_context,const Tensor * device_tensor,Tensor * cpu_tensor,StatusCallback done)174 void PluggableDeviceUtil::CopyPluggableDeviceTensorToCPU(
175     Device* device, const DeviceContext* device_context,
176     const Tensor* device_tensor, Tensor* cpu_tensor, StatusCallback done) {
177   VLOG(1) << "CopyPluggableDeviceTensorToCPU";
178   const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
179   se::Stream* send_stream = nullptr;
180   Status s = PrepareCopy(device, device_context, *device_tensor, cpu_tensor,
181                          &dev_info, &send_stream);
182   if (!s.ok()) {
183     done(s);
184     return;
185   }
186 
187   auto send_device_to_host_stream =
188       static_cast<const PluggableDeviceContext*>(device_context)
189           ->device_to_host_stream();
190   if (send_device_to_host_stream == nullptr) {
191     done(errors::Internal(
192         "No send PluggableDevice copy-out-stream is available."));
193     return;
194   }
195   // Wait for the sender's main stream to make sure that the data are available.
196   send_device_to_host_stream->ThenWaitFor(send_stream);
197 
198   const int64_t total_bytes = device_tensor->TotalBytes();
199   if (total_bytes > 0) {
200     void* src_ptr = GetBase(device_tensor);
201     DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
202     void* dst_ptr = GetBase(cpu_tensor);
203     send_device_to_host_stream->ThenMemcpy(dst_ptr, device_src_ptr,
204                                            total_bytes);
205   }
206 
207   // Use of the input may outlive stack scope, so keep a ref.
208   TensorReference input_ref(*device_tensor);
209   dev_info->event_mgr->ThenExecute(
210       send_device_to_host_stream,
211       [send_device_to_host_stream, done, input_ref]() {
212         if (!send_device_to_host_stream->ok()) {
213           LOG(FATAL) << "PluggableDevice->CPU Memcpy failed.";  // Crash OK
214         }
215         input_ref.Unref();
216         done(OkStatus());
217       });
218 }
219 
220 // static
CopyCPUTensorToPluggableDevice(const Tensor * cpu_tensor,const DeviceContext * device_context,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute)221 void PluggableDeviceUtil::CopyCPUTensorToPluggableDevice(
222     const Tensor* cpu_tensor, const DeviceContext* device_context,
223     Device* device, Tensor* device_tensor, StatusCallback done,
224     bool sync_dst_compute) {
225   VLOG(1) << "CopyCPUTensorToPluggableDevice";
226   const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
227   se::Stream* recv_stream = nullptr;
228   Status s = PrepareCopy(device, device_context, *cpu_tensor, device_tensor,
229                          &dev_info, &recv_stream);
230   if (!s.ok()) {
231     done(s);
232     return;
233   }
234 
235   auto recv_host_to_device_stream =
236       static_cast<const PluggableDeviceContext*>(device_context)
237           ->host_to_device_stream();
238   if (recv_host_to_device_stream == nullptr) {
239     done(errors::Internal(
240         "No send PluggableDevice copy-out-stream is available."));
241     return;
242   }
243   // Wait for the recv-stream to make sure the buffer is truly available.
244   if (sync_dst_compute) {
245     recv_host_to_device_stream->ThenWaitFor(recv_stream);
246   }
247   const int64_t total_bytes = cpu_tensor->TotalBytes();
248   // Note that 0-size tensors have no backing buffer.
249   if (total_bytes > 0) {
250     void* src_ptr = GetBase(cpu_tensor);
251     void* dst_ptr = GetBase(device_tensor);
252     DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
253     recv_host_to_device_stream->ThenMemcpy(&device_dst_ptr, src_ptr,
254                                            total_bytes);
255   }
256   // Use of cpu_tensor may outlive stack scope, so keep a ref.
257   TensorReference input_ref(*cpu_tensor);
258   dev_info->event_mgr->ThenExecute(
259       recv_host_to_device_stream,
260       [recv_host_to_device_stream, done, input_ref]() {
261         input_ref.Unref();
262         if (!recv_host_to_device_stream->ok()) {
263           LOG(FATAL) << "CPU->PluggableDevice Memcpy failed.";  // Crash OK
264         }
265         done(OkStatus());
266       });
267 }
268 
Sync(Device * device)269 Status PluggableDeviceUtil::Sync(Device* device) {
270   VLOG(1) << "PluggableDeviceUtil::Sync";
271   auto* dev_info = device->tensorflow_accelerator_device_info();
272   if (!dev_info) {
273     return errors::Internal("Failed to find dest device GPUDeviceInfo.");
274   }
275   return dev_info->stream->BlockHostUntilDone();
276 }
277 
SyncAll(Device * device)278 Status PluggableDeviceUtil::SyncAll(Device* device) {
279   VLOG(1) << "PluggableDeviceUtil::SyncAll";
280   auto* dev_info = device->tensorflow_accelerator_device_info();
281   if (!dev_info) {
282     return errors::Internal("Failed to find dest device GPUDeviceInfo.");
283   }
284   if (!dev_info->stream->parent()->SynchronizeAllActivity() ||
285       !dev_info->stream->ok()) {
286     return errors::Internal("PluggableDevice SyncAll failed.");
287   }
288   return OkStatus();
289 }
290 
291 // static
CopyPluggableDeviceTensorToSameDevice(Device * device,const DeviceContext * device_context,const Tensor * src_device_tensor,Tensor * dst_device_tensor,StatusCallback done)292 void PluggableDeviceUtil::CopyPluggableDeviceTensorToSameDevice(
293     Device* device, const DeviceContext* device_context,
294     const Tensor* src_device_tensor, Tensor* dst_device_tensor,
295     StatusCallback done) {
296   VLOG(1) << "CopyPluggableDeviceTensorToSameDevice";
297   const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr;
298   se::Stream* send_stream = nullptr;
299   Status s = PrepareCopy(device, device_context, *src_device_tensor,
300                          dst_device_tensor, &dev_info, &send_stream);
301   if (!s.ok()) {
302     done(s);
303     return;
304   }
305 
306   const int64_t total_bytes = src_device_tensor->TotalBytes();
307   if (total_bytes > 0) {
308     void* src_ptr = GetBase(src_device_tensor);
309     DeviceMemoryBase device_src_ptr(src_ptr, total_bytes);
310     void* dst_ptr = GetBase(dst_device_tensor);
311     DeviceMemoryBase device_dst_ptr(dst_ptr, total_bytes);
312     send_stream->ThenMemcpy(&device_dst_ptr, device_src_ptr, total_bytes);
313   }
314 
315   done(OkStatus());
316 }
317 
318 }  // namespace tensorflow
319