1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 18 19 #include <string> 20 #include <unordered_set> 21 22 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 23 #include "tensorflow/core/distributed_runtime/worker_env.h" 24 #include "tensorflow/core/distributed_runtime/worker_session.h" 25 #include "tensorflow/core/framework/control_flow.h" 26 #include "tensorflow/core/framework/rendezvous.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/flatmap.h" 29 #include "tensorflow/core/lib/gtl/flatset.h" 30 #include "tensorflow/core/lib/hash/hash.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/core/util/device_name_utils.h" 36 37 namespace tensorflow { 38 39 class BaseRemoteRendezvous; 40 class BaseRecvTensorCall; 41 42 // RendezvousMgr keeps track of a set of local rendezvous instances. 43 // All tensors sent by this worker are buffered in a RendezvousMgr 44 // until the tensor is received. Each global unique "step_id" 45 // corresponds to one local rendezvous instance managed by a 46 // RendezvousMgr. 47 // 48 // E.g., 49 // Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); 50 // fork execution of a graph executor using "rendez" on thread 1; 51 // fork execution of another graph executor using "rendez" on thread 2; 52 // ... 53 // join threads 1 and 2; 54 // 55 // In the example above, execution in thread 1 and 2 communicates with 56 // each other by send/recv operations through `rendez`. 57 // 58 // Tensors sent and received through a rendezvous managed by this 59 // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). 60 class BaseRendezvousMgr : public RendezvousMgrInterface { 61 public: 62 explicit BaseRendezvousMgr(const WorkerEnv* worker_env); 63 64 ~BaseRendezvousMgr() override; 65 66 // Returns Rendezvous supporting send and recv among workers in the 67 // "step_id". The caller takes ownership of one reference on the 68 // returned Rendezvous instance. 69 // 70 // Note: the caller must guarantee to eventually call Initialize on the 71 // returned RemoteRendezvous 72 RemoteRendezvous* Find(int64 step_id) override; 73 74 // Finds the local rendezvous instance for the "step_id". Runs 75 // "done" when the tensor for "key" is produced or an error occurs. 76 // 77 // This method is used by the rpc handler of RecvTensor. 78 void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed, 79 Rendezvous::DoneCallback done) override; 80 81 // Synchronous wrapper for RecvLocalAsync. 82 Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed, 83 Tensor* val, bool* is_dead) override; 84 85 // Removes rendezvous for "step_id". 86 // 87 // TODO(zhifengc): Have a background thread in worker that 88 // periodically calls CleanupAll(). 89 void Cleanup(int64 step_id) override; 90 91 // Removed all rendezvous. 92 void CleanupAll() override; 93 94 protected: 95 virtual BaseRemoteRendezvous* Create(int64 step_id, 96 const WorkerEnv* worker_env) = 0; 97 98 private: 99 // Maps step_id to rendezvous. 100 typedef gtl::FlatMap<int64, BaseRemoteRendezvous*> Table; 101 102 // Not owned. 103 const WorkerEnv* const worker_env_; 104 105 mutex mu_; 106 Table table_ GUARDED_BY(mu_); 107 108 BaseRemoteRendezvous* FindOrCreate(int64 step_id); 109 110 TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr); 111 }; 112 113 // RemoteRendezvous is a Rendezvous which can handle either 114 // the producer or consumer being in a remote process. 115 // 116 // Buffering of Tensor values is delegated to a "local" Rendezvous 117 // obtained from NewLocalRendezvous(). This class just adds 118 // functionality to coordinate with remote workers. 119 class BaseRemoteRendezvous : public RemoteRendezvous { 120 public: 121 BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id); 122 123 // Upgrades the BaseRemoteRendezvous to full initialization. 124 Status Initialize(WorkerSession* session) override; 125 126 // Forwards to local_, where the Tensor "val" will be buffered and 127 // any waiting callback stored. 128 Status Send(const ParsedKey& key, const Rendezvous::Args& args, 129 const Tensor& val, const bool is_dead) override; 130 131 // This method is called only by the RecvOp. It tests to see 132 // whether the value will be produced by a local or remote device 133 // and handles accordingly. In the local case it forwards to 134 // local_, in the remote case it initiates an RPC request. 135 void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, 136 DoneCallback done) override; 137 138 void StartAbort(const Status& status) override; 139 140 // This method is called only by the local Worker, forwarded through 141 // the same method on RendezvousMgr. This occurs when the Worker 142 // has received a RecvTensor request, either locally or over the 143 // network. In either case it needs to retrieve a locally buffered 144 // value from local_, and give it to its caller. 145 // 146 // Runs "done" as soon as the tensor for "parsed" is available or an error 147 // is detected. 148 // 149 // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. 150 void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); 151 152 protected: 153 virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, 154 const Rendezvous::Args& args, 155 DoneCallback done) = 0; 156 157 // Returns true if "src" and "dst" are located in the same worker, 158 // and hence may use a local rendezvous. 159 virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, 160 DeviceNameUtils::ParsedName dst); 161 162 // If aborted, aborts "call". Otherwise, adds "call" into active_. 163 void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); 164 165 // Removes "call" from active_ if "call" is in active_. 166 void DeregisterCall(BaseRecvTensorCall* call); 167 168 WorkerSession* session(); 169 170 bool is_initialized(); 171 172 ~BaseRemoteRendezvous() override; 173 174 const WorkerEnv* const env_; // Not owned. 175 const int64 step_id_; 176 177 private: 178 Rendezvous* local_; // Owns a Ref on this object. 179 180 // Guards mutable state that is read-mostly after this rendezvous is 181 // initialized. 182 mutable mutex init_mu_; 183 184 // Status given by StartAbort() if any. 185 Status status_ GUARDED_BY(init_mu_); 186 187 WorkerSession* session_ GUARDED_BY(init_mu_); // Not owned. 188 189 // Data structures to handle calls when partially initialized. 190 struct DeferredCall { 191 const ParsedKey parsed; 192 DoneCallback done; 193 194 DeferredCall(const ParsedKey& parsed, DoneCallback done); 195 }; 196 std::vector<DeferredCall> deferred_calls_ GUARDED_BY(init_mu_); 197 198 typedef std::function<void()> InactiveCallback; 199 200 // Active outstanding RecvTensor calls. 201 mutex active_mu_; 202 std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_ 203 GUARDED_BY(active_mu_); 204 is_initialized_locked()205 bool is_initialized_locked() SHARED_LOCKS_REQUIRED(init_mu_) { 206 return session_ != nullptr; 207 } 208 209 // If "is_src" is true, checks that the rendezvous key "parsed"'s 210 // source is in this process. If "is_src" is false, checks that the 211 // rendezvous key "parsed"'s destination is in this process. 212 Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); 213 214 // Callback handling the case when a rendezvous has been 215 // accomplished in local_ and the consumer is local to this process. 216 // Tensor "in" will be copied into "out". The key "parsed" encodes 217 // the src and dst devices. 218 void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, 219 const Rendezvous::Args& in_args, 220 const Rendezvous::Args& out_args, const Tensor& in, 221 Tensor* out, StatusCallback done); 222 223 // Must be called only if fully initialized. 224 void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); 225 226 TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); 227 }; 228 229 class BaseRecvTensorCall { 230 public: BaseRecvTensorCall()231 BaseRecvTensorCall() {} ~BaseRecvTensorCall()232 virtual ~BaseRecvTensorCall() {} 233 234 virtual void Start(std::function<void()> recv_done) = 0; 235 236 virtual void StartAbort(const Status& s) = 0; 237 238 virtual Status status() const = 0; 239 240 private: 241 TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall); 242 }; 243 244 } // end namespace tensorflow 245 246 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 247