• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/device_factory.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/notification.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
36 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)37 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
38                         const Rendezvous::ParsedKey& parsed,
39                         const Rendezvous::Args& send_args,
40                         const Rendezvous::Args& recv_args, const Tensor& in,
41                         Tensor* out, StatusCallback done) {
42   // Do a quick copy (sharing the underlying buffer) if both tensors
43   // are on host memory.
44   const bool src_host =
45       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
46   const bool dst_host =
47       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
48   if (src_host && dst_host) {
49     *out = in;
50     done(Status::OK());
51     return;
52   }
53 
54   // This copy must involve a non-CPU device. Hence, "in" must support DMA
55   // (e.g., string tensors do not work on GPU).  Variant copy DMA
56   // checks happen inside CopyTensor::ViaDMA.
57   if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
58       in.dtype() != DT_RESOURCE) {
59     done(errors::InvalidArgument(
60         "Non-DMA-safe ", DataTypeString(in.dtype()),
61         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
62     return;
63   }
64 
65   Device* src_device;
66   Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
67   if (!s.ok()) {
68     done(s);
69     return;
70   }
71   Device* dst_device;
72   s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
73   if (!s.ok()) {
74     done(s);
75     return;
76   }
77 
78   ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", 0, "dynamic",
79                                             in.dtype(), &in.shape());
80   AllocatorAttributes attr = recv_args.alloc_attrs;
81   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
82                           recv_args.alloc_attrs.gpu_compatible());
83   Allocator* out_allocator = dst_device->GetAllocator(attr);
84   bool sync_dst_compute = true;
85   if (in.dtype() != DT_VARIANT) {
86     // Variants are handled by CopyTensor::ViaDMA.
87     AllocationAttributes aa;
88     uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
89     std::function<uint64()> freed_by_func = [dst_device,
90                                              &safe_alloc_frontier]() {
91       safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
92       return safe_alloc_frontier;
93     };
94     if ((parsed.dst.type == "GPU" ||
95          DeviceFactory::IsPluggableDevice(parsed.dst.type)) &&
96         safe_alloc_frontier > 0) {
97       // There's a timestamped allocator at work, so use it instead
98       // of sync_dst_compute.
99       aa.freed_by_func = &freed_by_func;
100       sync_dst_compute = false;
101     }
102     Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
103     *out = copy;
104     if (in.shape().num_elements() > 0 && out->data() == nullptr) {
105       done(tensorflow::errors::ResourceExhausted(
106           "SameWorkerRecvDone unable to allocate output tensor. Key: ",
107           parsed.FullKey()));
108       return;
109     }
110   }
111 
112   CopyTensor::ViaDMA(
113       parsed.edge_name, send_args.device_context, recv_args.device_context,
114       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
115       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
116 }
117 
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)118 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
119                                LocalRendezvous* local,
120                                const RendezvousInterface::ParsedKey& parsed,
121                                const Rendezvous::Args& recv_args,
122                                RendezvousInterface::DoneCallback done) {
123   VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
124 
125   ScopedMemoryDebugAnnotation op_annotation("RecvAsync");
126   // Recv the tensor from local_.
127   local->RecvAsync(
128       parsed, recv_args,
129       [device_mgr, parsed, done = std::move(done)](
130           const Status& status, const Rendezvous::Args& send_args,
131           const Rendezvous::Args& recv_args, const Tensor& in,
132           bool is_dead) mutable {
133         // If "in" is an uninitialized tensor, do copy-construction to
134         // preserve the uninitialized state, along with data type and shape
135         // info, which is useful for debugger purposes.
136         Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
137 
138         auto final_callback = [send_args, recv_args, out, is_dead,
139                                done = std::move(done)](const Status& s) {
140           done(s, send_args, recv_args, *out, is_dead);
141           delete out;
142         };
143 
144         if (status.ok() && in.IsInitialized()) {
145           SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
146                              std::move(final_callback));
147         } else {
148           final_callback(status);
149         }
150       });
151 }
152 
153 }  // namespace
154 
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)155 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
156     const DeviceMgr* device_mgr)
157     : device_mgr_(device_mgr), local_(this) {}
158 
~RefCountedIntraProcessRendezvous()159 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
160 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)161 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
162                                               const Rendezvous::Args& args,
163                                               const Tensor& val,
164                                               const bool is_dead) {
165   VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
166   return local_.Send(key, args, val, is_dead);
167 }
168 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)169 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
170                                                  const Rendezvous::Args& args,
171                                                  DoneCallback done) {
172   VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
173   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
174 }
175 
StartAbort(const Status & s)176 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
177   local_.StartAbort(s);
178 }
179 
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)180 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
181     const DeviceMgr* device_mgr)
182     : device_mgr_(device_mgr), local_(nullptr) {}
183 
~PrivateIntraProcessRendezvous()184 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
185 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)186 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
187                                            const Rendezvous::Args& args,
188                                            const Tensor& val,
189                                            const bool is_dead) {
190   DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
191   return local_.Send(key, args, val, is_dead);
192 }
193 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)194 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
195                                               const Rendezvous::Args& args,
196                                               DoneCallback done) {
197   DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
198            << key.FullKey();
199   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
200 }
201 
StartAbort(const Status & s)202 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
203   local_.StartAbort(s);
204 }
205 
206 }  // end namespace tensorflow
207