1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6
7 #include <stdint.h>
8
9 #include "base/bind.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/macros.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/sequenced_task_runner.h"
15 #include "base/stl_util.h"
16 #include "mojo/public/cpp/bindings/associated_group.h"
17 #include "mojo/public/cpp/bindings/associated_group_controller.h"
18 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
19 #include "mojo/public/cpp/bindings/lib/task_runner_helper.h"
20 #include "mojo/public/cpp/bindings/lib/validation_util.h"
21 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
22
23 namespace mojo {
24
25 // ----------------------------------------------------------------------------
26
27 namespace {
28
DetermineIfEndpointIsConnected(const base::WeakPtr<InterfaceEndpointClient> & client,base::OnceCallback<void (bool)> callback)29 void DetermineIfEndpointIsConnected(
30 const base::WeakPtr<InterfaceEndpointClient>& client,
31 base::OnceCallback<void(bool)> callback) {
32 std::move(callback).Run(client && !client->encountered_error());
33 }
34
35 // When receiving an incoming message which expects a repsonse,
36 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
37 // incoming message receiver. When the receiver finishes processing the message,
38 // it can provide a response using this object.
39 class ResponderThunk : public MessageReceiverWithStatus {
40 public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SequencedTaskRunner> runner)41 explicit ResponderThunk(
42 const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
43 scoped_refptr<base::SequencedTaskRunner> runner)
44 : endpoint_client_(endpoint_client),
45 accept_was_invoked_(false),
46 task_runner_(std::move(runner)) {}
~ResponderThunk()47 ~ResponderThunk() override {
48 if (!accept_was_invoked_) {
49 // The Service handled a message that was expecting a response
50 // but did not send a response.
51 // We raise an error to signal the calling application that an error
52 // condition occurred. Without this the calling application would have no
53 // way of knowing it should stop waiting for a response.
54 if (task_runner_->RunsTasksInCurrentSequence()) {
55 // Please note that even if this code is run from a different task
56 // runner on the same thread as |task_runner_|, it is okay to directly
57 // call InterfaceEndpointClient::RaiseError(), because it will raise
58 // error from the correct task runner asynchronously.
59 if (endpoint_client_) {
60 endpoint_client_->RaiseError();
61 }
62 } else {
63 task_runner_->PostTask(
64 FROM_HERE,
65 base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
66 }
67 }
68 }
69
70 // MessageReceiver implementation:
PrefersSerializedMessages()71 bool PrefersSerializedMessages() override {
72 return endpoint_client_ && endpoint_client_->PrefersSerializedMessages();
73 }
74
Accept(Message * message)75 bool Accept(Message* message) override {
76 DCHECK(task_runner_->RunsTasksInCurrentSequence());
77 accept_was_invoked_ = true;
78 DCHECK(message->has_flag(Message::kFlagIsResponse));
79
80 bool result = false;
81
82 if (endpoint_client_)
83 result = endpoint_client_->Accept(message);
84
85 return result;
86 }
87
88 // MessageReceiverWithStatus implementation:
IsConnected()89 bool IsConnected() override {
90 DCHECK(task_runner_->RunsTasksInCurrentSequence());
91 return endpoint_client_ && !endpoint_client_->encountered_error();
92 }
93
IsConnectedAsync(base::OnceCallback<void (bool)> callback)94 void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override {
95 if (task_runner_->RunsTasksInCurrentSequence()) {
96 DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback));
97 } else {
98 task_runner_->PostTask(
99 FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected,
100 endpoint_client_, std::move(callback)));
101 }
102 }
103
104 private:
105 base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
106 bool accept_was_invoked_;
107 scoped_refptr<base::SequencedTaskRunner> task_runner_;
108
109 DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
110 };
111
112 } // namespace
113
114 // ----------------------------------------------------------------------------
115
SyncResponseInfo(bool * in_response_received)116 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
117 bool* in_response_received)
118 : response_received(in_response_received) {}
119
~SyncResponseInfo()120 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
121
122 // ----------------------------------------------------------------------------
123
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)124 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
125 InterfaceEndpointClient* owner)
126 : owner_(owner) {}
127
128 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()129 ~HandleIncomingMessageThunk() {}
130
Accept(Message * message)131 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
132 Message* message) {
133 return owner_->HandleValidatedMessage(message);
134 }
135
136 // ----------------------------------------------------------------------------
137
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageReceiver> payload_validator,bool expect_sync_requests,scoped_refptr<base::SequencedTaskRunner> runner,uint32_t interface_version)138 InterfaceEndpointClient::InterfaceEndpointClient(
139 ScopedInterfaceEndpointHandle handle,
140 MessageReceiverWithResponderStatus* receiver,
141 std::unique_ptr<MessageReceiver> payload_validator,
142 bool expect_sync_requests,
143 scoped_refptr<base::SequencedTaskRunner> runner,
144 uint32_t interface_version)
145 : expect_sync_requests_(expect_sync_requests),
146 handle_(std::move(handle)),
147 incoming_receiver_(receiver),
148 thunk_(this),
149 filters_(&thunk_),
150 task_runner_(std::move(runner)),
151 control_message_proxy_(this),
152 control_message_handler_(interface_version),
153 weak_ptr_factory_(this) {
154 DCHECK(handle_.is_valid());
155
156 // TODO(yzshen): the way to use validator (or message filter in general)
157 // directly is a little awkward.
158 if (payload_validator)
159 filters_.Append(std::move(payload_validator));
160
161 if (handle_.pending_association()) {
162 handle_.SetAssociationEventHandler(base::Bind(
163 &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this)));
164 } else {
165 InitControllerIfNecessary();
166 }
167 }
168
~InterfaceEndpointClient()169 InterfaceEndpointClient::~InterfaceEndpointClient() {
170 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
171 if (controller_)
172 handle_.group_controller()->DetachEndpointClient(handle_);
173 }
174
associated_group()175 AssociatedGroup* InterfaceEndpointClient::associated_group() {
176 if (!associated_group_)
177 associated_group_ = std::make_unique<AssociatedGroup>(handle_);
178 return associated_group_.get();
179 }
180
PassHandle()181 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
182 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
183 DCHECK(!has_pending_responders());
184
185 if (!handle_.is_valid())
186 return ScopedInterfaceEndpointHandle();
187
188 handle_.SetAssociationEventHandler(
189 ScopedInterfaceEndpointHandle::AssociationEventCallback());
190
191 if (controller_) {
192 controller_ = nullptr;
193 handle_.group_controller()->DetachEndpointClient(handle_);
194 }
195
196 return std::move(handle_);
197 }
198
AddFilter(std::unique_ptr<MessageReceiver> filter)199 void InterfaceEndpointClient::AddFilter(
200 std::unique_ptr<MessageReceiver> filter) {
201 filters_.Append(std::move(filter));
202 }
203
RaiseError()204 void InterfaceEndpointClient::RaiseError() {
205 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
206
207 if (!handle_.pending_association())
208 handle_.group_controller()->RaiseError();
209 }
210
CloseWithReason(uint32_t custom_reason,const std::string & description)211 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
212 const std::string& description) {
213 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
214
215 auto handle = PassHandle();
216 handle.ResetWithReason(custom_reason, description);
217 }
218
PrefersSerializedMessages()219 bool InterfaceEndpointClient::PrefersSerializedMessages() {
220 auto* controller = handle_.group_controller();
221 return controller && controller->PrefersSerializedMessages();
222 }
223
Accept(Message * message)224 bool InterfaceEndpointClient::Accept(Message* message) {
225 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
226 DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
227 DCHECK(!handle_.pending_association());
228
229 // This has to been done even if connection error has occurred. For example,
230 // the message contains a pending associated request. The user may try to use
231 // the corresponding associated interface pointer after sending this message.
232 // That associated interface pointer has to join an associated group in order
233 // to work properly.
234 if (!message->associated_endpoint_handles()->empty())
235 message->SerializeAssociatedEndpointHandles(handle_.group_controller());
236
237 if (encountered_error_)
238 return false;
239
240 InitControllerIfNecessary();
241
242 return controller_->SendMessage(message);
243 }
244
AcceptWithResponder(Message * message,std::unique_ptr<MessageReceiver> responder)245 bool InterfaceEndpointClient::AcceptWithResponder(
246 Message* message,
247 std::unique_ptr<MessageReceiver> responder) {
248 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
249 DCHECK(message->has_flag(Message::kFlagExpectsResponse));
250 DCHECK(!handle_.pending_association());
251
252 // Please see comments in Accept().
253 if (!message->associated_endpoint_handles()->empty())
254 message->SerializeAssociatedEndpointHandles(handle_.group_controller());
255
256 if (encountered_error_)
257 return false;
258
259 InitControllerIfNecessary();
260
261 // Reserve 0 in case we want it to convey special meaning in the future.
262 uint64_t request_id = next_request_id_++;
263 if (request_id == 0)
264 request_id = next_request_id_++;
265
266 message->set_request_id(request_id);
267
268 bool is_sync = message->has_flag(Message::kFlagIsSync);
269 if (!controller_->SendMessage(message))
270 return false;
271
272 if (!is_sync) {
273 async_responders_[request_id] = std::move(responder);
274 return true;
275 }
276
277 SyncCallRestrictions::AssertSyncCallAllowed();
278
279 bool response_received = false;
280 sync_responses_.insert(std::make_pair(
281 request_id, std::make_unique<SyncResponseInfo>(&response_received)));
282
283 base::WeakPtr<InterfaceEndpointClient> weak_self =
284 weak_ptr_factory_.GetWeakPtr();
285 controller_->SyncWatch(&response_received);
286 // Make sure that this instance hasn't been destroyed.
287 if (weak_self) {
288 DCHECK(base::ContainsKey(sync_responses_, request_id));
289 auto iter = sync_responses_.find(request_id);
290 DCHECK_EQ(&response_received, iter->second->response_received);
291 if (response_received) {
292 ignore_result(responder->Accept(&iter->second->response));
293 } else {
294 DVLOG(1) << "Mojo sync call returns without receiving a response. "
295 << "Typcially it is because the interface has been "
296 << "disconnected.";
297 }
298 sync_responses_.erase(iter);
299 }
300
301 return true;
302 }
303
HandleIncomingMessage(Message * message)304 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
305 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
306 return filters_.Accept(message);
307 }
308
NotifyError(const base::Optional<DisconnectReason> & reason)309 void InterfaceEndpointClient::NotifyError(
310 const base::Optional<DisconnectReason>& reason) {
311 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
312
313 if (encountered_error_)
314 return;
315 encountered_error_ = true;
316
317 // Response callbacks may hold on to resource, and there's no need to keep
318 // them alive any longer. Note that it's allowed that a pending response
319 // callback may own this endpoint, so we simply move the responders onto the
320 // stack here and let them be destroyed when the stack unwinds.
321 AsyncResponderMap responders = std::move(async_responders_);
322
323 control_message_proxy_.OnConnectionError();
324
325 if (error_handler_) {
326 std::move(error_handler_).Run();
327 } else if (error_with_reason_handler_) {
328 if (reason) {
329 std::move(error_with_reason_handler_)
330 .Run(reason->custom_reason, reason->description);
331 } else {
332 std::move(error_with_reason_handler_).Run(0, std::string());
333 }
334 }
335 }
336
QueryVersion(const base::Callback<void (uint32_t)> & callback)337 void InterfaceEndpointClient::QueryVersion(
338 const base::Callback<void(uint32_t)>& callback) {
339 control_message_proxy_.QueryVersion(callback);
340 }
341
RequireVersion(uint32_t version)342 void InterfaceEndpointClient::RequireVersion(uint32_t version) {
343 control_message_proxy_.RequireVersion(version);
344 }
345
FlushForTesting()346 void InterfaceEndpointClient::FlushForTesting() {
347 control_message_proxy_.FlushForTesting();
348 }
349
InitControllerIfNecessary()350 void InterfaceEndpointClient::InitControllerIfNecessary() {
351 if (controller_ || handle_.pending_association())
352 return;
353
354 controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
355 task_runner_);
356 if (expect_sync_requests_)
357 controller_->AllowWokenUpBySyncWatchOnSameThread();
358 }
359
OnAssociationEvent(ScopedInterfaceEndpointHandle::AssociationEvent event)360 void InterfaceEndpointClient::OnAssociationEvent(
361 ScopedInterfaceEndpointHandle::AssociationEvent event) {
362 if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
363 InitControllerIfNecessary();
364 } else if (event ==
365 ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
366 task_runner_->PostTask(FROM_HERE,
367 base::Bind(&InterfaceEndpointClient::NotifyError,
368 weak_ptr_factory_.GetWeakPtr(),
369 handle_.disconnect_reason()));
370 }
371 }
372
HandleValidatedMessage(Message * message)373 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
374 DCHECK_EQ(handle_.id(), message->interface_id());
375
376 if (encountered_error_) {
377 // This message is received after error has been encountered. For associated
378 // interfaces, this means the remote side sends a
379 // PeerAssociatedEndpointClosed event but continues to send more messages
380 // for the same interface. Close the pipe because this shouldn't happen.
381 DVLOG(1) << "A message is received for an interface after it has been "
382 << "disconnected. Closing the pipe.";
383 return false;
384 }
385
386 if (message->has_flag(Message::kFlagExpectsResponse)) {
387 std::unique_ptr<MessageReceiverWithStatus> responder =
388 std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(),
389 task_runner_);
390 if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
391 return control_message_handler_.AcceptWithResponder(message,
392 std::move(responder));
393 } else {
394 return incoming_receiver_->AcceptWithResponder(message,
395 std::move(responder));
396 }
397 } else if (message->has_flag(Message::kFlagIsResponse)) {
398 uint64_t request_id = message->request_id();
399
400 if (message->has_flag(Message::kFlagIsSync)) {
401 auto it = sync_responses_.find(request_id);
402 if (it == sync_responses_.end())
403 return false;
404 it->second->response = std::move(*message);
405 *it->second->response_received = true;
406 return true;
407 }
408
409 auto it = async_responders_.find(request_id);
410 if (it == async_responders_.end())
411 return false;
412 std::unique_ptr<MessageReceiver> responder = std::move(it->second);
413 async_responders_.erase(it);
414 return responder->Accept(message);
415 } else {
416 if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
417 return control_message_handler_.Accept(message);
418
419 return incoming_receiver_->Accept(message);
420 }
421 }
422
423 } // namespace mojo
424