• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/strings/numbers.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
IntraProcessRendezvous(const DeviceMgr * device_mgr)34 IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
35     : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {}
36 
~IntraProcessRendezvous()37 IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
38 
Send(const ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)39 Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
40                                     const Rendezvous::Args& args,
41                                     const Tensor& val, const bool is_dead) {
42   VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
43   {
44     mutex_lock l(mu_);
45     if (!status_.ok()) return status_;
46   }
47 
48   // Buffers "val" and "device_context" in local_.
49   return local_->Send(parsed, args, val, is_dead);
50 }
51 
ParseKey(const string & key,bool is_src,Rendezvous::ParsedKey * parsed)52 Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
53                                         Rendezvous::ParsedKey* parsed) {
54   {
55     mutex_lock l(mu_);
56     if (!status_.ok()) return status_;
57   }
58   TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
59   return Status::OK();
60 }
61 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)62 void IntraProcessRendezvous::SameWorkerRecvDone(
63     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
64     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
65     StatusCallback done) {
66   // Do a quick copy (sharing the underlying buffer) if both tensors
67   // are on host memory.
68   const bool src_host =
69       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
70   const bool dst_host =
71       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
72   if (src_host && dst_host) {
73     *out = in;
74     done(Status::OK());
75     return;
76   }
77 
78   // This copy must involve a non-CPU device. Hence, "in" must support DMA
79   // (e.g., string tensors do not work on GPU).  Variant copy DMA
80   // checks happen inside CopyTensor::ViaDMA.
81   if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT) {
82     done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
83                                  " tensor may not be copied from/to a GPU."));
84     return;
85   }
86 
87   Device* src_device;
88   Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
89   if (!s.ok()) {
90     done(s);
91     return;
92   }
93   Device* dst_device;
94   s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
95   if (!s.ok()) {
96     done(s);
97     return;
98   }
99 
100   AllocatorAttributes attr = recv_args.alloc_attrs;
101   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
102                           recv_args.alloc_attrs.gpu_compatible());
103   Allocator* out_allocator = dst_device->GetAllocator(attr);
104   if (in.dtype() != DT_VARIANT) {
105     // Variants are handled by CopyTensor::ViaDMA.
106     Tensor copy(out_allocator, in.dtype(), in.shape());
107     *out = copy;
108   }
109 
110   CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
111                      recv_args.device_context, src_device, dst_device,
112                      send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
113                      0 /*dev_to_dev_stream_index*/, std::move(done));
114 }
115 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)116 void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
117                                        const Rendezvous::Args& recv_args,
118                                        DoneCallback done) {
119   VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
120 
121   // Recv the tensor from local_.
122   local_->RecvAsync(
123       parsed, recv_args,
124       std::bind(
125           [this, parsed](DoneCallback done,
126                          // Begin unbound arguments.
127                          const Status& status,
128                          const Rendezvous::Args& send_args,
129                          const Rendezvous::Args& recv_args, const Tensor& in,
130                          bool is_dead) {
131             // If "in" is an uninitialized tensor, do copy-construction to
132             // preserve the uninitialized state, along with data type and shape
133             // info, which is useful for debugger purposes.
134             Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
135 
136             auto final_callback = std::bind(
137                 [send_args, recv_args, out, is_dead](DoneCallback done,
138                                                      // Begin unbound arguments.
139                                                      const Status& s) {
140                   done(s, send_args, recv_args, *out, is_dead);
141                   delete out;
142                 },
143                 std::move(done), std::placeholders::_1);
144 
145             if (status.ok() && in.IsInitialized()) {
146               SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
147                                  std::move(final_callback));
148             } else {
149               final_callback(status);
150             }
151           },
152           std::move(done), std::placeholders::_1, std::placeholders::_2,
153           std::placeholders::_3, std::placeholders::_4, std::placeholders::_5));
154 }
155 
StartAbort(const Status & s)156 void IntraProcessRendezvous::StartAbort(const Status& s) {
157   CHECK(!s.ok());
158   local_->StartAbort(s);
159 }
160 
161 }  // end namespace tensorflow
162