1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 18 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 21 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" 22 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" 23 #include "tensorflow/core/distributed_runtime/worker_env.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/strings/stringprintf.h" 27 #include "tensorflow/core/protobuf/eager_service.pb.h" 28 29 namespace tensorflow { 30 namespace eager { 31 32 // A TensorFlow Eager Worker runs ops and supports worker to worker 33 // Tensor transfer. 34 // 35 // See eager_service.proto for more details about each method. 36 // This class can be wrapped by specific classes that implement rpc transports 37 // over this (e.g. gRPC). 38 class EagerServiceImpl { 39 public: EagerServiceImpl(const WorkerEnv * env)40 explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { 41 gc_thread_.reset( 42 env_->env->StartThread({}, "EagerServiceContextGC", [this]() { 43 while (true) { 44 { 45 mutex_lock l(gc_thread_shutdown_mu_); 46 gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); 47 48 if (shutting_down_) { 49 return; 50 } 51 } 52 { 53 mutex_lock l(contexts_mu_); 54 for (auto it = contexts_.begin(); it != contexts_.end();) { 55 if (it->second->IsStale()) { 56 it->second->Unref(); 57 it = contexts_.erase(it); 58 } else { 59 it++; 60 } 61 } 62 } 63 } 64 })); 65 } ~EagerServiceImpl()66 virtual ~EagerServiceImpl() { 67 { 68 mutex_lock l(gc_thread_shutdown_mu_); 69 shutting_down_ = true; 70 gc_thread_cv_.notify_all(); 71 } 72 gc_thread_.reset(); 73 74 mutex_lock l(contexts_mu_); 75 for (auto& entry : contexts_) { 76 entry.second->Unref(); 77 } 78 } 79 80 Status CreateContext(const CreateContextRequest* request, 81 CreateContextResponse* response); 82 83 Status UpdateContext(const UpdateContextRequest* request, 84 UpdateContextResponse* response); 85 86 // Create a ServerContext for master eager context. 87 Status CreateMasterContext(const tensorflow::uint64 context_id, 88 EagerContext* context); 89 90 static constexpr uint64 kInvalidStreamId = 0; 91 92 // Used by both Enqueue and StreamingEnqueue RPCs. 93 Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, 94 EnqueueResponse* response, 95 uint64 stream_id = kInvalidStreamId); 96 97 Status WaitQueueDone(const WaitQueueDoneRequest* request, 98 WaitQueueDoneResponse* response); 99 100 void RunComponentFunction(CallOptions* call_opts, 101 const RunComponentFunctionRequest* request, 102 RunComponentFunctionResponse* response, 103 StatusCallback done); 104 105 Status KeepAlive(const KeepAliveRequest* request, 106 KeepAliveResponse* response); 107 108 Status CloseContext(const CloseContextRequest* request, 109 CloseContextResponse* response); 110 111 protected: 112 // This is the server-side execution context. All state regarding execution of 113 // a client's ops is held in this server-side context (all generated tensors, 114 // and the EagerContext). 115 class ServerContext : public core::RefCounted { 116 public: 117 // Create a ServerContext for local master. CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)118 static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx, 119 const WorkerEnv* env) { 120 return new ServerContext(ctx, -1, env, /* is_master= */ true); 121 } 122 123 explicit ServerContext(tensorflow::EagerContext* ctx, 124 int64_t destroy_after_secs, const WorkerEnv* env, 125 const bool is_master = false) ctx_(ctx)126 : ctx_(ctx), env_(env), is_master_(is_master) { 127 ctx->Ref(); 128 destroy_after_micros_ = 129 destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; 130 RecordAccess(); 131 } 132 ~ServerContext()133 ~ServerContext() override { 134 // TFE_Context is responsible for shutting down master eager context. 135 if (!is_master_) { 136 ctx_->WaitForAndCloseRemoteContexts(); 137 } 138 // ctx_->RefCountIsOne() should be true here when is_master_ = false. 139 // TODO(iga): Remove EagerContext refcounting. 140 ctx_->Unref(); 141 } 142 Context()143 tensorflow::EagerContext* Context() const { return ctx_; } 144 RecordAccess()145 void RecordAccess() { 146 mutex_lock l(last_accessed_mu_); 147 last_accessed_micros_ = env_->env->NowMicros(); 148 } 149 IsStale()150 bool IsStale() { 151 mutex_lock l(last_accessed_mu_); 152 const int64_t time_passed = 153 env_->env->NowMicros() - last_accessed_micros_; 154 return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_); 155 } 156 157 private: 158 // The context for this execution. 159 tensorflow::EagerContext* ctx_; 160 161 const WorkerEnv* const env_; // Not owned. 162 163 mutex last_accessed_mu_; 164 int64 last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_); 165 int64 destroy_after_micros_; 166 167 const bool is_master_; 168 }; 169 // The returned ServerContext will need to be Unrefed. 170 tensorflow::Status GetServerContext(uint64, ServerContext**); 171 172 class ClientTensorHandleDeleteNode : public EagerNode { 173 public: ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)174 ClientTensorHandleDeleteNode( 175 ServerContext* context, 176 std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete) 177 : tensorflow::EagerNode(), 178 context_(context), 179 handle_to_delete_(std::move(handle_to_delete)) { 180 context_->Ref(); 181 } 182 ~ClientTensorHandleDeleteNode()183 ~ClientTensorHandleDeleteNode() override { context_->Unref(); } 184 Run()185 Status Run() override { 186 VLOG(3) << "ServerContext: Deleting tensor handle " 187 << handle_to_delete_->op_id << ":" 188 << handle_to_delete_->output_num; 189 return context_->Context()->RemoteMgr()->DeleteTensorHandle( 190 *handle_to_delete_); 191 } 192 Abort(Status status)193 void Abort(Status status) override {} 194 195 // Remote node deletions are best effort Fatal()196 bool Fatal() const override { return false; } 197 DebugString()198 string DebugString() const override { 199 string out = "[ClientTensorHandleDeleteNode]"; 200 strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id); 201 strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num); 202 return out; 203 } 204 205 private: 206 // Owns one reference. 207 ServerContext* const context_; 208 const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_; 209 }; 210 211 private: 212 Status ExecuteOp(CallOptions* call_opts, const Operation& operation, 213 EagerContext* eager_context, EagerExecutor* eager_executor, 214 QueueResponse* queue_response); 215 Status SendTensor(const SendTensorOp& send_tensor, 216 EagerContext* eager_context); 217 Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, 218 EagerContext* eager_context); 219 Status RegisterFunction(const RegisterFunctionOp& register_function, 220 EagerContext* eager_context); 221 Status CleanupFunction(const CleanupFunctionOp& cleanup_function); 222 const WorkerEnv* const env_; // Not owned. 223 224 mutex contexts_mu_; 225 std::unordered_map<uint64, ServerContext*> contexts_ 226 TF_GUARDED_BY(contexts_mu_); 227 228 std::unique_ptr<Thread> gc_thread_; 229 mutex gc_thread_shutdown_mu_; 230 condition_variable gc_thread_cv_; 231 bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false; 232 233 TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); 234 }; 235 236 } // namespace eager 237 } // namespace tensorflow 238 239 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 240