• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6 
7 #include <stdint.h>
8 
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/ptr_util.h"
16 #include "base/single_thread_task_runner.h"
17 #include "base/stl_util.h"
18 #include "mojo/public/cpp/bindings/associated_group.h"
19 #include "mojo/public/cpp/bindings/associated_group_controller.h"
20 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
21 #include "mojo/public/cpp/bindings/lib/validation_util.h"
22 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
23 
24 namespace mojo {
25 
26 // ----------------------------------------------------------------------------
27 
28 namespace {
29 
DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient> & client,const std::string & message)30 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client,
31                    const std::string& message) {
32   bool is_valid = client && !client->encountered_error();
33   DCHECK(!is_valid) << message;
34 }
35 
36 // When receiving an incoming message which expects a repsonse,
37 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
38 // incoming message receiver. When the receiver finishes processing the message,
39 // it can provide a response using this object.
40 class ResponderThunk : public MessageReceiverWithStatus {
41  public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SingleThreadTaskRunner> runner)42   explicit ResponderThunk(
43       const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
44       scoped_refptr<base::SingleThreadTaskRunner> runner)
45       : endpoint_client_(endpoint_client),
46         accept_was_invoked_(false),
47         task_runner_(std::move(runner)) {}
~ResponderThunk()48   ~ResponderThunk() override {
49     if (!accept_was_invoked_) {
50       // The Service handled a message that was expecting a response
51       // but did not send a response.
52       // We raise an error to signal the calling application that an error
53       // condition occurred. Without this the calling application would have no
54       // way of knowing it should stop waiting for a response.
55       if (task_runner_->RunsTasksOnCurrentThread()) {
56         // Please note that even if this code is run from a different task
57         // runner on the same thread as |task_runner_|, it is okay to directly
58         // call InterfaceEndpointClient::RaiseError(), because it will raise
59         // error from the correct task runner asynchronously.
60         if (endpoint_client_) {
61           endpoint_client_->RaiseError();
62         }
63       } else {
64         task_runner_->PostTask(
65             FROM_HERE,
66             base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
67       }
68     }
69   }
70 
71   // MessageReceiver implementation:
Accept(Message * message)72   bool Accept(Message* message) override {
73     DCHECK(task_runner_->RunsTasksOnCurrentThread());
74     accept_was_invoked_ = true;
75     DCHECK(message->has_flag(Message::kFlagIsResponse));
76 
77     bool result = false;
78 
79     if (endpoint_client_)
80       result = endpoint_client_->Accept(message);
81 
82     return result;
83   }
84 
85   // MessageReceiverWithStatus implementation:
IsValid()86   bool IsValid() override {
87     DCHECK(task_runner_->RunsTasksOnCurrentThread());
88     return endpoint_client_ && !endpoint_client_->encountered_error();
89   }
90 
DCheckInvalid(const std::string & message)91   void DCheckInvalid(const std::string& message) override {
92     if (task_runner_->RunsTasksOnCurrentThread()) {
93       DCheckIfInvalid(endpoint_client_, message);
94     } else {
95       task_runner_->PostTask(
96           FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message));
97     }
98  }
99 
100  private:
101   base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
102   bool accept_was_invoked_;
103   scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
104 
105   DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
106 };
107 
108 }  // namespace
109 
110 // ----------------------------------------------------------------------------
111 
SyncResponseInfo(bool * in_response_received)112 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
113     bool* in_response_received)
114     : response_received(in_response_received) {}
115 
~SyncResponseInfo()116 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
117 
118 // ----------------------------------------------------------------------------
119 
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)120 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
121     InterfaceEndpointClient* owner)
122     : owner_(owner) {}
123 
124 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()125     ~HandleIncomingMessageThunk() {}
126 
Accept(Message * message)127 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
128     Message* message) {
129   return owner_->HandleValidatedMessage(message);
130 }
131 
132 // ----------------------------------------------------------------------------
133 
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageReceiver> payload_validator,bool expect_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner,uint32_t interface_version)134 InterfaceEndpointClient::InterfaceEndpointClient(
135     ScopedInterfaceEndpointHandle handle,
136     MessageReceiverWithResponderStatus* receiver,
137     std::unique_ptr<MessageReceiver> payload_validator,
138     bool expect_sync_requests,
139     scoped_refptr<base::SingleThreadTaskRunner> runner,
140     uint32_t interface_version)
141     : expect_sync_requests_(expect_sync_requests),
142       handle_(std::move(handle)),
143       incoming_receiver_(receiver),
144       thunk_(this),
145       filters_(&thunk_),
146       task_runner_(std::move(runner)),
147       control_message_proxy_(this),
148       control_message_handler_(interface_version),
149       weak_ptr_factory_(this) {
150   DCHECK(handle_.is_valid());
151 
152   // TODO(yzshen): the way to use validator (or message filter in general)
153   // directly is a little awkward.
154   if (payload_validator)
155     filters_.Append(std::move(payload_validator));
156 
157   if (handle_.pending_association()) {
158     handle_.SetAssociationEventHandler(base::Bind(
159         &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this)));
160   } else {
161     InitControllerIfNecessary();
162   }
163 }
164 
~InterfaceEndpointClient()165 InterfaceEndpointClient::~InterfaceEndpointClient() {
166   DCHECK(thread_checker_.CalledOnValidThread());
167 
168   if (controller_)
169     handle_.group_controller()->DetachEndpointClient(handle_);
170 }
171 
associated_group()172 AssociatedGroup* InterfaceEndpointClient::associated_group() {
173   if (!associated_group_)
174     associated_group_ = base::MakeUnique<AssociatedGroup>(handle_);
175   return associated_group_.get();
176 }
177 
PassHandle()178 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
179   DCHECK(thread_checker_.CalledOnValidThread());
180   DCHECK(!has_pending_responders());
181 
182   if (!handle_.is_valid())
183     return ScopedInterfaceEndpointHandle();
184 
185   handle_.SetAssociationEventHandler(
186       ScopedInterfaceEndpointHandle::AssociationEventCallback());
187 
188   if (controller_) {
189     controller_ = nullptr;
190     handle_.group_controller()->DetachEndpointClient(handle_);
191   }
192 
193   return std::move(handle_);
194 }
195 
AddFilter(std::unique_ptr<MessageReceiver> filter)196 void InterfaceEndpointClient::AddFilter(
197     std::unique_ptr<MessageReceiver> filter) {
198   filters_.Append(std::move(filter));
199 }
200 
RaiseError()201 void InterfaceEndpointClient::RaiseError() {
202   DCHECK(thread_checker_.CalledOnValidThread());
203 
204   if (!handle_.pending_association())
205     handle_.group_controller()->RaiseError();
206 }
207 
CloseWithReason(uint32_t custom_reason,const std::string & description)208 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
209                                               const std::string& description) {
210   DCHECK(thread_checker_.CalledOnValidThread());
211 
212   auto handle = PassHandle();
213   handle.ResetWithReason(custom_reason, description);
214 }
215 
Accept(Message * message)216 bool InterfaceEndpointClient::Accept(Message* message) {
217   DCHECK(thread_checker_.CalledOnValidThread());
218   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
219   DCHECK(!handle_.pending_association());
220 
221   // This has to been done even if connection error has occurred. For example,
222   // the message contains a pending associated request. The user may try to use
223   // the corresponding associated interface pointer after sending this message.
224   // That associated interface pointer has to join an associated group in order
225   // to work properly.
226   if (!message->associated_endpoint_handles()->empty())
227     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
228 
229   if (encountered_error_)
230     return false;
231 
232   InitControllerIfNecessary();
233 
234   return controller_->SendMessage(message);
235 }
236 
AcceptWithResponder(Message * message,MessageReceiver * responder)237 bool InterfaceEndpointClient::AcceptWithResponder(Message* message,
238                                                   MessageReceiver* responder) {
239   DCHECK(thread_checker_.CalledOnValidThread());
240   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
241   DCHECK(!handle_.pending_association());
242 
243   // Please see comments in Accept().
244   if (!message->associated_endpoint_handles()->empty())
245     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
246 
247   if (encountered_error_)
248     return false;
249 
250   InitControllerIfNecessary();
251 
252   // Reserve 0 in case we want it to convey special meaning in the future.
253   uint64_t request_id = next_request_id_++;
254   if (request_id == 0)
255     request_id = next_request_id_++;
256 
257   message->set_request_id(request_id);
258 
259   bool is_sync = message->has_flag(Message::kFlagIsSync);
260   if (!controller_->SendMessage(message))
261     return false;
262 
263   if (!is_sync) {
264     // We assume ownership of |responder|.
265     async_responders_[request_id] = base::WrapUnique(responder);
266     return true;
267   }
268 
269   SyncCallRestrictions::AssertSyncCallAllowed();
270 
271   bool response_received = false;
272   std::unique_ptr<MessageReceiver> sync_responder(responder);
273   sync_responses_.insert(std::make_pair(
274       request_id, base::MakeUnique<SyncResponseInfo>(&response_received)));
275 
276   base::WeakPtr<InterfaceEndpointClient> weak_self =
277       weak_ptr_factory_.GetWeakPtr();
278   controller_->SyncWatch(&response_received);
279   // Make sure that this instance hasn't been destroyed.
280   if (weak_self) {
281     DCHECK(base::ContainsKey(sync_responses_, request_id));
282     auto iter = sync_responses_.find(request_id);
283     DCHECK_EQ(&response_received, iter->second->response_received);
284     if (response_received)
285       ignore_result(sync_responder->Accept(&iter->second->response));
286     sync_responses_.erase(iter);
287   }
288 
289   // Return true means that we take ownership of |responder|.
290   return true;
291 }
292 
HandleIncomingMessage(Message * message)293 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
294   DCHECK(thread_checker_.CalledOnValidThread());
295   return filters_.Accept(message);
296 }
297 
NotifyError(const base::Optional<DisconnectReason> & reason)298 void InterfaceEndpointClient::NotifyError(
299     const base::Optional<DisconnectReason>& reason) {
300   DCHECK(thread_checker_.CalledOnValidThread());
301 
302   if (encountered_error_)
303     return;
304   encountered_error_ = true;
305 
306   // Response callbacks may hold on to resource, and there's no need to keep
307   // them alive any longer. Note that it's allowed that a pending response
308   // callback may own this endpoint, so we simply move the responders onto the
309   // stack here and let them be destroyed when the stack unwinds.
310   AsyncResponderMap responders = std::move(async_responders_);
311 
312   control_message_proxy_.OnConnectionError();
313 
314   if (!error_handler_.is_null()) {
315     base::Closure error_handler = std::move(error_handler_);
316     error_handler.Run();
317   } else if (!error_with_reason_handler_.is_null()) {
318     ConnectionErrorWithReasonCallback error_with_reason_handler =
319         std::move(error_with_reason_handler_);
320     if (reason) {
321       error_with_reason_handler.Run(reason->custom_reason, reason->description);
322     } else {
323       error_with_reason_handler.Run(0, std::string());
324     }
325   }
326 }
327 
QueryVersion(const base::Callback<void (uint32_t)> & callback)328 void InterfaceEndpointClient::QueryVersion(
329     const base::Callback<void(uint32_t)>& callback) {
330   control_message_proxy_.QueryVersion(callback);
331 }
332 
RequireVersion(uint32_t version)333 void InterfaceEndpointClient::RequireVersion(uint32_t version) {
334   control_message_proxy_.RequireVersion(version);
335 }
336 
FlushForTesting()337 void InterfaceEndpointClient::FlushForTesting() {
338   control_message_proxy_.FlushForTesting();
339 }
340 
InitControllerIfNecessary()341 void InterfaceEndpointClient::InitControllerIfNecessary() {
342   if (controller_ || handle_.pending_association())
343     return;
344 
345   controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
346                                                                  task_runner_);
347   if (expect_sync_requests_)
348     controller_->AllowWokenUpBySyncWatchOnSameThread();
349 }
350 
OnAssociationEvent(ScopedInterfaceEndpointHandle::AssociationEvent event)351 void InterfaceEndpointClient::OnAssociationEvent(
352     ScopedInterfaceEndpointHandle::AssociationEvent event) {
353   if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
354     InitControllerIfNecessary();
355   } else if (event ==
356              ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
357     task_runner_->PostTask(FROM_HERE,
358                            base::Bind(&InterfaceEndpointClient::NotifyError,
359                                       weak_ptr_factory_.GetWeakPtr(),
360                                       handle_.disconnect_reason()));
361   }
362 }
363 
HandleValidatedMessage(Message * message)364 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
365   DCHECK_EQ(handle_.id(), message->interface_id());
366 
367   if (encountered_error_) {
368     // This message is received after error has been encountered. For associated
369     // interfaces, this means the remote side sends a
370     // PeerAssociatedEndpointClosed event but continues to send more messages
371     // for the same interface. Close the pipe because this shouldn't happen.
372     DVLOG(1) << "A message is received for an interface after it has been "
373              << "disconnected. Closing the pipe.";
374     return false;
375   }
376 
377   if (message->has_flag(Message::kFlagExpectsResponse)) {
378     MessageReceiverWithStatus* responder =
379         new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_);
380     bool ok = false;
381     if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
382       ok = control_message_handler_.AcceptWithResponder(message, responder);
383     } else {
384       ok = incoming_receiver_->AcceptWithResponder(message, responder);
385     }
386     if (!ok)
387       delete responder;
388     return ok;
389   } else if (message->has_flag(Message::kFlagIsResponse)) {
390     uint64_t request_id = message->request_id();
391 
392     if (message->has_flag(Message::kFlagIsSync)) {
393       auto it = sync_responses_.find(request_id);
394       if (it == sync_responses_.end())
395         return false;
396       it->second->response = std::move(*message);
397       *it->second->response_received = true;
398       return true;
399     }
400 
401     auto it = async_responders_.find(request_id);
402     if (it == async_responders_.end())
403       return false;
404     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
405     async_responders_.erase(it);
406     return responder->Accept(message);
407   } else {
408     if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
409       return control_message_handler_.Accept(message);
410 
411     return incoming_receiver_->Accept(message);
412   }
413 }
414 
415 }  // namespace mojo
416