• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
18 
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
22 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/worker_env.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/strings/stringprintf.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 
29 namespace tensorflow {
30 namespace eager {
31 
32 // A TensorFlow Eager Worker runs ops and supports worker to worker
33 // Tensor transfer.
34 //
35 // See eager_service.proto for more details about each method.
36 // This class can be wrapped by specific classes that implement rpc transports
37 // over this (e.g. gRPC).
38 class EagerServiceImpl {
39  public:
EagerServiceImpl(const WorkerEnv * env)40   explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {
41     gc_thread_.reset(
42         env_->env->StartThread({}, "EagerServiceContextGC", [this]() {
43           while (true) {
44             {
45               mutex_lock l(gc_thread_shutdown_mu_);
46               gc_thread_cv_.wait_for(l, std::chrono::seconds(1));
47 
48               if (shutting_down_) {
49                 return;
50               }
51             }
52             {
53               mutex_lock l(contexts_mu_);
54               for (auto it = contexts_.begin(); it != contexts_.end();) {
55                 if (it->second->IsStale()) {
56                   it->second->Unref();
57                   it = contexts_.erase(it);
58                 } else {
59                   it++;
60                 }
61               }
62             }
63           }
64         }));
65   }
~EagerServiceImpl()66   virtual ~EagerServiceImpl() {
67     {
68       mutex_lock l(gc_thread_shutdown_mu_);
69       shutting_down_ = true;
70       gc_thread_cv_.notify_all();
71     }
72     gc_thread_.reset();
73 
74     mutex_lock l(contexts_mu_);
75     for (auto& entry : contexts_) {
76       entry.second->Unref();
77     }
78   }
79 
80   Status CreateContext(const CreateContextRequest* request,
81                        CreateContextResponse* response);
82 
83   Status UpdateContext(const UpdateContextRequest* request,
84                        UpdateContextResponse* response);
85 
86   // Create a ServerContext for master eager context.
87   Status CreateMasterContext(const tensorflow::uint64 context_id,
88                              EagerContext* context);
89 
90   static constexpr uint64 kInvalidStreamId = 0;
91 
92   // Used by both Enqueue and StreamingEnqueue RPCs.
93   Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request,
94                  EnqueueResponse* response,
95                  uint64 stream_id = kInvalidStreamId);
96 
97   Status WaitQueueDone(const WaitQueueDoneRequest* request,
98                        WaitQueueDoneResponse* response);
99 
100   void RunComponentFunction(CallOptions* call_opts,
101                             const RunComponentFunctionRequest* request,
102                             RunComponentFunctionResponse* response,
103                             StatusCallback done);
104 
105   Status KeepAlive(const KeepAliveRequest* request,
106                    KeepAliveResponse* response);
107 
108   Status CloseContext(const CloseContextRequest* request,
109                       CloseContextResponse* response);
110 
111  protected:
112   // This is the server-side execution context. All state regarding execution of
113   // a client's ops is held in this server-side context (all generated tensors,
114   // and the EagerContext).
115   class ServerContext : public core::RefCounted {
116    public:
117     // Create a ServerContext for local master.
CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)118     static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx,
119                                               const WorkerEnv* env) {
120       return new ServerContext(ctx, -1, env, /* is_master= */ true);
121     }
122 
123     explicit ServerContext(tensorflow::EagerContext* ctx,
124                            int64 destroy_after_secs, const WorkerEnv* env,
125                            const bool is_master = false)
ctx_(ctx)126         : ctx_(ctx), env_(env), is_master_(is_master) {
127       ctx->Ref();
128       destroy_after_micros_ =
129           destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
130       RecordAccess();
131     }
132 
~ServerContext()133     ~ServerContext() override {
134       // TFE_Context is responsible for shutting down master eager context.
135       if (!is_master_) {
136         ctx_->WaitForAndCloseRemoteContexts();
137       }
138       // ctx_->RefCountIsOne() should be true here when is_master_ = false.
139       // TODO(iga): Remove EagerContext refcounting.
140       ctx_->Unref();
141     }
142 
Context()143     tensorflow::EagerContext* Context() const { return ctx_; }
144 
RecordAccess()145     void RecordAccess() {
146       mutex_lock l(last_accessed_mu_);
147       last_accessed_micros_ = env_->env->NowMicros();
148     }
149 
IsStale()150     bool IsStale() {
151       mutex_lock l(last_accessed_mu_);
152       const int64 time_passed = env_->env->NowMicros() - last_accessed_micros_;
153       return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_);
154     }
155 
156    private:
157     // The context for this execution.
158     tensorflow::EagerContext* ctx_;
159 
160     const WorkerEnv* const env_;  // Not owned.
161 
162     mutex last_accessed_mu_;
163     int64 last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_);
164     int64 destroy_after_micros_;
165 
166     const bool is_master_;
167   };
168   // The returned ServerContext will need to be Unrefed.
169   tensorflow::Status GetServerContext(uint64, ServerContext**);
170 
171   class ClientTensorHandleDeleteNode : public EagerNode {
172    public:
ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)173     ClientTensorHandleDeleteNode(
174         ServerContext* context,
175         std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)
176         : tensorflow::EagerNode(),
177           context_(context),
178           handle_to_delete_(std::move(handle_to_delete)) {
179       context_->Ref();
180     }
181 
~ClientTensorHandleDeleteNode()182     ~ClientTensorHandleDeleteNode() override { context_->Unref(); }
183 
Run()184     Status Run() override {
185       VLOG(3) << "ServerContext: Deleting tensor handle "
186               << handle_to_delete_->op_id << ":"
187               << handle_to_delete_->output_num;
188       return context_->Context()->RemoteMgr()->DeleteTensorHandle(
189           *handle_to_delete_);
190     }
191 
Abort(Status status)192     void Abort(Status status) override {}
193 
194     // Remote node deletions are best effort
Fatal()195     bool Fatal() const override { return false; }
196 
DebugString()197     string DebugString() const override {
198       string out = "[ClientTensorHandleDeleteNode]";
199       strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id);
200       strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num);
201       return out;
202     }
203 
204    private:
205     // Owns one reference.
206     ServerContext* const context_;
207     const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_;
208   };
209 
210  private:
211   Status ExecuteOp(CallOptions* call_opts, const Operation& operation,
212                    EagerContext* eager_context, EagerExecutor* eager_executor,
213                    QueueResponse* queue_response);
214   Status SendTensor(const SendTensorOp& send_tensor,
215                     EagerContext* eager_context);
216   Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
217                           EagerContext* eager_context);
218   Status RegisterFunction(const RegisterFunctionOp& register_function,
219                           EagerContext* eager_context);
220   Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
221   const WorkerEnv* const env_;  // Not owned.
222 
223   mutex contexts_mu_;
224   std::unordered_map<uint64, ServerContext*> contexts_
225       TF_GUARDED_BY(contexts_mu_);
226 
227   std::unique_ptr<Thread> gc_thread_;
228   mutex gc_thread_shutdown_mu_;
229   condition_variable gc_thread_cv_;
230   bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false;
231 
232   TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
233 };
234 
235 }  // namespace eager
236 }  // namespace tensorflow
237 
238 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
239