1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
17 #define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
18
19 #include <list>
20
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/util/reffed_status_callback.h"
26
27 namespace tensorflow {
28
29 namespace internal {
30 // The following class is used for coordination between a `CallContainer`
31 // instance and a cancellation callback to make sure that the `CallContainer`
32 // instance waits for the cancellation callback to be destroyed (either because
33 // a cancellation occurred or because the callback was deregistered) before
34 // deleting itself. Without this coordination the cancellation callback could
35 // attempt to access a `CallContainer` instance that is no longer valid.
36 class NotifyWhenDestroyed {
37 public:
NotifyWhenDestroyed(std::shared_ptr<Notification> notification)38 explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
39 : notification_(std::move(notification)) {}
40
~NotifyWhenDestroyed()41 ~NotifyWhenDestroyed() { notification_->Notify(); }
42
43 private:
44 std::shared_ptr<Notification> notification_;
45 };
46 } // namespace internal
47
48 // The following class is responsible for the life cycle management of a set of
49 // RPC calls. The calls are started when an instance of the class is created and
50 // the class contract guarantees to invoke a "done" callback provided by the
51 // caller when all RPC calls have either completed or been cancelled.
52 //
53 // The caller should not make any assumptions about the validity of an instance
54 // of this class after the provided callback has been invoked, which may be
55 // immediately after the instance was created.
56 template <class Call>
57 class CallContainer {
58 public:
59 typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
60 typedef std::function<void(Call*)> StartCallFn;
61
62 // Uses the provided `create_call_fn` and `start_call_fn` functions to create
63 // and start a set of RPC calls. When all RPC calls have either completed or
64 // been cancelled, the `done` callback is invoked. The caller should not make
65 // any assumptions about the validity of the created instance as the instance
66 // will delete itself after invoking the `done` callback.
67 explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
68 bool try_rpc, AsyncOpKernel::DoneCallback done,
69 CreateCallFn create_call_fn,
70 StartCallFn start_call_fn);
71
72 // Registers a call with this container. This method expects its arguments to
73 // match those of a `Call` constructor as it forwards them to an underlying
74 // collection, which creates a `Call` instance in place.
75 template <class... Args>
76 void RegisterCall(Args&&... args);
77
78 // Starts the cancellation of all RPC calls managed by this container.
79 void StartCancel();
80
81 // Indicates that the `index`-th RPC call has finished.
82 void Done(const Status& s, int index);
83
84 private:
85 OpKernelContext* ctx_;
86 std::list<Call> calls_;
87 const AsyncOpKernel::DoneCallback done_;
88 const CancellationToken token_;
89 const bool fail_fast_;
90 const bool try_rpc_;
91 std::shared_ptr<Notification> callback_destroyed_;
92
93 // Performs its own reference counting.
94 ReffedStatusCallback* reffed_status_callback_;
95 };
96
97 template <class Call>
CallContainer(OpKernelContext * ctx,int num_calls,bool fail_fast,bool try_rpc,AsyncOpKernel::DoneCallback done,typename CallContainer<Call>::CreateCallFn create_call_fn,typename CallContainer<Call>::StartCallFn start_call_fn)98 CallContainer<Call>::CallContainer(
99 OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
100 AsyncOpKernel::DoneCallback done,
101 typename CallContainer<Call>::CreateCallFn create_call_fn,
102 typename CallContainer<Call>::StartCallFn start_call_fn)
103 : ctx_(ctx),
104 done_(std::move(done)),
105 token_(ctx->cancellation_manager() != nullptr
106 ? ctx->cancellation_manager()->get_cancellation_token()
107 : CancellationManager::kInvalidToken),
108 fail_fast_(fail_fast),
109 try_rpc_(try_rpc),
110 callback_destroyed_(new Notification) {
111 CHECK_GT(num_calls, 0);
112
113 // This will run when all RPCs are finished.
114 reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
115 if (token_ != CancellationManager::kInvalidToken) {
116 ctx_->cancellation_manager()->DeregisterCallback(token_);
117 }
118 ctx_->SetStatus(s);
119 done_();
120 callback_destroyed_->WaitForNotification();
121 delete this;
122 });
123
124 // The cancellation callback needs to be registered before the RPC calls are
125 // started to make sure that the callback is properly cleaned up by the
126 // `reffed_status_callback` when all calls complete. At the same time, the
127 // cancellation callback should wait for the RPC calls to be started for the
128 // cancellation to take effect.
129 std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
130 new internal::NotifyWhenDestroyed(callback_destroyed_));
131 std::shared_ptr<Notification> calls_started(new Notification);
132 bool is_cancelled = false;
133 if (token_ != CancellationManager::kInvalidToken) {
134 is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
135 token_, [this, calls_started, notify_when_destroyed]() {
136 calls_started->WaitForNotification();
137 StartCancel();
138 });
139 }
140
141 for (int i = 0; i < num_calls; ++i) {
142 create_call_fn(this, i);
143 // Increase the reference on the callback for each new RPC.
144 reffed_status_callback_->Ref();
145 }
146 for (Call& call : calls_) {
147 start_call_fn(&call);
148 }
149 calls_started->Notify();
150
151 if (is_cancelled) {
152 ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
153 StartCancel();
154 }
155
156 // Subtract reference count from the initial creation.
157 reffed_status_callback_->Unref();
158 }
159
160 template <class Call>
161 template <class... Args>
RegisterCall(Args &&...args)162 void CallContainer<Call>::RegisterCall(Args&&... args) {
163 calls_.emplace_back(std::forward<Args>(args)...);
164 }
165
166 template <class Call>
StartCancel()167 void CallContainer<Call>::StartCancel() {
168 for (auto& call : calls_) {
169 call.StartCancel();
170 }
171 }
172
173 template <class Call>
Done(const Status & s,int index)174 void CallContainer<Call>::Done(const Status& s, int index) {
175 if (!try_rpc_) {
176 reffed_status_callback_->UpdateStatus(s);
177 }
178 reffed_status_callback_->Unref();
179 }
180
181 } // namespace tensorflow
182 #endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
183