• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
18 
19 #include <unordered_map>
20 
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
24 #include "tensorflow/core/distributed_runtime/worker_env.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/strings/stringprintf.h"
28 #include "tensorflow/core/protobuf/eager_service.pb.h"
29 
30 namespace tensorflow {
31 namespace eager {
32 
33 // A TensorFlow Eager Worker runs ops and supports worker to worker
34 // Tensor transfer.
35 //
36 // See eager_service.proto for more details about each method.
37 // This class can be wrapped by specific classes that implement rpc transports
38 // over this (e.g. gRPC).
39 class EagerServiceImpl {
40  public:
EagerServiceImpl(const WorkerEnv * env)41   explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {
42     gc_thread_.reset(
43         env_->env->StartThread({}, "EagerServiceContextGC", [this]() {
44           while (true) {
45             {
46               mutex_lock l(gc_thread_shutdown_mu_);
47               gc_thread_cv_.wait_for(l, std::chrono::seconds(1));
48 
49               if (shutting_down_) {
50                 return;
51               }
52             }
53             {
54               mutex_lock l(contexts_mu_);
55               for (auto it = contexts_.begin(); it != contexts_.end();) {
56                 if (it->second->IsStale()) {
57                   it->second->Unref();
58                   it = contexts_.erase(it);
59                 } else {
60                   it++;
61                 }
62               }
63             }
64           }
65         }));
66   }
~EagerServiceImpl()67   virtual ~EagerServiceImpl() {
68     {
69       mutex_lock l(gc_thread_shutdown_mu_);
70       shutting_down_ = true;
71       gc_thread_cv_.notify_all();
72     }
73     gc_thread_.reset();
74 
75     mutex_lock l(contexts_mu_);
76     for (auto& entry : contexts_) {
77       entry.second->Unref();
78     }
79   }
80 
81   Status CreateContext(const CreateContextRequest* request,
82                        CreateContextResponse* response);
83 
84   Status Enqueue(const EnqueueRequest* request, EnqueueResponse* response);
85 
86   Status WaitQueueDone(const WaitQueueDoneRequest* request,
87                        WaitQueueDoneResponse* response);
88 
89   Status KeepAlive(const KeepAliveRequest* request,
90                    KeepAliveResponse* response);
91 
92   Status CloseContext(const CloseContextRequest* request,
93                       CloseContextResponse* response);
94 
95   Status RegisterFunction(const RegisterFunctionRequest* request,
96                           RegisterFunctionResponse* response);
97 
98   Status SendTensor(const SendTensorRequest* request,
99                     SendTensorResponse* response);
100 
101  protected:
102   // This is the server-side execution context. All state regarding execution of
103   // a client's ops is held in this server-side context (all generated tensors,
104   // and the EagerContext).
105   class ServerContext : public core::RefCounted {
106    public:
ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,int64 destroy_after_secs,const WorkerEnv * env)107     explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,
108                            int64 destroy_after_secs, const WorkerEnv* env)
109         : ctx_(std::move(ctx)), env_(env) {
110       destroy_after_micros_ =
111           destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
112       RecordAccess();
113     }
~ServerContext()114     ~ServerContext() {
115       for (const auto& entry : tensors_) {
116         entry.second->Unref();
117       }
118     }
119 
Context()120     tensorflow::EagerContext* Context() const { return ctx_.get(); }
121 
AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > & handles,int64 operation_id)122     void AddOperationOutputs(
123         const gtl::ArraySlice<tensorflow::TensorHandle*>& handles,
124         int64 operation_id) {
125       mutex_lock l(tensors_mu_);
126       for (int i = 0; i < handles.size(); i++) {
127         // TODO(nareshmodi): Correctly handle operation_id not being unique.
128         tensors_.emplace(RemoteTensorHandleInternal(operation_id, i),
129                          handles[i]);
130       }
131     }
132 
GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)133     Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle,
134                            tensorflow::TensorHandle** handle) {
135       mutex_lock l(tensors_mu_);
136       auto iter = tensors_.find(remote_handle);
137       if (iter == tensors_.end()) {
138         return errors::InvalidArgument(
139             "Unable to find the relevant tensor remote_handle: Op ID: ",
140             remote_handle.op_id, ", Output num: ", remote_handle.output_num);
141       }
142 
143       *handle = iter->second;
144 
145       return Status::OK();
146     }
147 
DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)148     Status DeleteTensorHandle(const RemoteTensorHandleInternal& remote_handle) {
149       mutex_lock l(tensors_mu_);
150       auto iter = tensors_.find(remote_handle);
151       if (iter == tensors_.end()) {
152         return errors::InvalidArgument(
153             "Unable to find the relevant tensor remote_handle: Op ID: ",
154             remote_handle.op_id, ", Output num: ", remote_handle.output_num);
155       }
156 
157       iter->second->Unref();
158       tensors_.erase(iter);
159 
160       return Status::OK();
161     }
162 
RecordAccess()163     void RecordAccess() {
164       mutex_lock l(last_accessed_mu_);
165       last_accessed_micros_ = env_->env->NowMicros();
166     }
167 
IsStale()168     bool IsStale() {
169       mutex_lock l(last_accessed_mu_);
170       return (destroy_after_micros_ > 0 &&
171               (env_->env->NowMicros() - last_accessed_micros_) >
172                   destroy_after_micros_);
173     }
174 
175    private:
176     using RemoteTensorHandleMap =
177         gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*,
178                      RemoteTensorHandleInternalHash,
179                      RemoteTensorHandleInternalEquals>;
180 
181     // The context for this execution.
182     std::unique_ptr<tensorflow::EagerContext> ctx_;
183 
184     // The state related to the context for this execution.
185     mutex tensors_mu_;
186     RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_);
187 
188     const WorkerEnv* const env_;  // Not owned.
189 
190     mutex last_accessed_mu_;
191     int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_);
192     int64 destroy_after_micros_;
193   };
194   // The returned ServerContext will need to be Unrefed.
195   tensorflow::Status GetServerContext(uint64, ServerContext**);
196 
197  private:
198   Status ExecuteOp(const Operation& operation, ServerContext* server_context,
199                    QueueResponse* queue_response);
200   const WorkerEnv* const env_;  // Not owned.
201 
202   mutex contexts_mu_;
203   std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_);
204 
205   std::unique_ptr<Thread> gc_thread_;
206   mutex gc_thread_shutdown_mu_;
207   condition_variable gc_thread_cv_;
208   bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false;
209 
210   TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
211 };
212 
213 }  // namespace eager
214 }  // namespace tensorflow
215 
216 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
217