• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
35 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)36 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
37                         const Rendezvous::ParsedKey& parsed,
38                         const Rendezvous::Args& send_args,
39                         const Rendezvous::Args& recv_args, const Tensor& in,
40                         Tensor* out, StatusCallback done) {
41   // Do a quick copy (sharing the underlying buffer) if both tensors
42   // are on host memory.
43   const bool src_host =
44       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
45   const bool dst_host =
46       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
47   if (src_host && dst_host) {
48     *out = in;
49     done(Status::OK());
50     return;
51   }
52 
53   // This copy must involve a non-CPU device. Hence, "in" must support DMA
54   // (e.g., string tensors do not work on GPU).  Variant copy DMA
55   // checks happen inside CopyTensor::ViaDMA.
56   if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
57       in.dtype() != DT_RESOURCE) {
58     done(errors::InvalidArgument(
59         "Non-DMA-safe ", DataTypeString(in.dtype()),
60         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
61     return;
62   }
63 
64   Device* src_device;
65   Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
66   if (!s.ok()) {
67     done(s);
68     return;
69   }
70   Device* dst_device;
71   s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
72   if (!s.ok()) {
73     done(s);
74     return;
75   }
76 
77   MEMDEBUG_CACHE_OP("SameWorkerRecvDone");
78   AllocatorAttributes attr = recv_args.alloc_attrs;
79   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
80                           recv_args.alloc_attrs.gpu_compatible());
81   Allocator* out_allocator = dst_device->GetAllocator(attr);
82   bool sync_dst_compute = true;
83   if (in.dtype() != DT_VARIANT) {
84     // Variants are handled by CopyTensor::ViaDMA.
85     AllocationAttributes aa;
86     uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
87     std::function<uint64()> freed_by_func = [dst_device,
88                                              &safe_alloc_frontier]() {
89       safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
90       return safe_alloc_frontier;
91     };
92     if (parsed.dst.type == "GPU" && safe_alloc_frontier > 0) {
93       // There's a timestamped allocator at work, so use it instead
94       // of sync_dst_compute.
95       aa.freed_by_func = &freed_by_func;
96       sync_dst_compute = false;
97     }
98     Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
99     *out = copy;
100   }
101 
102   CopyTensor::ViaDMA(
103       parsed.edge_name, send_args.device_context, recv_args.device_context,
104       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
105       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
106 }
107 
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)108 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
109                                LocalRendezvous* local,
110                                const RendezvousInterface::ParsedKey& parsed,
111                                const Rendezvous::Args& recv_args,
112                                RendezvousInterface::DoneCallback done) {
113   VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
114 
115   MEMDEBUG_CACHE_OP("RecvAsync");
116   // Recv the tensor from local_.
117   local->RecvAsync(
118       parsed, recv_args,
119       [device_mgr, parsed, done = std::move(done)](
120           const Status& status, const Rendezvous::Args& send_args,
121           const Rendezvous::Args& recv_args, const Tensor& in,
122           bool is_dead) mutable {
123         // If "in" is an uninitialized tensor, do copy-construction to
124         // preserve the uninitialized state, along with data type and shape
125         // info, which is useful for debugger purposes.
126         Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
127 
128         auto final_callback = [send_args, recv_args, out, is_dead,
129                                done = std::move(done)](const Status& s) {
130           done(s, send_args, recv_args, *out, is_dead);
131           delete out;
132         };
133 
134         if (status.ok() && in.IsInitialized()) {
135           SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
136                              std::move(final_callback));
137         } else {
138           final_callback(status);
139         }
140       });
141 }
142 
143 }  // namespace
144 
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)145 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
146     const DeviceMgr* device_mgr)
147     : device_mgr_(device_mgr) {}
148 
~RefCountedIntraProcessRendezvous()149 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
150 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)151 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
152                                               const Rendezvous::Args& args,
153                                               const Tensor& val,
154                                               const bool is_dead) {
155   VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
156   return local_.Send(key, args, val, is_dead);
157 }
158 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)159 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
160                                                  const Rendezvous::Args& args,
161                                                  DoneCallback done) {
162   VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
163   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
164 }
165 
StartAbort(const Status & s)166 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
167   local_.StartAbort(s);
168 }
169 
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)170 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
171     const DeviceMgr* device_mgr)
172     : device_mgr_(device_mgr) {}
173 
~PrivateIntraProcessRendezvous()174 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
175 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)176 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
177                                            const Rendezvous::Args& args,
178                                            const Tensor& val,
179                                            const bool is_dead) {
180   DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
181   return local_.Send(key, args, val, is_dead);
182 }
183 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)184 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
185                                               const Rendezvous::Args& args,
186                                               DoneCallback done) {
187   DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
188            << key.FullKey();
189   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
190 }
191 
StartAbort(const Status & s)192 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
193   local_.StartAbort(s);
194 }
195 
196 }  // end namespace tensorflow
197