1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 18 19 #include <unordered_map> 20 21 #include "tensorflow/core/common_runtime/eager/context.h" 22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 23 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" 24 #include "tensorflow/core/distributed_runtime/worker_env.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 #include "tensorflow/core/lib/gtl/array_slice.h" 27 #include "tensorflow/core/lib/strings/stringprintf.h" 28 #include "tensorflow/core/protobuf/eager_service.pb.h" 29 30 namespace tensorflow { 31 namespace eager { 32 33 // A TensorFlow Eager Worker runs ops and supports worker to worker 34 // Tensor transfer. 35 // 36 // See eager_service.proto for more details about each method. 37 // This class can be wrapped by specific classes that implement rpc transports 38 // over this (e.g. gRPC). 39 class EagerServiceImpl { 40 public: EagerServiceImpl(const WorkerEnv * env)41 explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { 42 gc_thread_.reset( 43 env_->env->StartThread({}, "EagerServiceContextGC", [this]() { 44 while (true) { 45 { 46 mutex_lock l(gc_thread_shutdown_mu_); 47 gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); 48 49 if (shutting_down_) { 50 return; 51 } 52 } 53 { 54 mutex_lock l(contexts_mu_); 55 for (auto it = contexts_.begin(); it != contexts_.end();) { 56 if (it->second->IsStale()) { 57 it->second->Unref(); 58 it = contexts_.erase(it); 59 } else { 60 it++; 61 } 62 } 63 } 64 } 65 })); 66 } ~EagerServiceImpl()67 virtual ~EagerServiceImpl() { 68 { 69 mutex_lock l(gc_thread_shutdown_mu_); 70 shutting_down_ = true; 71 gc_thread_cv_.notify_all(); 72 } 73 gc_thread_.reset(); 74 75 mutex_lock l(contexts_mu_); 76 for (auto& entry : contexts_) { 77 entry.second->Unref(); 78 } 79 } 80 81 Status CreateContext(const CreateContextRequest* request, 82 CreateContextResponse* response); 83 84 Status Enqueue(const EnqueueRequest* request, EnqueueResponse* response); 85 86 Status WaitQueueDone(const WaitQueueDoneRequest* request, 87 WaitQueueDoneResponse* response); 88 89 Status KeepAlive(const KeepAliveRequest* request, 90 KeepAliveResponse* response); 91 92 Status CloseContext(const CloseContextRequest* request, 93 CloseContextResponse* response); 94 95 Status RegisterFunction(const RegisterFunctionRequest* request, 96 RegisterFunctionResponse* response); 97 98 Status SendTensor(const SendTensorRequest* request, 99 SendTensorResponse* response); 100 101 protected: 102 // This is the server-side execution context. All state regarding execution of 103 // a client's ops is held in this server-side context (all generated tensors, 104 // and the EagerContext). 105 class ServerContext : public core::RefCounted { 106 public: ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,int64 destroy_after_secs,const WorkerEnv * env)107 explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx, 108 int64 destroy_after_secs, const WorkerEnv* env) 109 : ctx_(std::move(ctx)), env_(env) { 110 destroy_after_micros_ = 111 destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; 112 RecordAccess(); 113 } ~ServerContext()114 ~ServerContext() { 115 for (const auto& entry : tensors_) { 116 entry.second->Unref(); 117 } 118 } 119 Context()120 tensorflow::EagerContext* Context() const { return ctx_.get(); } 121 AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > & handles,int64 operation_id)122 void AddOperationOutputs( 123 const gtl::ArraySlice<tensorflow::TensorHandle*>& handles, 124 int64 operation_id) { 125 mutex_lock l(tensors_mu_); 126 for (int i = 0; i < handles.size(); i++) { 127 // TODO(nareshmodi): Correctly handle operation_id not being unique. 128 tensors_.emplace(RemoteTensorHandleInternal(operation_id, i), 129 handles[i]); 130 } 131 } 132 GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)133 Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle, 134 tensorflow::TensorHandle** handle) { 135 mutex_lock l(tensors_mu_); 136 auto iter = tensors_.find(remote_handle); 137 if (iter == tensors_.end()) { 138 return errors::InvalidArgument( 139 "Unable to find the relevant tensor remote_handle: Op ID: ", 140 remote_handle.op_id, ", Output num: ", remote_handle.output_num); 141 } 142 143 *handle = iter->second; 144 145 return Status::OK(); 146 } 147 DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)148 Status DeleteTensorHandle(const RemoteTensorHandleInternal& remote_handle) { 149 mutex_lock l(tensors_mu_); 150 auto iter = tensors_.find(remote_handle); 151 if (iter == tensors_.end()) { 152 return errors::InvalidArgument( 153 "Unable to find the relevant tensor remote_handle: Op ID: ", 154 remote_handle.op_id, ", Output num: ", remote_handle.output_num); 155 } 156 157 iter->second->Unref(); 158 tensors_.erase(iter); 159 160 return Status::OK(); 161 } 162 RecordAccess()163 void RecordAccess() { 164 mutex_lock l(last_accessed_mu_); 165 last_accessed_micros_ = env_->env->NowMicros(); 166 } 167 IsStale()168 bool IsStale() { 169 mutex_lock l(last_accessed_mu_); 170 return (destroy_after_micros_ > 0 && 171 (env_->env->NowMicros() - last_accessed_micros_) > 172 destroy_after_micros_); 173 } 174 175 private: 176 using RemoteTensorHandleMap = 177 gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*, 178 RemoteTensorHandleInternalHash, 179 RemoteTensorHandleInternalEquals>; 180 181 // The context for this execution. 182 std::unique_ptr<tensorflow::EagerContext> ctx_; 183 184 // The state related to the context for this execution. 185 mutex tensors_mu_; 186 RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_); 187 188 const WorkerEnv* const env_; // Not owned. 189 190 mutex last_accessed_mu_; 191 int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_); 192 int64 destroy_after_micros_; 193 }; 194 // The returned ServerContext will need to be Unrefed. 195 tensorflow::Status GetServerContext(uint64, ServerContext**); 196 197 private: 198 Status ExecuteOp(const Operation& operation, ServerContext* server_context, 199 QueueResponse* queue_response); 200 const WorkerEnv* const env_; // Not owned. 201 202 mutex contexts_mu_; 203 std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_); 204 205 std::unique_ptr<Thread> gc_thread_; 206 mutex gc_thread_shutdown_mu_; 207 condition_variable gc_thread_cv_; 208 bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false; 209 210 TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); 211 }; 212 213 } // namespace eager 214 } // namespace tensorflow 215 216 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 217