1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 18 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 21 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" 22 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" 23 #include "tensorflow/core/distributed_runtime/worker_env.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/strings/stringprintf.h" 27 #include "tensorflow/core/protobuf/eager_service.pb.h" 28 29 namespace tensorflow { 30 namespace eager { 31 32 // A TensorFlow Eager Worker runs ops and supports worker to worker 33 // Tensor transfer. 34 // 35 // See eager_service.proto for more details about each method. 36 // This class can be wrapped by specific classes that implement rpc transports 37 // over this (e.g. gRPC). 38 class EagerServiceImpl { 39 public: EagerServiceImpl(const WorkerEnv * env)40 explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { 41 gc_thread_.reset( 42 env_->env->StartThread({}, "EagerServiceContextGC", [this]() { 43 while (true) { 44 { 45 mutex_lock l(gc_thread_shutdown_mu_); 46 gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); 47 48 if (shutting_down_) { 49 return; 50 } 51 } 52 { 53 mutex_lock l(contexts_mu_); 54 for (auto it = contexts_.begin(); it != contexts_.end();) { 55 if (it->second->IsStale()) { 56 it->second->Unref(); 57 it = contexts_.erase(it); 58 } else { 59 it++; 60 } 61 } 62 } 63 } 64 })); 65 } ~EagerServiceImpl()66 virtual ~EagerServiceImpl() { 67 { 68 mutex_lock l(gc_thread_shutdown_mu_); 69 shutting_down_ = true; 70 gc_thread_cv_.notify_all(); 71 } 72 gc_thread_.reset(); 73 74 mutex_lock l(contexts_mu_); 75 for (auto& entry : contexts_) { 76 entry.second->Unref(); 77 } 78 } 79 80 Status CreateContext(const CreateContextRequest* request, 81 CreateContextResponse* response); 82 83 Status UpdateContext(const UpdateContextRequest* request, 84 UpdateContextResponse* response); 85 86 // Create a ServerContext for master eager context. 87 Status CreateMasterContext(const tensorflow::uint64 context_id, 88 EagerContext* context); 89 90 static constexpr uint64 kInvalidStreamId = 0; 91 92 // Used by both Enqueue and StreamingEnqueue RPCs. 93 Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, 94 EnqueueResponse* response, 95 uint64 stream_id = kInvalidStreamId); 96 97 Status WaitQueueDone(const WaitQueueDoneRequest* request, 98 WaitQueueDoneResponse* response); 99 100 void RunComponentFunction(CallOptions* call_opts, 101 const RunComponentFunctionRequest* request, 102 RunComponentFunctionResponse* response, 103 StatusCallback done); 104 105 Status KeepAlive(const KeepAliveRequest* request, 106 KeepAliveResponse* response); 107 108 Status CloseContext(const CloseContextRequest* request, 109 CloseContextResponse* response); 110 111 protected: 112 // This is the server-side execution context. All state regarding execution of 113 // a client's ops is held in this server-side context (all generated tensors, 114 // and the EagerContext). 115 class ServerContext : public core::RefCounted { 116 public: 117 // Create a ServerContext for local master. CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)118 static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx, 119 const WorkerEnv* env) { 120 return new ServerContext(ctx, -1, env, /* is_master= */ true); 121 } 122 123 explicit ServerContext(tensorflow::EagerContext* ctx, 124 int64 destroy_after_secs, const WorkerEnv* env, 125 const bool is_master = false) ctx_(ctx)126 : ctx_(ctx), env_(env), is_master_(is_master) { 127 ctx->Ref(); 128 destroy_after_micros_ = 129 destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; 130 RecordAccess(); 131 } 132 ~ServerContext()133 ~ServerContext() override { 134 // TFE_Context is responsible for shutting down master eager context. 135 if (!is_master_) { 136 ctx_->WaitForAndCloseRemoteContexts(); 137 } 138 // ctx_->RefCountIsOne() should be true here when is_master_ = false. 139 // TODO(iga): Remove EagerContext refcounting. 140 ctx_->Unref(); 141 } 142 Context()143 tensorflow::EagerContext* Context() const { return ctx_; } 144 RecordAccess()145 void RecordAccess() { 146 mutex_lock l(last_accessed_mu_); 147 last_accessed_micros_ = env_->env->NowMicros(); 148 } 149 IsStale()150 bool IsStale() { 151 mutex_lock l(last_accessed_mu_); 152 const int64 time_passed = env_->env->NowMicros() - last_accessed_micros_; 153 return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_); 154 } 155 156 private: 157 // The context for this execution. 158 tensorflow::EagerContext* ctx_; 159 160 const WorkerEnv* const env_; // Not owned. 161 162 mutex last_accessed_mu_; 163 int64 last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_); 164 int64 destroy_after_micros_; 165 166 const bool is_master_; 167 }; 168 // The returned ServerContext will need to be Unrefed. 169 tensorflow::Status GetServerContext(uint64, ServerContext**); 170 171 class ClientTensorHandleDeleteNode : public EagerNode { 172 public: ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)173 ClientTensorHandleDeleteNode( 174 ServerContext* context, 175 std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete) 176 : tensorflow::EagerNode(), 177 context_(context), 178 handle_to_delete_(std::move(handle_to_delete)) { 179 context_->Ref(); 180 } 181 ~ClientTensorHandleDeleteNode()182 ~ClientTensorHandleDeleteNode() override { context_->Unref(); } 183 Run()184 Status Run() override { 185 VLOG(3) << "ServerContext: Deleting tensor handle " 186 << handle_to_delete_->op_id << ":" 187 << handle_to_delete_->output_num; 188 return context_->Context()->RemoteMgr()->DeleteTensorHandle( 189 *handle_to_delete_); 190 } 191 Abort(Status status)192 void Abort(Status status) override {} 193 194 // Remote node deletions are best effort Fatal()195 bool Fatal() const override { return false; } 196 DebugString()197 string DebugString() const override { 198 string out = "[ClientTensorHandleDeleteNode]"; 199 strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id); 200 strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num); 201 return out; 202 } 203 204 private: 205 // Owns one reference. 206 ServerContext* const context_; 207 const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_; 208 }; 209 210 private: 211 Status ExecuteOp(CallOptions* call_opts, const Operation& operation, 212 EagerContext* eager_context, EagerExecutor* eager_executor, 213 QueueResponse* queue_response); 214 Status SendTensor(const SendTensorOp& send_tensor, 215 EagerContext* eager_context); 216 Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, 217 EagerContext* eager_context); 218 Status RegisterFunction(const RegisterFunctionOp& register_function, 219 EagerContext* eager_context); 220 Status CleanupFunction(const CleanupFunctionOp& cleanup_function); 221 const WorkerEnv* const env_; // Not owned. 222 223 mutex contexts_mu_; 224 std::unordered_map<uint64, ServerContext*> contexts_ 225 TF_GUARDED_BY(contexts_mu_); 226 227 std::unique_ptr<Thread> gc_thread_; 228 mutex gc_thread_shutdown_mu_; 229 condition_variable gc_thread_cv_; 230 bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false; 231 232 TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); 233 }; 234 235 } // namespace eager 236 } // namespace tensorflow 237 238 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 239