1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34
35 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)36 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
37 const Rendezvous::ParsedKey& parsed,
38 const Rendezvous::Args& send_args,
39 const Rendezvous::Args& recv_args, const Tensor& in,
40 Tensor* out, StatusCallback done) {
41 // Do a quick copy (sharing the underlying buffer) if both tensors
42 // are on host memory.
43 const bool src_host =
44 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
45 const bool dst_host =
46 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
47 if (src_host && dst_host) {
48 *out = in;
49 done(Status::OK());
50 return;
51 }
52
53 // This copy must involve a non-CPU device. Hence, "in" must support DMA
54 // (e.g., string tensors do not work on GPU). Variant copy DMA
55 // checks happen inside CopyTensor::ViaDMA.
56 if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
57 in.dtype() != DT_RESOURCE) {
58 done(errors::InvalidArgument(
59 "Non-DMA-safe ", DataTypeString(in.dtype()),
60 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
61 return;
62 }
63
64 Device* src_device;
65 Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
66 if (!s.ok()) {
67 done(s);
68 return;
69 }
70 Device* dst_device;
71 s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
72 if (!s.ok()) {
73 done(s);
74 return;
75 }
76
77 ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", 0, "dynamic",
78 in.dtype(), &in.shape());
79 AllocatorAttributes attr = recv_args.alloc_attrs;
80 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
81 recv_args.alloc_attrs.gpu_compatible());
82 Allocator* out_allocator = dst_device->GetAllocator(attr);
83 bool sync_dst_compute = true;
84 if (in.dtype() != DT_VARIANT) {
85 // Variants are handled by CopyTensor::ViaDMA.
86 AllocationAttributes aa;
87 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
88 std::function<uint64()> freed_by_func = [dst_device,
89 &safe_alloc_frontier]() {
90 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
91 return safe_alloc_frontier;
92 };
93 if (parsed.dst.type == "GPU" && safe_alloc_frontier > 0) {
94 // There's a timestamped allocator at work, so use it instead
95 // of sync_dst_compute.
96 aa.freed_by_func = &freed_by_func;
97 sync_dst_compute = false;
98 }
99 Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
100 *out = copy;
101 if (in.shape().num_elements() > 0 && out->data() == nullptr) {
102 done(tensorflow::errors::ResourceExhausted(
103 "SameWorkerRecvDone unable to allocate output tensor. Key: ",
104 parsed.FullKey()));
105 return;
106 }
107 }
108
109 CopyTensor::ViaDMA(
110 parsed.edge_name, send_args.device_context, recv_args.device_context,
111 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
112 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
113 }
114
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)115 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
116 LocalRendezvous* local,
117 const RendezvousInterface::ParsedKey& parsed,
118 const Rendezvous::Args& recv_args,
119 RendezvousInterface::DoneCallback done) {
120 VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
121
122 ScopedMemoryDebugAnnotation op_annotation("RecvAsync");
123 // Recv the tensor from local_.
124 local->RecvAsync(
125 parsed, recv_args,
126 [device_mgr, parsed, done = std::move(done)](
127 const Status& status, const Rendezvous::Args& send_args,
128 const Rendezvous::Args& recv_args, const Tensor& in,
129 bool is_dead) mutable {
130 // If "in" is an uninitialized tensor, do copy-construction to
131 // preserve the uninitialized state, along with data type and shape
132 // info, which is useful for debugger purposes.
133 Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
134
135 auto final_callback = [send_args, recv_args, out, is_dead,
136 done = std::move(done)](const Status& s) {
137 done(s, send_args, recv_args, *out, is_dead);
138 delete out;
139 };
140
141 if (status.ok() && in.IsInitialized()) {
142 SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
143 std::move(final_callback));
144 } else {
145 final_callback(status);
146 }
147 });
148 }
149
150 } // namespace
151
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)152 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
153 const DeviceMgr* device_mgr)
154 : device_mgr_(device_mgr), local_(this) {}
155
~RefCountedIntraProcessRendezvous()156 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
157
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)158 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
159 const Rendezvous::Args& args,
160 const Tensor& val,
161 const bool is_dead) {
162 VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
163 return local_.Send(key, args, val, is_dead);
164 }
165
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)166 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
167 const Rendezvous::Args& args,
168 DoneCallback done) {
169 VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
170 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
171 }
172
StartAbort(const Status & s)173 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
174 local_.StartAbort(s);
175 }
176
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)177 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
178 const DeviceMgr* device_mgr)
179 : device_mgr_(device_mgr), local_(nullptr) {}
180
~PrivateIntraProcessRendezvous()181 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
182
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)183 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
184 const Rendezvous::Args& args,
185 const Tensor& val,
186 const bool is_dead) {
187 DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
188 return local_.Send(key, args, val, is_dead);
189 }
190
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)191 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
192 const Rendezvous::Args& args,
193 DoneCallback done) {
194 DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
195 << key.FullKey();
196 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
197 }
198
StartAbort(const Status & s)199 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
200 local_.StartAbort(s);
201 }
202
203 } // end namespace tensorflow
204