1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 18 19 #include <string> 20 #include <unordered_map> 21 #include <unordered_set> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 25 #include "tensorflow/core/distributed_runtime/worker_env.h" 26 #include "tensorflow/core/distributed_runtime/worker_session.h" 27 #include "tensorflow/core/framework/control_flow.h" 28 #include "tensorflow/core/framework/rendezvous.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/hash/hash.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/core/util/device_name_utils.h" 36 37 namespace tensorflow { 38 39 class BaseRemoteRendezvous; 40 class BaseRecvTensorCall; 41 42 // RendezvousMgr keeps track of a set of local rendezvous instances. 43 // All tensors sent by this worker are buffered in a RendezvousMgr 44 // until the tensor is received. Each global unique "step_id" 45 // corresponds to one local rendezvous instance managed by a 46 // RendezvousMgr. 47 // 48 // E.g., 49 // Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); 50 // fork execution of a graph executor using "rendez" on thread 1; 51 // fork execution of another graph executor using "rendez" on thread 2; 52 // ... 53 // join threads 1 and 2; 54 // 55 // In the example above, execution in thread 1 and 2 communicates with 56 // each other by send/recv operations through `rendez`. 57 // 58 // Tensors sent and received through a rendezvous managed by this 59 // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). 60 class BaseRendezvousMgr : public RendezvousMgrInterface { 61 public: 62 explicit BaseRendezvousMgr(const WorkerEnv* worker_env); 63 64 ~BaseRendezvousMgr() override; 65 66 // Returns Rendezvous supporting send and recv among workers in the 67 // "step_id". The caller takes ownership of one reference on the 68 // returned Rendezvous instance. 69 // 70 // Note: the caller must guarantee to eventually call Initialize on the 71 // returned RemoteRendezvous 72 RemoteRendezvous* Find(int64_t step_id) override; 73 74 // Finds the local rendezvous instance for the "step_id". Runs 75 // "done" when the tensor for "key" is produced or an error occurs. 76 // 77 // This method is used by the rpc handler of RecvTensor. 78 void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed, 79 Rendezvous::DoneCallback done) override; 80 81 // Synchronous wrapper for RecvLocalAsync. 82 Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, 83 Tensor* val, bool* is_dead) override; 84 85 // Removes rendezvous for "step_id". 86 // 87 // TODO(zhifengc): Have a background thread in worker that 88 // periodically calls CleanupAll(). 89 void Cleanup(int64_t step_id) override; 90 91 protected: 92 virtual BaseRemoteRendezvous* Create(int64_t step_id, 93 const WorkerEnv* worker_env) = 0; 94 95 private: 96 // Maps step_id to rendezvous. 97 typedef absl::flat_hash_map<int64, BaseRemoteRendezvous*> Table; 98 99 // Not owned. 100 const WorkerEnv* const worker_env_; 101 102 mutex mu_; 103 Table table_ TF_GUARDED_BY(mu_); 104 105 BaseRemoteRendezvous* FindOrCreate(int64_t step_id); 106 107 TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr); 108 }; 109 110 // RemoteRendezvous is a Rendezvous which can handle either 111 // the producer or consumer being in a remote process. 112 // 113 // Buffering of Tensor values is delegated to a "local" Rendezvous 114 // obtained from NewLocalRendezvous(). This class just adds 115 // functionality to coordinate with remote workers. 116 class BaseRemoteRendezvous : public RemoteRendezvous { 117 public: 118 BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id); 119 120 // Upgrades the BaseRemoteRendezvous to full initialization. 121 Status Initialize(WorkerSession* session) override; 122 123 // Forwards to local_, where the Tensor "val" will be buffered and 124 // any waiting callback stored. 125 Status Send(const ParsedKey& key, const Rendezvous::Args& args, 126 const Tensor& val, const bool is_dead) override; 127 128 // This method is called only by the RecvOp. It tests to see 129 // whether the value will be produced by a local or remote device 130 // and handles accordingly. In the local case it forwards to 131 // local_, in the remote case it initiates an RPC request. 132 void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, 133 DoneCallback done) override; 134 135 void StartAbort(const Status& status) override; 136 137 // This method is called only by the local Worker, forwarded through 138 // the same method on RendezvousMgr. This occurs when the Worker 139 // has received a RecvTensor request, either locally or over the 140 // network. In either case it needs to retrieve a locally buffered 141 // value from local_, and give it to its caller. 142 // 143 // Runs "done" as soon as the tensor for "parsed" is available or an error 144 // is detected. 145 // 146 // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. 147 void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); 148 149 protected: 150 virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, 151 const Rendezvous::Args& args, 152 DoneCallback done) = 0; 153 154 // Returns true if "src" and "dst" are located in the same worker, 155 // and hence may use a local rendezvous. 156 virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, 157 DeviceNameUtils::ParsedName dst); 158 159 // If aborted, aborts "call". Otherwise, adds "call" into active_. 160 void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); 161 162 // Removes "call" from active_ if "call" is in active_. 163 void DeregisterCall(BaseRecvTensorCall* call); 164 165 WorkerSession* session(); 166 167 bool is_initialized(); 168 169 ~BaseRemoteRendezvous() override; 170 171 const WorkerEnv* const env_; // Not owned. 172 const int64 step_id_; 173 174 private: 175 Rendezvous* local_; // Owns a Ref on this object. 176 177 mutable mutex mu_; 178 179 // Status given by StartAbort() if any. 180 Status status_ TF_GUARDED_BY(mu_); 181 182 WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned. 183 184 // Data structures to handle calls when partially initialized. 185 struct DeferredCall { 186 const ParsedKey parsed; 187 DoneCallback done; 188 189 DeferredCall(const ParsedKey& parsed, DoneCallback done); 190 }; 191 std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_); 192 193 typedef std::function<void()> InactiveCallback; 194 195 std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_ 196 TF_GUARDED_BY(mu_); 197 is_initialized_locked()198 bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) { 199 return session_ != nullptr; 200 } 201 202 // If "is_src" is true, checks that the rendezvous key "parsed"'s 203 // source is in this process. If "is_src" is false, checks that the 204 // rendezvous key "parsed"'s destination is in this process. 205 Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); 206 207 // Callback handling the case when a rendezvous has been 208 // accomplished in local_ and the consumer is local to this process. 209 // Tensor "in" will be copied into "out". The key "parsed" encodes 210 // the src and dst devices. 211 void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, 212 const Rendezvous::Args& in_args, 213 const Rendezvous::Args& out_args, const Tensor& in, 214 Tensor* out, StatusCallback done); 215 216 // Must be called only if fully initialized. 217 void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); 218 219 TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); 220 }; 221 222 class BaseRecvTensorCall { 223 public: BaseRecvTensorCall()224 BaseRecvTensorCall() {} ~BaseRecvTensorCall()225 virtual ~BaseRecvTensorCall() {} 226 227 virtual void Start(std::function<void()> recv_done) = 0; 228 229 virtual void StartAbort(const Status& s) = 0; 230 231 virtual Status status() const = 0; 232 233 private: 234 TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall); 235 }; 236 237 } // end namespace tensorflow 238 239 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 240