• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/copy_tensor.h"
17 
18 #include <atomic>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/framework/variant_op_registry.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
28 #include "tensorflow/core/util/reffed_status_callback.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 struct RegistrationInfo {
RegistrationInfotensorflow::__anonfbf19cbc0111::RegistrationInfo34   RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
35       : sender_device_type(std::move(s)),
36         receiver_device_type(std::move(r)),
37         copy_function(cf) {}
38   DeviceType sender_device_type;
39   DeviceType receiver_device_type;
40   CopyTensor::CopyFunction copy_function;
41 };
42 
43 // We use a vector instead of a map since we expect there to be very
44 // few registrations.
MutableRegistry()45 std::vector<RegistrationInfo>* MutableRegistry() {
46   static std::vector<RegistrationInfo>* registry =
47       new std::vector<RegistrationInfo>;
48   return registry;
49 }
50 
CopyHostToDevice(const Tensor * input,Allocator * cpu_allocator,Allocator * out_allocator,StringPiece edge_name,Device * dst,Tensor * output,DeviceContext * recv_dev_context,StatusCallback done,bool sync_dst_compute)51 void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
52                       Allocator* out_allocator, StringPiece edge_name,
53                       Device* dst, Tensor* output,
54                       DeviceContext* recv_dev_context, StatusCallback done,
55                       bool sync_dst_compute) {
56   if (input->dtype() == DT_VARIANT) {
57     Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
58     auto* status_cb = new ReffedStatusCallback(std::move(done));
59     core::ScopedUnref status_cb_unref(status_cb);
60 
61     auto wrapped_done = [status_cb](const Status& s) {
62       status_cb->UpdateStatus(s);
63       status_cb->Unref();
64     };
65     auto copier =
66         [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
67          edge_name, sync_dst_compute, wrapped_done = std::move(wrapped_done)](
68             const Tensor& from, Tensor* to) {
69           if (from.dtype() == DT_VARIANT) {
70             status_cb->Ref();
71             CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
72                              dst, to, recv_dev_context, wrapped_done,
73                              sync_dst_compute);
74             return Status::OK();
75           } else {
76             if (!DMAHelper::CanUseDMA(&from)) {
77               Status err = errors::InvalidArgument(
78                   "During Variant Host->Device Copy: "
79                   "non-DMA-copy attempted of tensor type: ",
80                   DataTypeString(from.dtype()));
81               status_cb->UpdateStatus(err);
82               return err;
83             }
84             if (status_cb->ok()) {
85               status_cb->Ref();
86               *to = Tensor(out_allocator, from.dtype(), from.shape());
87               recv_dev_context->CopyCPUTensorToDevice(
88                   &from, dst, to, wrapped_done, sync_dst_compute);
89               return Status::OK();
90             } else {
91               return status_cb->status();
92             }
93           }
94         };
95 
96     const Variant* v = input->flat<Variant>().data();
97     Variant* v_out = copy.flat<Variant>().data();
98     Status s_copy_init;
99     for (int64 i = 0; i < input->NumElements(); ++i) {
100       s_copy_init = VariantDeviceCopy(
101           VariantDeviceCopyDirection::HOST_TO_DEVICE, v[i], &v_out[i], copier);
102       if (!s_copy_init.ok()) {
103         status_cb->UpdateStatus(s_copy_init);
104         break;
105       }
106     }
107     if (s_copy_init.ok()) {
108       *output = std::move(copy);
109     }
110   } else if (input->dtype() == DT_RESOURCE) {
111     *output = *input;
112     done(Status::OK());
113   } else {
114     recv_dev_context->CopyCPUTensorToDevice(input, dst, output, std::move(done),
115                                             sync_dst_compute);
116   }
117 }
118 
119 
CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,Allocator * cpu_allocator,Allocator * out_allocator,DeviceContext * send_dev_context,DeviceContext * recv_dev_context,Device * src,Device * dst,const AllocatorAttributes src_alloc_attr,const AllocatorAttributes dst_alloc_attr,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done)120 void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
121                         Allocator* cpu_allocator, Allocator* out_allocator,
122                         DeviceContext* send_dev_context,
123                         DeviceContext* recv_dev_context, Device* src,
124                         Device* dst, const AllocatorAttributes src_alloc_attr,
125                         const AllocatorAttributes dst_alloc_attr,
126                         const Tensor* input, Tensor* output,
127                         int dev_to_dev_stream_index, StatusCallback done) {
128   if (input->dtype() == DT_VARIANT) {
129     Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
130     auto* status_cb = new ReffedStatusCallback(std::move(done));
131     core::ScopedUnref status_cb_unref(status_cb);
132 
133     auto wrapped_done = [status_cb](const Status& s) {
134       status_cb->UpdateStatus(s);
135       status_cb->Unref();
136     };
137     auto copier =
138         [copy_function, cpu_allocator, src, dst, src_alloc_attr, dst_alloc_attr,
139          recv_dev_context, send_dev_context, out_allocator, status_cb,
140          dev_to_dev_stream_index, wrapped_done = std::move(wrapped_done)](
141             // Begin unbound arguments
142             const Tensor& from, Tensor* to) {
143           if (from.dtype() == DT_VARIANT) {
144             status_cb->Ref();
145             CopyDeviceToDevice(copy_function, cpu_allocator, out_allocator,
146                                send_dev_context, recv_dev_context, src, dst,
147                                src_alloc_attr, dst_alloc_attr, &from, to,
148                                dev_to_dev_stream_index, wrapped_done);
149             return Status::OK();
150           } else {
151             if (!DMAHelper::CanUseDMA(&from)) {
152               Status err = errors::InvalidArgument(
153                   "During Variant Device->Device Copy: ", src->name(), " to ",
154                   dst->name(), " non-DMA-copy attempted of tensor type: ",
155                   DataTypeString(from.dtype()));
156               status_cb->UpdateStatus(err);
157               return err;
158             }
159             if (status_cb->ok()) {
160               status_cb->Ref();
161               *to = Tensor(out_allocator, from.dtype(), from.shape());
162               copy_function(send_dev_context, recv_dev_context, src, dst,
163                             src_alloc_attr, dst_alloc_attr, &from, to,
164                             dev_to_dev_stream_index, wrapped_done);
165               return Status::OK();
166             } else {
167               return status_cb->status();
168             }
169           }
170         };
171 
172     const Variant* v = input->flat<Variant>().data();
173     Variant* v_out = copy.flat<Variant>().data();
174     Status s_copy_init;
175     for (int64 i = 0; i < input->NumElements(); ++i) {
176       s_copy_init =
177           VariantDeviceCopy(VariantDeviceCopyDirection::DEVICE_TO_DEVICE, v[i],
178                             &v_out[i], copier);
179       if (!s_copy_init.ok()) {
180         status_cb->UpdateStatus(s_copy_init);
181         break;
182       }
183     }
184     if (s_copy_init.ok()) {
185       *output = std::move(copy);
186     }
187   } else if (input->dtype() == DT_RESOURCE) {
188     *output = *input;
189     done(Status::OK());
190   } else {
191     copy_function(send_dev_context, recv_dev_context, src, dst, src_alloc_attr,
192                   dst_alloc_attr, input, output, dev_to_dev_stream_index,
193                   std::move(done));
194   }
195 }
196 
197 }  // namespace
198 
199 // static
ViaDMA(StringPiece edge_name,DeviceContext * send_dev_context,DeviceContext * recv_dev_context,Device * src,Device * dst,const AllocatorAttributes src_alloc_attr,const AllocatorAttributes dst_alloc_attr,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done,bool sync_dst_compute)200 void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
201                         DeviceContext* recv_dev_context, Device* src,
202                         Device* dst, const AllocatorAttributes src_alloc_attr,
203                         const AllocatorAttributes dst_alloc_attr,
204                         const Tensor* input, Tensor* output,
205                         int dev_to_dev_stream_index, StatusCallback done,
206                         bool sync_dst_compute) {
207   profiler::ScopedAnnotation annotation(edge_name);
208   VLOG(1) << "Copy " << edge_name;
209 
210   const DeviceType src_device_type(
211       src_alloc_attr.on_host() ? DEVICE_CPU : src->attributes().device_type());
212   const DeviceType dst_device_type(
213       dst_alloc_attr.on_host() ? DEVICE_CPU : dst->attributes().device_type());
214   const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
215   const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
216 
217   // TODO(phawkins): choose an allocator optimal for both the src and dst
218   // devices, not just the src device.
219   AllocatorAttributes host_alloc_attrs;
220   host_alloc_attrs.set_gpu_compatible(true);
221   host_alloc_attrs.set_on_host(true);
222   Allocator* cpu_allocator = src->GetAllocator(host_alloc_attrs);
223   Allocator* out_allocator = dst->GetAllocator(dst_alloc_attr);
224 
225   // E.g., gpu -> gpu
226   if (non_cpu_src && non_cpu_dst) {
227     // Device to device copy.  Look through registry for an appropriate
228     // CopyFunction.
229     std::vector<RegistrationInfo>* registry = MutableRegistry();
230     for (const RegistrationInfo& ri : *registry) {
231       if (ri.sender_device_type == src_device_type &&
232           ri.receiver_device_type == dst_device_type) {
233         CopyDeviceToDevice(ri.copy_function, cpu_allocator, out_allocator,
234                            send_dev_context, recv_dev_context, src, dst,
235                            src_alloc_attr, dst_alloc_attr, input, output,
236                            dev_to_dev_stream_index, std::move(done));
237         return;
238       }
239     }
240 
241     // Fall back to copying via the host.
242     VLOG(1) << "No function registered to copy from devices of type "
243             << src_device_type.type() << " to devices of type "
244             << dst_device_type.type()
245             << ". Falling back to copying via the host.";
246 
247     Tensor* cpu_tensor =
248         new Tensor(cpu_allocator, input->dtype(), input->shape());
249     auto delete_and_done = [cpu_tensor,
250                             done = std::move(done)](const Status& status) {
251       delete cpu_tensor;
252       done(status);
253     };
254     auto then_copy_to_other_device =
255         [delete_and_done = std::move(delete_and_done), recv_dev_context,
256          cpu_tensor, cpu_allocator, out_allocator, edge_name, dst, output,
257          sync_dst_compute](Status status) {
258           if (!status.ok()) {
259             delete_and_done(status);
260             return;
261           }
262           CopyHostToDevice(cpu_tensor, cpu_allocator, out_allocator, edge_name,
263                            dst, output, recv_dev_context,
264                            std::move(delete_and_done), sync_dst_compute);
265         };
266     CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
267                      cpu_tensor, send_dev_context,
268                      std::move(then_copy_to_other_device));
269     return;
270   }
271 
272   // E.g., gpu -> cpu
273   if (non_cpu_src && !non_cpu_dst) {
274     // Device to host copy.
275     CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
276                      output, send_dev_context, std::move(done));
277     return;
278   }
279 
280   // E.g., cpu -> gpu
281   if (!non_cpu_src && non_cpu_dst) {
282     // Host to Device copy.
283     CopyHostToDevice(input, cpu_allocator, out_allocator, edge_name, dst,
284                      output, recv_dev_context, std::move(done),
285                      sync_dst_compute);
286     return;
287   }
288 
289   // cpu -> cpu
290   CHECK(!non_cpu_src && !non_cpu_dst);
291   *output = *input;
292   done(Status::OK());
293 }
294 
295 // static
Register(DeviceType sender_device_type,DeviceType receiver_device_type,CopyFunction copy_function)296 Status CopyTensor::Register(DeviceType sender_device_type,
297                             DeviceType receiver_device_type,
298                             CopyFunction copy_function) {
299   std::vector<RegistrationInfo>* registry = MutableRegistry();
300   registry->emplace_back(sender_device_type, receiver_device_type,
301                          copy_function);
302   return Status::OK();
303 }
304 
305 namespace {
306 
307 // The following registrations enable a DT_VARIANT tensor element that contains
308 // a wrapped `tensorflow::Tensor` to be copied between devices.
WrappedTensorDeviceCopy(const Tensor & from,Tensor * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)309 static Status WrappedTensorDeviceCopy(
310     const Tensor& from, Tensor* to,
311     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
312   if (DMAHelper::CanUseDMA(&from)) {
313     TF_RETURN_IF_ERROR(copy(from, to));
314   } else {
315     *to = from;
316   }
317 
318   return Status::OK();
319 }
320 
321 #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION)         \
322   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
323       Tensor, DIRECTION, WrappedTensorDeviceCopy)
324 
325 REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
326 REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
327 REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
328 
329 }  // namespace
330 
CopyDeviceToHost(const Tensor * input,Allocator * cpu_allocator,Allocator * out_allocator,StringPiece edge_name,Device * src,Tensor * output,DeviceContext * send_dev_context,StatusCallback done)331 void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
332                       Allocator* out_allocator, StringPiece edge_name,
333                       Device* src, Tensor* output,
334                       DeviceContext* send_dev_context, StatusCallback done) {
335   if (input->dtype() == DT_VARIANT) {
336     Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
337     auto* status_cb = new ReffedStatusCallback(std::move(done));
338     core::ScopedUnref status_cb_unref(status_cb);
339 
340     auto wrapped_done = [status_cb](const Status& s) {
341       status_cb->UpdateStatus(s);
342       status_cb->Unref();
343     };
344     auto copier =
345         [edge_name, src, send_dev_context, out_allocator, status_cb,
346          cpu_allocator, wrapped_done = std::move(wrapped_done)](
347             const Tensor& from, Tensor* to) {
348           if (from.dtype() == DT_VARIANT) {
349             status_cb->Ref();
350             CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
351                              src, to, send_dev_context, wrapped_done);
352             return Status::OK();
353           } else {
354             if (!DMAHelper::CanUseDMA(&from)) {
355               Status err = errors::InvalidArgument(
356                   "During Variant Device->Host Copy: "
357                   "non-DMA-copy attempted of tensor type: ",
358                   DataTypeString(from.dtype()));
359               status_cb->UpdateStatus(err);
360               return err;
361             }
362             if (status_cb->ok()) {
363               status_cb->Ref();
364               *to = Tensor(out_allocator, from.dtype(), from.shape());
365               send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
366                                                       wrapped_done);
367               return Status::OK();
368             } else {
369               return status_cb->status();
370             }
371           }
372         };
373 
374     const Variant* v = input->flat<Variant>().data();
375     Variant* v_out = copy.flat<Variant>().data();
376     Status s_copy_init;
377     for (int64 i = 0; i < input->NumElements(); ++i) {
378       s_copy_init = VariantDeviceCopy(
379           VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
380       if (!s_copy_init.ok()) {
381         status_cb->UpdateStatus(s_copy_init);
382         break;
383       }
384     }
385     if (s_copy_init.ok()) {
386       *output = std::move(copy);
387     }
388   } else if (input->dtype() == DT_RESOURCE) {
389     *output = *input;
390     done(Status::OK());
391   } else {
392     send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
393                                             std::move(done));
394   }
395 }
396 
397 }  // namespace tensorflow
398