1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34
35 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)36 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
37 const Rendezvous::ParsedKey& parsed,
38 const Rendezvous::Args& send_args,
39 const Rendezvous::Args& recv_args, const Tensor& in,
40 Tensor* out, StatusCallback done) {
41 // Do a quick copy (sharing the underlying buffer) if both tensors
42 // are on host memory.
43 const bool src_host =
44 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
45 const bool dst_host =
46 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
47 if (src_host && dst_host) {
48 *out = in;
49 done(Status::OK());
50 return;
51 }
52
53 // This copy must involve a non-CPU device. Hence, "in" must support DMA
54 // (e.g., string tensors do not work on GPU). Variant copy DMA
55 // checks happen inside CopyTensor::ViaDMA.
56 if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
57 in.dtype() != DT_RESOURCE) {
58 done(errors::InvalidArgument(
59 "Non-DMA-safe ", DataTypeString(in.dtype()),
60 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
61 return;
62 }
63
64 Device* src_device;
65 Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
66 if (!s.ok()) {
67 done(s);
68 return;
69 }
70 Device* dst_device;
71 s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
72 if (!s.ok()) {
73 done(s);
74 return;
75 }
76
77 MEMDEBUG_CACHE_OP("SameWorkerRecvDone");
78 AllocatorAttributes attr = recv_args.alloc_attrs;
79 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
80 recv_args.alloc_attrs.gpu_compatible());
81 Allocator* out_allocator = dst_device->GetAllocator(attr);
82 bool sync_dst_compute = true;
83 if (in.dtype() != DT_VARIANT) {
84 // Variants are handled by CopyTensor::ViaDMA.
85 AllocationAttributes aa;
86 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
87 std::function<uint64()> freed_by_func = [dst_device,
88 &safe_alloc_frontier]() {
89 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
90 return safe_alloc_frontier;
91 };
92 if (parsed.dst.type == "GPU" && safe_alloc_frontier > 0) {
93 // There's a timestamped allocator at work, so use it instead
94 // of sync_dst_compute.
95 aa.freed_by_func = &freed_by_func;
96 sync_dst_compute = false;
97 }
98 Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
99 *out = copy;
100 }
101
102 CopyTensor::ViaDMA(
103 parsed.edge_name, send_args.device_context, recv_args.device_context,
104 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
105 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
106 }
107
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)108 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
109 LocalRendezvous* local,
110 const RendezvousInterface::ParsedKey& parsed,
111 const Rendezvous::Args& recv_args,
112 RendezvousInterface::DoneCallback done) {
113 VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
114
115 MEMDEBUG_CACHE_OP("RecvAsync");
116 // Recv the tensor from local_.
117 local->RecvAsync(
118 parsed, recv_args,
119 [device_mgr, parsed, done = std::move(done)](
120 const Status& status, const Rendezvous::Args& send_args,
121 const Rendezvous::Args& recv_args, const Tensor& in,
122 bool is_dead) mutable {
123 // If "in" is an uninitialized tensor, do copy-construction to
124 // preserve the uninitialized state, along with data type and shape
125 // info, which is useful for debugger purposes.
126 Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
127
128 auto final_callback = [send_args, recv_args, out, is_dead,
129 done = std::move(done)](const Status& s) {
130 done(s, send_args, recv_args, *out, is_dead);
131 delete out;
132 };
133
134 if (status.ok() && in.IsInitialized()) {
135 SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
136 std::move(final_callback));
137 } else {
138 final_callback(status);
139 }
140 });
141 }
142
143 } // namespace
144
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)145 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
146 const DeviceMgr* device_mgr)
147 : device_mgr_(device_mgr) {}
148
~RefCountedIntraProcessRendezvous()149 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
150
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)151 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
152 const Rendezvous::Args& args,
153 const Tensor& val,
154 const bool is_dead) {
155 VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
156 return local_.Send(key, args, val, is_dead);
157 }
158
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)159 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
160 const Rendezvous::Args& args,
161 DoneCallback done) {
162 VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
163 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
164 }
165
StartAbort(const Status & s)166 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
167 local_.StartAbort(s);
168 }
169
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)170 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
171 const DeviceMgr* device_mgr)
172 : device_mgr_(device_mgr) {}
173
~PrivateIntraProcessRendezvous()174 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
175
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)176 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
177 const Rendezvous::Args& args,
178 const Tensor& val,
179 const bool is_dead) {
180 DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
181 return local_.Send(key, args, val, is_dead);
182 }
183
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)184 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
185 const Rendezvous::Args& args,
186 DoneCallback done) {
187 DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
188 << key.FullKey();
189 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
190 }
191
StartAbort(const Status & s)192 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
193 local_.StartAbort(s);
194 }
195
196 } // end namespace tensorflow
197