• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
17 #define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
18 
19 #include <list>
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/util/reffed_status_callback.h"
26 
27 namespace tensorflow {
28 
29 namespace internal {
30 // The following class is used for coordination between a `CallContainer`
31 // instance and a cancellation callback to make sure that the `CallContainer`
32 // instance waits for the cancellation callback to be destroyed (either because
33 // a cancellation occurred or because the callback was deregistered) before
34 // deleting itself. Without this coordination the cancellation callback could
35 // attempt to access a `CallContainer` instance that is no longer valid.
36 class NotifyWhenDestroyed {
37  public:
NotifyWhenDestroyed(std::shared_ptr<Notification> notification)38   explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
39       : notification_(std::move(notification)) {}
40 
~NotifyWhenDestroyed()41   ~NotifyWhenDestroyed() { notification_->Notify(); }
42 
43  private:
44   std::shared_ptr<Notification> notification_;
45 };
46 }  // namespace internal
47 
48 // The following class is responsible for the life cycle management of a set of
49 // RPC calls. The calls are started when an instance of the class is created and
50 // the class contract guarantees to invoke a "done" callback provided by the
51 // caller when all RPC calls have either completed or been cancelled.
52 //
53 // The caller should not make any assumptions about the validity of an instance
54 // of this class after the provided callback has been invoked, which may be
55 // immediately after the instance was created.
56 template <class Call>
57 class CallContainer {
58  public:
59   typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
60   typedef std::function<void(Call*)> StartCallFn;
61 
62   // Uses the provided `create_call_fn` and `start_call_fn` functions to create
63   // and start a set of RPC calls. When all RPC calls have either completed or
64   // been cancelled, the `done` callback is invoked. The caller should not make
65   // any assumptions about the validity of the created instance as the instance
66   // will delete itself after invoking the `done` callback.
67   explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
68                          bool try_rpc, AsyncOpKernel::DoneCallback done,
69                          CreateCallFn create_call_fn,
70                          StartCallFn start_call_fn);
71 
72   // Registers a call with this container. This method expects its arguments to
73   // match those of a `Call` constructor as it forwards them to an underlying
74   // collection, which creates a `Call` instance in place.
75   template <class... Args>
76   void RegisterCall(Args&&... args);
77 
78   // Starts the cancellation of all RPC calls managed by this container.
79   void StartCancel();
80 
81   // Indicates that the `index`-th RPC call has finished.
82   void Done(const Status& s, int index);
83 
84  private:
85   OpKernelContext* ctx_;
86   std::list<Call> calls_;
87   const AsyncOpKernel::DoneCallback done_;
88   const CancellationToken token_;
89   const bool fail_fast_;
90   const bool try_rpc_;
91   std::shared_ptr<Notification> callback_destroyed_;
92 
93   // Performs its own reference counting.
94   ReffedStatusCallback* reffed_status_callback_;
95 };
96 
97 template <class Call>
CallContainer(OpKernelContext * ctx,int num_calls,bool fail_fast,bool try_rpc,AsyncOpKernel::DoneCallback done,typename CallContainer<Call>::CreateCallFn create_call_fn,typename CallContainer<Call>::StartCallFn start_call_fn)98 CallContainer<Call>::CallContainer(
99     OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
100     AsyncOpKernel::DoneCallback done,
101     typename CallContainer<Call>::CreateCallFn create_call_fn,
102     typename CallContainer<Call>::StartCallFn start_call_fn)
103     : ctx_(ctx),
104       done_(std::move(done)),
105       token_(ctx->cancellation_manager() != nullptr
106                  ? ctx->cancellation_manager()->get_cancellation_token()
107                  : CancellationManager::kInvalidToken),
108       fail_fast_(fail_fast),
109       try_rpc_(try_rpc),
110       callback_destroyed_(new Notification) {
111   CHECK_GT(num_calls, 0);
112 
113   // This will run when all RPCs are finished.
114   reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
115     if (token_ != CancellationManager::kInvalidToken) {
116       ctx_->cancellation_manager()->DeregisterCallback(token_);
117     }
118     ctx_->SetStatus(s);
119     done_();
120     callback_destroyed_->WaitForNotification();
121     delete this;
122   });
123 
124   // The cancellation callback needs to be registered before the RPC calls are
125   // started to make sure that the callback is properly cleaned up by the
126   // `reffed_status_callback` when all calls complete. At the same time, the
127   // cancellation callback should wait for the RPC calls to be started for the
128   // cancellation to take effect.
129   std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
130       new internal::NotifyWhenDestroyed(callback_destroyed_));
131   std::shared_ptr<Notification> calls_started(new Notification);
132   bool is_cancelled = false;
133   if (token_ != CancellationManager::kInvalidToken) {
134     is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
135         token_, [this, calls_started, notify_when_destroyed]() {
136           calls_started->WaitForNotification();
137           StartCancel();
138         });
139   }
140 
141   for (int i = 0; i < num_calls; ++i) {
142     create_call_fn(this, i);
143     // Increase the reference on the callback for each new RPC.
144     reffed_status_callback_->Ref();
145   }
146   for (Call& call : calls_) {
147     start_call_fn(&call);
148   }
149   calls_started->Notify();
150 
151   if (is_cancelled) {
152     ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
153     StartCancel();
154   }
155 
156   // Subtract reference count from the initial creation.
157   reffed_status_callback_->Unref();
158 }
159 
160 template <class Call>
161 template <class... Args>
RegisterCall(Args &&...args)162 void CallContainer<Call>::RegisterCall(Args&&... args) {
163   calls_.emplace_back(std::forward<Args>(args)...);
164 }
165 
166 template <class Call>
StartCancel()167 void CallContainer<Call>::StartCancel() {
168   for (auto& call : calls_) {
169     call.StartCancel();
170   }
171 }
172 
173 template <class Call>
Done(const Status & s,int index)174 void CallContainer<Call>::Done(const Status& s, int index) {
175   if (!try_rpc_) {
176     reffed_status_callback_->UpdateStatus(s);
177   }
178   reffed_status_callback_->Unref();
179 }
180 
181 }  // namespace tensorflow
182 #endif  // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
183