• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
18 
19 #include <string>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
23 #include "tensorflow/core/distributed_runtime/worker_env.h"
24 #include "tensorflow/core/distributed_runtime/worker_session.h"
25 #include "tensorflow/core/framework/control_flow.h"
26 #include "tensorflow/core/framework/rendezvous.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/flatmap.h"
29 #include "tensorflow/core/lib/gtl/flatset.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 
39 class BaseRemoteRendezvous;
40 class BaseRecvTensorCall;
41 
42 // RendezvousMgr keeps track of a set of local rendezvous instances.
43 // All tensors sent by this worker are buffered in a RendezvousMgr
44 // until the tensor is received.  Each global unique "step_id"
45 // corresponds to one local rendezvous instance managed by a
46 // RendezvousMgr.
47 //
48 // E.g.,
49 //   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
50 //   fork execution of a graph executor using "rendez" on thread 1;
51 //   fork execution of another graph executor using "rendez" on thread 2;
52 //   ...
53 //   join threads 1 and 2;
54 //
55 // In the example above, execution in thread 1 and 2 communicates with
56 // each other by send/recv operations through `rendez`.
57 //
58 // Tensors sent and received through a rendezvous managed by this
59 // RendezvousMgr must have keys generated by Rendezvous::CreateKey().
60 class BaseRendezvousMgr : public RendezvousMgrInterface {
61  public:
62   explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
63 
64   ~BaseRendezvousMgr() override;
65 
66   // Returns Rendezvous supporting send and recv among workers in the
67   // "step_id".  The caller takes ownership of one reference on the
68   // returned Rendezvous instance.
69   //
70   // Note: the caller must guarantee to eventually call Initialize on the
71   // returned RemoteRendezvous
72   RemoteRendezvous* Find(int64 step_id) override;
73 
74   // Finds the local rendezvous instance for the "step_id".  Runs
75   // "done" when the tensor for "key" is produced or an error occurs.
76   //
77   // This method is used by the rpc handler of RecvTensor.
78   void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
79                       Rendezvous::DoneCallback done) override;
80 
81   // Synchronous wrapper for RecvLocalAsync.
82   Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
83                    Tensor* val, bool* is_dead) override;
84 
85   // Removes rendezvous for "step_id".
86   //
87   // TODO(zhifengc): Have a background thread in worker that
88   // periodically calls CleanupAll().
89   void Cleanup(int64 step_id) override;
90 
91   // Removed all rendezvous.
92   void CleanupAll() override;
93 
94  protected:
95   virtual BaseRemoteRendezvous* Create(int64 step_id,
96                                        const WorkerEnv* worker_env) = 0;
97 
98  private:
99   // Maps step_id to rendezvous.
100   typedef gtl::FlatMap<int64, BaseRemoteRendezvous*> Table;
101 
102   // Not owned.
103   const WorkerEnv* const worker_env_;
104 
105   mutex mu_;
106   Table table_ GUARDED_BY(mu_);
107 
108   BaseRemoteRendezvous* FindOrCreate(int64 step_id);
109 
110   TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
111 };
112 
113 // RemoteRendezvous is a Rendezvous which can handle either
114 // the producer or consumer being in a remote process.
115 //
116 // Buffering of Tensor values is delegated to a "local" Rendezvous
117 // obtained from NewLocalRendezvous().  This class just adds
118 // functionality to coordinate with remote workers.
119 class BaseRemoteRendezvous : public RemoteRendezvous {
120  public:
121   BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id);
122 
123   // Upgrades the BaseRemoteRendezvous to full initialization.
124   Status Initialize(WorkerSession* session) override;
125 
126   // Forwards to local_, where the Tensor "val" will be buffered and
127   // any waiting callback stored.
128   Status Send(const ParsedKey& key, const Rendezvous::Args& args,
129               const Tensor& val, const bool is_dead) override;
130 
131   // This method is called only by the RecvOp.  It tests to see
132   // whether the value will be produced by a local or remote device
133   // and handles accordingly.  In the local case it forwards to
134   // local_, in the remote case it initiates an RPC request.
135   void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
136                  DoneCallback done) override;
137 
138   void StartAbort(const Status& status) override;
139 
140   // This method is called only by the local Worker, forwarded through
141   // the same method on RendezvousMgr.  This occurs when the Worker
142   // has received a RecvTensor request, either locally or over the
143   // network.  In either case it needs to retrieve a locally buffered
144   // value from local_, and give it to its caller.
145   //
146   // Runs "done" as soon as the tensor for "parsed" is available or an error
147   // is detected.
148   //
149   // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
150   void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
151 
152  protected:
153   virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
154                                    const Rendezvous::Args& args,
155                                    DoneCallback done) = 0;
156 
157   // Returns true if "src" and "dst" are located in the same worker,
158   // and hence may use a local rendezvous.
159   virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
160                             DeviceNameUtils::ParsedName dst);
161 
162   // If aborted, aborts "call". Otherwise, adds "call" into active_.
163   void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
164 
165   // Removes "call" from active_ if "call" is in active_.
166   void DeregisterCall(BaseRecvTensorCall* call);
167 
168   WorkerSession* session();
169 
170   bool is_initialized();
171 
172   ~BaseRemoteRendezvous() override;
173 
174   const WorkerEnv* const env_;  // Not owned.
175   const int64 step_id_;
176 
177  private:
178   Rendezvous* local_;  // Owns a Ref on this object.
179 
180   // Guards mutable state that is read-mostly after this rendezvous is
181   // initialized.
182   mutable mutex init_mu_;
183 
184   // Status given by StartAbort() if any.
185   Status status_ GUARDED_BY(init_mu_);
186 
187   WorkerSession* session_ GUARDED_BY(init_mu_);  // Not owned.
188 
189   // Data structures to handle calls when partially initialized.
190   struct DeferredCall {
191     const ParsedKey parsed;
192     DoneCallback done;
193 
194     DeferredCall(const ParsedKey& parsed, DoneCallback done);
195   };
196   std::vector<DeferredCall> deferred_calls_ GUARDED_BY(init_mu_);
197 
198   typedef std::function<void()> InactiveCallback;
199 
200   // Active outstanding RecvTensor calls.
201   mutex active_mu_;
202   std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
203       GUARDED_BY(active_mu_);
204 
is_initialized_locked()205   bool is_initialized_locked() SHARED_LOCKS_REQUIRED(init_mu_) {
206     return session_ != nullptr;
207   }
208 
209   // If "is_src" is true, checks that the rendezvous key "parsed"'s
210   // source is in this process. If "is_src" is false, checks that the
211   // rendezvous key "parsed"'s destination is in this process.
212   Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
213 
214   // Callback handling the case when a rendezvous has been
215   // accomplished in local_ and the consumer is local to this process.
216   // Tensor "in" will be copied into "out". The key "parsed" encodes
217   // the src and dst devices.
218   void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
219                           const Rendezvous::Args& in_args,
220                           const Rendezvous::Args& out_args, const Tensor& in,
221                           Tensor* out, StatusCallback done);
222 
223   // Must be called only if fully initialized.
224   void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
225 
226   TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
227 };
228 
229 class BaseRecvTensorCall {
230  public:
BaseRecvTensorCall()231   BaseRecvTensorCall() {}
~BaseRecvTensorCall()232   virtual ~BaseRecvTensorCall() {}
233 
234   virtual void Start(std::function<void()> recv_done) = 0;
235 
236   virtual void StartAbort(const Status& s) = 0;
237 
238   virtual Status status() const = 0;
239 
240  private:
241   TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
242 };
243 
244 }  // end namespace tensorflow
245 
246 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
247