• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_
6 #define MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_
7 
8 #include <memory>
9 
10 #include "base/macros.h"
11 #include "base/memory/ptr_util.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/stl_util.h"
14 #include "base/synchronization/waitable_event.h"
15 #include "base/task_runner.h"
16 #include "base/threading/sequenced_task_runner_handle.h"
17 #include "mojo/public/cpp/bindings/associated_group.h"
18 #include "mojo/public/cpp/bindings/associated_interface_ptr.h"
19 #include "mojo/public/cpp/bindings/interface_ptr.h"
20 #include "mojo/public/cpp/bindings/message.h"
21 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
22 #include "mojo/public/cpp/bindings/sync_event_watcher.h"
23 
24 // ThreadSafeInterfacePtr wraps a non-thread-safe InterfacePtr and proxies
25 // messages to it. Async calls are posted to the sequence that the InteracePtr
26 // is bound to, and the responses are posted back. Sync calls are dispatched
27 // directly if the call is made on the sequence that the wrapped InterfacePtr is
28 // bound to, or posted otherwise. It's important to be aware that sync calls
29 // block both the calling sequence and the InterfacePtr sequence. That means
30 // that you cannot make sync calls through a ThreadSafeInterfacePtr if the
31 // underlying InterfacePtr is bound to a sequence that cannot block, like the IO
32 // thread.
33 
34 namespace mojo {
35 
36 // Instances of this class may be used from any sequence to serialize
37 // |Interface| messages and forward them elsewhere. In general you should use
38 // one of the ThreadSafeInterfacePtrBase helper aliases defined below, but this
39 // type may be useful if you need/want to manually manage the lifetime of the
40 // underlying proxy object which will be used to ultimately send messages.
41 template <typename Interface>
42 class ThreadSafeForwarder : public MessageReceiverWithResponder {
43  public:
44   using ProxyType = typename Interface::Proxy_;
45   using ForwardMessageCallback = base::Callback<void(Message)>;
46   using ForwardMessageWithResponderCallback =
47       base::Callback<void(Message, std::unique_ptr<MessageReceiver>)>;
48 
49   // Constructs a ThreadSafeForwarder through which Messages are forwarded to
50   // |forward| or |forward_with_responder| by posting to |task_runner|.
51   //
52   // Any message sent through this forwarding interface will dispatch its reply,
53   // if any, back to the sequence which called the corresponding interface
54   // method.
ThreadSafeForwarder(const scoped_refptr<base::SequencedTaskRunner> & task_runner,const ForwardMessageCallback & forward,const ForwardMessageWithResponderCallback & forward_with_responder,const AssociatedGroup & associated_group)55   ThreadSafeForwarder(
56       const scoped_refptr<base::SequencedTaskRunner>& task_runner,
57       const ForwardMessageCallback& forward,
58       const ForwardMessageWithResponderCallback& forward_with_responder,
59       const AssociatedGroup& associated_group)
60       : proxy_(this),
61         task_runner_(task_runner),
62         forward_(forward),
63         forward_with_responder_(forward_with_responder),
64         associated_group_(associated_group),
65         sync_calls_(new InProgressSyncCalls()) {}
66 
~ThreadSafeForwarder()67   ~ThreadSafeForwarder() override {
68     // If there are ongoing sync calls signal their completion now.
69     base::AutoLock l(sync_calls_->lock);
70     for (const auto& pending_response : sync_calls_->pending_responses)
71       pending_response->event.Signal();
72   }
73 
proxy()74   ProxyType& proxy() { return proxy_; }
75 
76  private:
77   // MessageReceiverWithResponder implementation:
PrefersSerializedMessages()78   bool PrefersSerializedMessages() override {
79     // TSIP is primarily used because it emulates legacy IPC threading behavior.
80     // In practice this means it's only for cross-process messaging and we can
81     // just always assume messages should be serialized.
82     return true;
83   }
84 
Accept(Message * message)85   bool Accept(Message* message) override {
86     if (!message->associated_endpoint_handles()->empty()) {
87       // If this DCHECK fails, it is likely because:
88       // - This is a non-associated interface pointer setup using
89       //     PtrWrapper::BindOnTaskRunner(
90       //         InterfacePtrInfo<InterfaceType> ptr_info);
91       //   Please see the TODO in that method.
92       // - This is an associated interface which hasn't been associated with a
93       //   message pipe. In other words, the corresponding
94       //   AssociatedInterfaceRequest hasn't been sent.
95       DCHECK(associated_group_.GetController());
96       message->SerializeAssociatedEndpointHandles(
97           associated_group_.GetController());
98     }
99     task_runner_->PostTask(FROM_HERE,
100                            base::Bind(forward_, base::Passed(message)));
101     return true;
102   }
103 
AcceptWithResponder(Message * message,std::unique_ptr<MessageReceiver> responder)104   bool AcceptWithResponder(
105       Message* message,
106       std::unique_ptr<MessageReceiver> responder) override {
107     if (!message->associated_endpoint_handles()->empty()) {
108       // Please see comment for the DCHECK in the previous method.
109       DCHECK(associated_group_.GetController());
110       message->SerializeAssociatedEndpointHandles(
111           associated_group_.GetController());
112     }
113 
114     // Async messages are always posted (even if |task_runner_| runs tasks on
115     // this sequence) to guarantee that two async calls can't be reordered.
116     if (!message->has_flag(Message::kFlagIsSync)) {
117       auto reply_forwarder =
118           std::make_unique<ForwardToCallingThread>(std::move(responder));
119       task_runner_->PostTask(
120           FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message),
121                                 base::Passed(&reply_forwarder)));
122       return true;
123     }
124 
125     SyncCallRestrictions::AssertSyncCallAllowed();
126 
127     // If the InterfacePtr is bound to this sequence, dispatch it directly.
128     if (task_runner_->RunsTasksInCurrentSequence()) {
129       forward_with_responder_.Run(std::move(*message), std::move(responder));
130       return true;
131     }
132 
133     // If the InterfacePtr is bound on another sequence, post the call.
134     // TODO(yzshen, watk): We block both this sequence and the InterfacePtr
135     // sequence. Ideally only this sequence would block.
136     auto response = base::MakeRefCounted<SyncResponseInfo>();
137     auto response_signaler = std::make_unique<SyncResponseSignaler>(response);
138     task_runner_->PostTask(
139         FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message),
140                               base::Passed(&response_signaler)));
141 
142     // Save the pending SyncResponseInfo so that if the sync call deletes
143     // |this|, we can signal the completion of the call to return from
144     // SyncWatch().
145     auto sync_calls = sync_calls_;
146     {
147       base::AutoLock l(sync_calls->lock);
148       sync_calls->pending_responses.push_back(response.get());
149     }
150 
151     auto assign_true = [](bool* b) { *b = true; };
152     bool event_signaled = false;
153     SyncEventWatcher watcher(&response->event,
154                              base::Bind(assign_true, &event_signaled));
155     const bool* stop_flags[] = {&event_signaled};
156     watcher.SyncWatch(stop_flags, 1);
157 
158     {
159       base::AutoLock l(sync_calls->lock);
160       base::Erase(sync_calls->pending_responses, response.get());
161     }
162 
163     if (response->received)
164       ignore_result(responder->Accept(&response->message));
165 
166     return true;
167   }
168 
169   // Data that we need to share between the sequences involved in a sync call.
170   struct SyncResponseInfo
171       : public base::RefCountedThreadSafe<SyncResponseInfo> {
172     Message message;
173     bool received = false;
174     base::WaitableEvent event{base::WaitableEvent::ResetPolicy::MANUAL,
175                               base::WaitableEvent::InitialState::NOT_SIGNALED};
176 
177    private:
178     friend class base::RefCountedThreadSafe<SyncResponseInfo>;
179   };
180 
181   // A MessageReceiver that signals |response| when it either accepts the
182   // response message, or is destructed.
183   class SyncResponseSignaler : public MessageReceiver {
184    public:
SyncResponseSignaler(scoped_refptr<SyncResponseInfo> response)185     explicit SyncResponseSignaler(scoped_refptr<SyncResponseInfo> response)
186         : response_(response) {}
187 
~SyncResponseSignaler()188     ~SyncResponseSignaler() override {
189       // If Accept() was not called we must still notify the waiter that the
190       // sync call is finished.
191       if (response_)
192         response_->event.Signal();
193     }
194 
Accept(Message * message)195     bool Accept(Message* message) override {
196       response_->message = std::move(*message);
197       response_->received = true;
198       response_->event.Signal();
199       response_ = nullptr;
200       return true;
201     }
202 
203    private:
204     scoped_refptr<SyncResponseInfo> response_;
205   };
206 
207   // A record of the pending sync responses for canceling pending sync calls
208   // when the owning ThreadSafeForwarder is destructed.
209   struct InProgressSyncCalls
210       : public base::RefCountedThreadSafe<InProgressSyncCalls> {
211     // |lock| protects access to |pending_responses|.
212     base::Lock lock;
213     std::vector<SyncResponseInfo*> pending_responses;
214   };
215 
216   class ForwardToCallingThread : public MessageReceiver {
217    public:
ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder)218     explicit ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder)
219         : responder_(std::move(responder)),
220           caller_task_runner_(base::SequencedTaskRunnerHandle::Get()) {}
~ForwardToCallingThread()221     ~ForwardToCallingThread() override {
222       caller_task_runner_->DeleteSoon(FROM_HERE, std::move(responder_));
223     }
224 
225    private:
Accept(Message * message)226     bool Accept(Message* message) override {
227       // The current instance will be deleted when this method returns, so we
228       // have to relinquish the responder's ownership so it does not get
229       // deleted.
230       caller_task_runner_->PostTask(
231           FROM_HERE,
232           base::Bind(&ForwardToCallingThread::CallAcceptAndDeleteResponder,
233                      base::Passed(std::move(responder_)),
234                      base::Passed(std::move(*message))));
235       return true;
236     }
237 
CallAcceptAndDeleteResponder(std::unique_ptr<MessageReceiver> responder,Message message)238     static void CallAcceptAndDeleteResponder(
239         std::unique_ptr<MessageReceiver> responder,
240         Message message) {
241       ignore_result(responder->Accept(&message));
242     }
243 
244     std::unique_ptr<MessageReceiver> responder_;
245     scoped_refptr<base::SequencedTaskRunner> caller_task_runner_;
246   };
247 
248   ProxyType proxy_;
249   const scoped_refptr<base::SequencedTaskRunner> task_runner_;
250   const ForwardMessageCallback forward_;
251   const ForwardMessageWithResponderCallback forward_with_responder_;
252   AssociatedGroup associated_group_;
253   scoped_refptr<InProgressSyncCalls> sync_calls_;
254 
255   DISALLOW_COPY_AND_ASSIGN(ThreadSafeForwarder);
256 };
257 
258 template <typename InterfacePtrType>
259 class ThreadSafeInterfacePtrBase
260     : public base::RefCountedThreadSafe<
261           ThreadSafeInterfacePtrBase<InterfacePtrType>> {
262  public:
263   using InterfaceType = typename InterfacePtrType::InterfaceType;
264   using PtrInfoType = typename InterfacePtrType::PtrInfoType;
265 
ThreadSafeInterfacePtrBase(std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder)266   explicit ThreadSafeInterfacePtrBase(
267       std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder)
268       : forwarder_(std::move(forwarder)) {}
269 
270   // Creates a ThreadSafeInterfacePtrBase wrapping an underlying non-thread-safe
271   // InterfacePtrType which is bound to the calling sequence. All messages sent
272   // via this thread-safe proxy will internally be sent by first posting to this
273   // (the calling) sequence's TaskRunner.
Create(InterfacePtrType interface_ptr)274   static scoped_refptr<ThreadSafeInterfacePtrBase> Create(
275       InterfacePtrType interface_ptr) {
276     scoped_refptr<PtrWrapper> wrapper =
277         new PtrWrapper(std::move(interface_ptr));
278     return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder());
279   }
280 
281   // Creates a ThreadSafeInterfacePtrBase which binds the underlying
282   // non-thread-safe InterfacePtrType on the specified TaskRunner. All messages
283   // sent via this thread-safe proxy will internally be sent by first posting to
284   // that TaskRunner.
Create(PtrInfoType ptr_info,const scoped_refptr<base::SequencedTaskRunner> & bind_task_runner)285   static scoped_refptr<ThreadSafeInterfacePtrBase> Create(
286       PtrInfoType ptr_info,
287       const scoped_refptr<base::SequencedTaskRunner>& bind_task_runner) {
288     scoped_refptr<PtrWrapper> wrapper = new PtrWrapper(bind_task_runner);
289     wrapper->BindOnTaskRunner(std::move(ptr_info));
290     return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder());
291   }
292 
get()293   InterfaceType* get() { return &forwarder_->proxy(); }
294   InterfaceType* operator->() { return get(); }
295   InterfaceType& operator*() { return *get(); }
296 
297  private:
298   friend class base::RefCountedThreadSafe<
299       ThreadSafeInterfacePtrBase<InterfacePtrType>>;
300 
301   struct PtrWrapperDeleter;
302 
303   // Helper class which owns an |InterfacePtrType| instance on an appropriate
304   // sequence. This is kept alive as long its bound within some
305   // ThreadSafeForwarder's callbacks.
306   class PtrWrapper
307       : public base::RefCountedThreadSafe<PtrWrapper, PtrWrapperDeleter> {
308    public:
PtrWrapper(InterfacePtrType ptr)309     explicit PtrWrapper(InterfacePtrType ptr)
310         : PtrWrapper(base::SequencedTaskRunnerHandle::Get()) {
311       ptr_ = std::move(ptr);
312       associated_group_ = *ptr_.internal_state()->associated_group();
313     }
314 
PtrWrapper(const scoped_refptr<base::SequencedTaskRunner> & task_runner)315     explicit PtrWrapper(
316         const scoped_refptr<base::SequencedTaskRunner>& task_runner)
317         : task_runner_(task_runner) {}
318 
BindOnTaskRunner(AssociatedInterfacePtrInfo<InterfaceType> ptr_info)319     void BindOnTaskRunner(AssociatedInterfacePtrInfo<InterfaceType> ptr_info) {
320       associated_group_ = AssociatedGroup(ptr_info.handle());
321       task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this,
322                                                    base::Passed(&ptr_info)));
323     }
324 
BindOnTaskRunner(InterfacePtrInfo<InterfaceType> ptr_info)325     void BindOnTaskRunner(InterfacePtrInfo<InterfaceType> ptr_info) {
326       // TODO(yzhsen): At the momment we don't have a group controller
327       // available. That means the user won't be able to pass associated
328       // endpoints on this interface (at least not immediately). In order to fix
329       // this, we need to create a MultiplexRouter immediately and bind it to
330       // the interface pointer on the |task_runner_|. Therefore, MultiplexRouter
331       // should be able to be created on a sequence different than the one that
332       // it is supposed to listen on. crbug.com/682334
333       task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this,
334                                                    base::Passed(&ptr_info)));
335     }
336 
CreateForwarder()337     std::unique_ptr<ThreadSafeForwarder<InterfaceType>> CreateForwarder() {
338       return std::make_unique<ThreadSafeForwarder<InterfaceType>>(
339           task_runner_, base::Bind(&PtrWrapper::Accept, this),
340           base::Bind(&PtrWrapper::AcceptWithResponder, this),
341           associated_group_);
342     }
343 
344    private:
345     friend struct PtrWrapperDeleter;
346 
~PtrWrapper()347     ~PtrWrapper() {}
348 
Bind(PtrInfoType ptr_info)349     void Bind(PtrInfoType ptr_info) {
350       DCHECK(task_runner_->RunsTasksInCurrentSequence());
351       ptr_.Bind(std::move(ptr_info));
352     }
353 
Accept(Message message)354     void Accept(Message message) {
355       ptr_.internal_state()->ForwardMessage(std::move(message));
356     }
357 
AcceptWithResponder(Message message,std::unique_ptr<MessageReceiver> responder)358     void AcceptWithResponder(Message message,
359                              std::unique_ptr<MessageReceiver> responder) {
360       ptr_.internal_state()->ForwardMessageWithResponder(std::move(message),
361                                                          std::move(responder));
362     }
363 
DeleteOnCorrectThread()364     void DeleteOnCorrectThread() const {
365       if (!task_runner_->RunsTasksInCurrentSequence()) {
366         // NOTE: This is only called when there are no more references to
367         // |this|, so binding it unretained is both safe and necessary.
368         task_runner_->PostTask(FROM_HERE,
369                                base::Bind(&PtrWrapper::DeleteOnCorrectThread,
370                                           base::Unretained(this)));
371       } else {
372         delete this;
373       }
374     }
375 
376     InterfacePtrType ptr_;
377     const scoped_refptr<base::SequencedTaskRunner> task_runner_;
378     AssociatedGroup associated_group_;
379 
380     DISALLOW_COPY_AND_ASSIGN(PtrWrapper);
381   };
382 
383   struct PtrWrapperDeleter {
DestructPtrWrapperDeleter384     static void Destruct(const PtrWrapper* interface_ptr) {
385       interface_ptr->DeleteOnCorrectThread();
386     }
387   };
388 
~ThreadSafeInterfacePtrBase()389   ~ThreadSafeInterfacePtrBase() {}
390 
391   const std::unique_ptr<ThreadSafeForwarder<InterfaceType>> forwarder_;
392 
393   DISALLOW_COPY_AND_ASSIGN(ThreadSafeInterfacePtrBase);
394 };
395 
396 template <typename Interface>
397 using ThreadSafeAssociatedInterfacePtr =
398     ThreadSafeInterfacePtrBase<AssociatedInterfacePtr<Interface>>;
399 
400 template <typename Interface>
401 using ThreadSafeInterfacePtr =
402     ThreadSafeInterfacePtrBase<InterfacePtr<Interface>>;
403 
404 }  // namespace mojo
405 
406 #endif  // MOJO_PUBLIC_CPP_BINDINGS_THREAD_SAFE_INTERFACE_PTR_H_
407