1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/strings/numbers.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace tensorflow {
33
IntraProcessRendezvous(const DeviceMgr * device_mgr)34 IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
35 : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {}
36
~IntraProcessRendezvous()37 IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
38
Send(const ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)39 Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
40 const Rendezvous::Args& args,
41 const Tensor& val, const bool is_dead) {
42 VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
43 {
44 mutex_lock l(mu_);
45 if (!status_.ok()) return status_;
46 }
47
48 // Buffers "val" and "device_context" in local_.
49 return local_->Send(parsed, args, val, is_dead);
50 }
51
ParseKey(const string & key,bool is_src,Rendezvous::ParsedKey * parsed)52 Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
53 Rendezvous::ParsedKey* parsed) {
54 {
55 mutex_lock l(mu_);
56 if (!status_.ok()) return status_;
57 }
58 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
59 return Status::OK();
60 }
61
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)62 void IntraProcessRendezvous::SameWorkerRecvDone(
63 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
64 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
65 StatusCallback done) {
66 // Do a quick copy (sharing the underlying buffer) if both tensors
67 // are on host memory.
68 const bool src_host =
69 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
70 const bool dst_host =
71 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
72 if (src_host && dst_host) {
73 *out = in;
74 done(Status::OK());
75 return;
76 }
77
78 // This copy must involve a non-CPU device. Hence, "in" must support DMA
79 // (e.g., string tensors do not work on GPU). Variant copy DMA
80 // checks happen inside CopyTensor::ViaDMA.
81 if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT) {
82 done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
83 " tensor may not be copied from/to a GPU."));
84 return;
85 }
86
87 Device* src_device;
88 Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
89 if (!s.ok()) {
90 done(s);
91 return;
92 }
93 Device* dst_device;
94 s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
95 if (!s.ok()) {
96 done(s);
97 return;
98 }
99
100 AllocatorAttributes attr = recv_args.alloc_attrs;
101 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
102 recv_args.alloc_attrs.gpu_compatible());
103 Allocator* out_allocator = dst_device->GetAllocator(attr);
104 if (in.dtype() != DT_VARIANT) {
105 // Variants are handled by CopyTensor::ViaDMA.
106 Tensor copy(out_allocator, in.dtype(), in.shape());
107 *out = copy;
108 }
109
110 CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
111 recv_args.device_context, src_device, dst_device,
112 send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
113 0 /*dev_to_dev_stream_index*/, std::move(done));
114 }
115
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)116 void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
117 const Rendezvous::Args& recv_args,
118 DoneCallback done) {
119 VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
120
121 // Recv the tensor from local_.
122 local_->RecvAsync(
123 parsed, recv_args,
124 std::bind(
125 [this, parsed](DoneCallback done,
126 // Begin unbound arguments.
127 const Status& status,
128 const Rendezvous::Args& send_args,
129 const Rendezvous::Args& recv_args, const Tensor& in,
130 bool is_dead) {
131 // If "in" is an uninitialized tensor, do copy-construction to
132 // preserve the uninitialized state, along with data type and shape
133 // info, which is useful for debugger purposes.
134 Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
135
136 auto final_callback = std::bind(
137 [send_args, recv_args, out, is_dead](DoneCallback done,
138 // Begin unbound arguments.
139 const Status& s) {
140 done(s, send_args, recv_args, *out, is_dead);
141 delete out;
142 },
143 std::move(done), std::placeholders::_1);
144
145 if (status.ok() && in.IsInitialized()) {
146 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
147 std::move(final_callback));
148 } else {
149 final_callback(status);
150 }
151 },
152 std::move(done), std::placeholders::_1, std::placeholders::_2,
153 std::placeholders::_3, std::placeholders::_4, std::placeholders::_5));
154 }
155
StartAbort(const Status & s)156 void IntraProcessRendezvous::StartAbort(const Status& s) {
157 CHECK(!s.ok());
158 local_->StartAbort(s);
159 }
160
161 } // end namespace tensorflow
162