• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
18 
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
25 #include "tensorflow/core/distributed_runtime/worker_env.h"
26 #include "tensorflow/core/distributed_runtime/worker_session.h"
27 #include "tensorflow/core/framework/control_flow.h"
28 #include "tensorflow/core/framework/rendezvous.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 
39 class BaseRemoteRendezvous;
40 class BaseRecvTensorCall;
41 
42 // RendezvousMgr keeps track of a set of local rendezvous instances.
43 // All tensors sent by this worker are buffered in a RendezvousMgr
44 // until the tensor is received.  Each global unique "step_id"
45 // corresponds to one local rendezvous instance managed by a
46 // RendezvousMgr.
47 //
48 // E.g.,
49 //   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
50 //   fork execution of a graph executor using "rendez" on thread 1;
51 //   fork execution of another graph executor using "rendez" on thread 2;
52 //   ...
53 //   join threads 1 and 2;
54 //
55 // In the example above, execution in thread 1 and 2 communicates with
56 // each other by send/recv operations through `rendez`.
57 //
58 // Tensors sent and received through a rendezvous managed by this
59 // RendezvousMgr must have keys generated by Rendezvous::CreateKey().
60 class BaseRendezvousMgr : public RendezvousMgrInterface {
61  public:
62   explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
63 
64   ~BaseRendezvousMgr() override;
65 
66   // Returns Rendezvous supporting send and recv among workers in the
67   // "step_id".  The caller takes ownership of one reference on the
68   // returned Rendezvous instance.
69   //
70   // Note: the caller must guarantee to eventually call Initialize on the
71   // returned RemoteRendezvous
72   RemoteRendezvous* Find(int64 step_id) override;
73 
74   // Finds the local rendezvous instance for the "step_id".  Runs
75   // "done" when the tensor for "key" is produced or an error occurs.
76   //
77   // This method is used by the rpc handler of RecvTensor.
78   void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
79                       Rendezvous::DoneCallback done) override;
80 
81   // Synchronous wrapper for RecvLocalAsync.
82   Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
83                    Tensor* val, bool* is_dead) override;
84 
85   // Removes rendezvous for "step_id".
86   //
87   // TODO(zhifengc): Have a background thread in worker that
88   // periodically calls CleanupAll().
89   void Cleanup(int64 step_id) override;
90 
91  protected:
92   virtual BaseRemoteRendezvous* Create(int64 step_id,
93                                        const WorkerEnv* worker_env) = 0;
94 
95  private:
96   // Maps step_id to rendezvous.
97   typedef absl::flat_hash_map<int64, BaseRemoteRendezvous*> Table;
98 
99   // Not owned.
100   const WorkerEnv* const worker_env_;
101 
102   mutex mu_;
103   Table table_ TF_GUARDED_BY(mu_);
104 
105   BaseRemoteRendezvous* FindOrCreate(int64 step_id);
106 
107   TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
108 };
109 
110 // RemoteRendezvous is a Rendezvous which can handle either
111 // the producer or consumer being in a remote process.
112 //
113 // Buffering of Tensor values is delegated to a "local" Rendezvous
114 // obtained from NewLocalRendezvous().  This class just adds
115 // functionality to coordinate with remote workers.
116 class BaseRemoteRendezvous : public RemoteRendezvous {
117  public:
118   BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id);
119 
120   // Upgrades the BaseRemoteRendezvous to full initialization.
121   Status Initialize(WorkerSession* session) override;
122 
123   // Forwards to local_, where the Tensor "val" will be buffered and
124   // any waiting callback stored.
125   Status Send(const ParsedKey& key, const Rendezvous::Args& args,
126               const Tensor& val, const bool is_dead) override;
127 
128   // This method is called only by the RecvOp.  It tests to see
129   // whether the value will be produced by a local or remote device
130   // and handles accordingly.  In the local case it forwards to
131   // local_, in the remote case it initiates an RPC request.
132   void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
133                  DoneCallback done) override;
134 
135   void StartAbort(const Status& status) override;
136 
137   // This method is called only by the local Worker, forwarded through
138   // the same method on RendezvousMgr.  This occurs when the Worker
139   // has received a RecvTensor request, either locally or over the
140   // network.  In either case it needs to retrieve a locally buffered
141   // value from local_, and give it to its caller.
142   //
143   // Runs "done" as soon as the tensor for "parsed" is available or an error
144   // is detected.
145   //
146   // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
147   void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
148 
149  protected:
150   virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
151                                    const Rendezvous::Args& args,
152                                    DoneCallback done) = 0;
153 
154   // Returns true if "src" and "dst" are located in the same worker,
155   // and hence may use a local rendezvous.
156   virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
157                             DeviceNameUtils::ParsedName dst);
158 
159   // If aborted, aborts "call". Otherwise, adds "call" into active_.
160   void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
161 
162   // Removes "call" from active_ if "call" is in active_.
163   void DeregisterCall(BaseRecvTensorCall* call);
164 
165   WorkerSession* session();
166 
167   bool is_initialized();
168 
169   ~BaseRemoteRendezvous() override;
170 
171   const WorkerEnv* const env_;  // Not owned.
172   const int64 step_id_;
173 
174  private:
175   Rendezvous* local_;  // Owns a Ref on this object.
176 
177   mutable mutex mu_;
178 
179   // Status given by StartAbort() if any.
180   Status status_ TF_GUARDED_BY(mu_);
181 
182   WorkerSession* session_ TF_GUARDED_BY(mu_);  // Not owned.
183 
184   // Data structures to handle calls when partially initialized.
185   struct DeferredCall {
186     const ParsedKey parsed;
187     DoneCallback done;
188 
189     DeferredCall(const ParsedKey& parsed, DoneCallback done);
190   };
191   std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);
192 
193   typedef std::function<void()> InactiveCallback;
194 
195   std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
196       TF_GUARDED_BY(mu_);
197 
is_initialized_locked()198   bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
199     return session_ != nullptr;
200   }
201 
202   // If "is_src" is true, checks that the rendezvous key "parsed"'s
203   // source is in this process. If "is_src" is false, checks that the
204   // rendezvous key "parsed"'s destination is in this process.
205   Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
206 
207   // Callback handling the case when a rendezvous has been
208   // accomplished in local_ and the consumer is local to this process.
209   // Tensor "in" will be copied into "out". The key "parsed" encodes
210   // the src and dst devices.
211   void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
212                           const Rendezvous::Args& in_args,
213                           const Rendezvous::Args& out_args, const Tensor& in,
214                           Tensor* out, StatusCallback done);
215 
216   // Must be called only if fully initialized.
217   void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
218 
219   TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
220 };
221 
222 class BaseRecvTensorCall {
223  public:
BaseRecvTensorCall()224   BaseRecvTensorCall() {}
~BaseRecvTensorCall()225   virtual ~BaseRecvTensorCall() {}
226 
227   virtual void Start(std::function<void()> recv_done) = 0;
228 
229   virtual void StartAbort(const Status& s) = 0;
230 
231   virtual Status status() const = 0;
232 
233  private:
234   TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
235 };
236 
237 }  // end namespace tensorflow
238 
239 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
240