• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
18 
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
22 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/worker_env.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/strings/stringprintf.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 
29 namespace tensorflow {
30 namespace eager {
31 
32 // A TensorFlow Eager Worker runs ops and supports worker to worker
33 // Tensor transfer.
34 //
35 // See eager_service.proto for more details about each method.
36 // This class can be wrapped by specific classes that implement rpc transports
37 // over this (e.g. gRPC).
38 class EagerServiceImpl {
39  public:
EagerServiceImpl(const WorkerEnv * env)40   explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {
41     gc_thread_.reset(
42         env_->env->StartThread({}, "EagerServiceContextGC", [this]() {
43           while (true) {
44             {
45               mutex_lock l(gc_thread_shutdown_mu_);
46               gc_thread_cv_.wait_for(l, std::chrono::seconds(1));
47 
48               if (shutting_down_) {
49                 return;
50               }
51             }
52             {
53               mutex_lock l(contexts_mu_);
54               for (auto it = contexts_.begin(); it != contexts_.end();) {
55                 if (it->second->IsStale()) {
56                   it->second->Unref();
57                   it = contexts_.erase(it);
58                 } else {
59                   it++;
60                 }
61               }
62             }
63           }
64         }));
65   }
~EagerServiceImpl()66   virtual ~EagerServiceImpl() {
67     {
68       mutex_lock l(gc_thread_shutdown_mu_);
69       shutting_down_ = true;
70       gc_thread_cv_.notify_all();
71     }
72     gc_thread_.reset();
73 
74     mutex_lock l(contexts_mu_);
75     for (auto& entry : contexts_) {
76       entry.second->Unref();
77     }
78   }
79 
80   Status CreateContext(const CreateContextRequest* request,
81                        CreateContextResponse* response);
82 
83   Status UpdateContext(const UpdateContextRequest* request,
84                        UpdateContextResponse* response);
85 
86   // Create a ServerContext for master eager context.
87   Status CreateMasterContext(const tensorflow::uint64 context_id,
88                              EagerContext* context);
89 
90   static constexpr uint64 kInvalidStreamId = 0;
91 
92   // Used by both Enqueue and StreamingEnqueue RPCs.
93   Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request,
94                  EnqueueResponse* response,
95                  uint64 stream_id = kInvalidStreamId);
96 
97   Status WaitQueueDone(const WaitQueueDoneRequest* request,
98                        WaitQueueDoneResponse* response);
99 
100   void RunComponentFunction(CallOptions* call_opts,
101                             const RunComponentFunctionRequest* request,
102                             RunComponentFunctionResponse* response,
103                             StatusCallback done);
104 
105   Status KeepAlive(const KeepAliveRequest* request,
106                    KeepAliveResponse* response);
107 
108   Status CloseContext(const CloseContextRequest* request,
109                       CloseContextResponse* response);
110 
111  protected:
112   // This is the server-side execution context. All state regarding execution of
113   // a client's ops is held in this server-side context (all generated tensors,
114   // and the EagerContext).
115   class ServerContext : public core::RefCounted {
116    public:
117     // Create a ServerContext for local master.
CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)118     static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx,
119                                               const WorkerEnv* env) {
120       return new ServerContext(ctx, -1, env, /* is_master= */ true);
121     }
122 
123     explicit ServerContext(tensorflow::EagerContext* ctx,
124                            int64_t destroy_after_secs, const WorkerEnv* env,
125                            const bool is_master = false)
ctx_(ctx)126         : ctx_(ctx), env_(env), is_master_(is_master) {
127       ctx->Ref();
128       destroy_after_micros_ =
129           destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
130       RecordAccess();
131     }
132 
~ServerContext()133     ~ServerContext() override {
134       // TFE_Context is responsible for shutting down master eager context.
135       if (!is_master_) {
136         ctx_->WaitForAndCloseRemoteContexts();
137       }
138       // ctx_->RefCountIsOne() should be true here when is_master_ = false.
139       // TODO(iga): Remove EagerContext refcounting.
140       ctx_->Unref();
141     }
142 
Context()143     tensorflow::EagerContext* Context() const { return ctx_; }
144 
RecordAccess()145     void RecordAccess() {
146       mutex_lock l(last_accessed_mu_);
147       last_accessed_micros_ = env_->env->NowMicros();
148     }
149 
IsStale()150     bool IsStale() {
151       mutex_lock l(last_accessed_mu_);
152       const int64_t time_passed =
153           env_->env->NowMicros() - last_accessed_micros_;
154       return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_);
155     }
156 
157    private:
158     // The context for this execution.
159     tensorflow::EagerContext* ctx_;
160 
161     const WorkerEnv* const env_;  // Not owned.
162 
163     mutex last_accessed_mu_;
164     int64 last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_);
165     int64 destroy_after_micros_;
166 
167     const bool is_master_;
168   };
169   // The returned ServerContext will need to be Unrefed.
170   tensorflow::Status GetServerContext(uint64, ServerContext**);
171 
172   class ClientTensorHandleDeleteNode : public EagerNode {
173    public:
ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)174     ClientTensorHandleDeleteNode(
175         ServerContext* context,
176         std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)
177         : tensorflow::EagerNode(),
178           context_(context),
179           handle_to_delete_(std::move(handle_to_delete)) {
180       context_->Ref();
181     }
182 
~ClientTensorHandleDeleteNode()183     ~ClientTensorHandleDeleteNode() override { context_->Unref(); }
184 
Run()185     Status Run() override {
186       VLOG(3) << "ServerContext: Deleting tensor handle "
187               << handle_to_delete_->op_id << ":"
188               << handle_to_delete_->output_num;
189       return context_->Context()->RemoteMgr()->DeleteTensorHandle(
190           *handle_to_delete_);
191     }
192 
Abort(Status status)193     void Abort(Status status) override {}
194 
195     // Remote node deletions are best effort
Fatal()196     bool Fatal() const override { return false; }
197 
DebugString()198     string DebugString() const override {
199       string out = "[ClientTensorHandleDeleteNode]";
200       strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id);
201       strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num);
202       return out;
203     }
204 
205    private:
206     // Owns one reference.
207     ServerContext* const context_;
208     const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_;
209   };
210 
211  private:
212   Status ExecuteOp(CallOptions* call_opts, const Operation& operation,
213                    EagerContext* eager_context, EagerExecutor* eager_executor,
214                    QueueResponse* queue_response);
215   Status SendTensor(const SendTensorOp& send_tensor,
216                     EagerContext* eager_context);
217   Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
218                           EagerContext* eager_context);
219   Status RegisterFunction(const RegisterFunctionOp& register_function,
220                           EagerContext* eager_context);
221   Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
222   const WorkerEnv* const env_;  // Not owned.
223 
224   mutex contexts_mu_;
225   std::unordered_map<uint64, ServerContext*> contexts_
226       TF_GUARDED_BY(contexts_mu_);
227 
228   std::unique_ptr<Thread> gc_thread_;
229   mutex gc_thread_shutdown_mu_;
230   condition_variable gc_thread_cv_;
231   bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false;
232 
233   TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
234 };
235 
236 }  // namespace eager
237 }  // namespace tensorflow
238 
239 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
240