1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6
7 #include <stdint.h>
8
9 #include <utility>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/ptr_util.h"
16 #include "base/single_thread_task_runner.h"
17 #include "base/stl_util.h"
18 #include "mojo/public/cpp/bindings/associated_group.h"
19 #include "mojo/public/cpp/bindings/associated_group_controller.h"
20 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
21 #include "mojo/public/cpp/bindings/lib/validation_util.h"
22 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
23
24 namespace mojo {
25
26 // ----------------------------------------------------------------------------
27
28 namespace {
29
DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient> & client,const std::string & message)30 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client,
31 const std::string& message) {
32 bool is_valid = client && !client->encountered_error();
33 DCHECK(!is_valid) << message;
34 }
35
36 // When receiving an incoming message which expects a repsonse,
37 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
38 // incoming message receiver. When the receiver finishes processing the message,
39 // it can provide a response using this object.
40 class ResponderThunk : public MessageReceiverWithStatus {
41 public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SingleThreadTaskRunner> runner)42 explicit ResponderThunk(
43 const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
44 scoped_refptr<base::SingleThreadTaskRunner> runner)
45 : endpoint_client_(endpoint_client),
46 accept_was_invoked_(false),
47 task_runner_(std::move(runner)) {}
~ResponderThunk()48 ~ResponderThunk() override {
49 if (!accept_was_invoked_) {
50 // The Service handled a message that was expecting a response
51 // but did not send a response.
52 // We raise an error to signal the calling application that an error
53 // condition occurred. Without this the calling application would have no
54 // way of knowing it should stop waiting for a response.
55 if (task_runner_->RunsTasksOnCurrentThread()) {
56 // Please note that even if this code is run from a different task
57 // runner on the same thread as |task_runner_|, it is okay to directly
58 // call InterfaceEndpointClient::RaiseError(), because it will raise
59 // error from the correct task runner asynchronously.
60 if (endpoint_client_) {
61 endpoint_client_->RaiseError();
62 }
63 } else {
64 task_runner_->PostTask(
65 FROM_HERE,
66 base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
67 }
68 }
69 }
70
71 // MessageReceiver implementation:
Accept(Message * message)72 bool Accept(Message* message) override {
73 DCHECK(task_runner_->RunsTasksOnCurrentThread());
74 accept_was_invoked_ = true;
75 DCHECK(message->has_flag(Message::kFlagIsResponse));
76
77 bool result = false;
78
79 if (endpoint_client_)
80 result = endpoint_client_->Accept(message);
81
82 return result;
83 }
84
85 // MessageReceiverWithStatus implementation:
IsValid()86 bool IsValid() override {
87 DCHECK(task_runner_->RunsTasksOnCurrentThread());
88 return endpoint_client_ && !endpoint_client_->encountered_error();
89 }
90
DCheckInvalid(const std::string & message)91 void DCheckInvalid(const std::string& message) override {
92 if (task_runner_->RunsTasksOnCurrentThread()) {
93 DCheckIfInvalid(endpoint_client_, message);
94 } else {
95 task_runner_->PostTask(
96 FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message));
97 }
98 }
99
100 private:
101 base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
102 bool accept_was_invoked_;
103 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
104
105 DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
106 };
107
108 } // namespace
109
110 // ----------------------------------------------------------------------------
111
SyncResponseInfo(bool * in_response_received)112 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
113 bool* in_response_received)
114 : response_received(in_response_received) {}
115
~SyncResponseInfo()116 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
117
118 // ----------------------------------------------------------------------------
119
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)120 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
121 InterfaceEndpointClient* owner)
122 : owner_(owner) {}
123
124 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()125 ~HandleIncomingMessageThunk() {}
126
Accept(Message * message)127 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
128 Message* message) {
129 return owner_->HandleValidatedMessage(message);
130 }
131
132 // ----------------------------------------------------------------------------
133
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageReceiver> payload_validator,bool expect_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner,uint32_t interface_version)134 InterfaceEndpointClient::InterfaceEndpointClient(
135 ScopedInterfaceEndpointHandle handle,
136 MessageReceiverWithResponderStatus* receiver,
137 std::unique_ptr<MessageReceiver> payload_validator,
138 bool expect_sync_requests,
139 scoped_refptr<base::SingleThreadTaskRunner> runner,
140 uint32_t interface_version)
141 : expect_sync_requests_(expect_sync_requests),
142 handle_(std::move(handle)),
143 incoming_receiver_(receiver),
144 thunk_(this),
145 filters_(&thunk_),
146 task_runner_(std::move(runner)),
147 control_message_proxy_(this),
148 control_message_handler_(interface_version),
149 weak_ptr_factory_(this) {
150 DCHECK(handle_.is_valid());
151
152 // TODO(yzshen): the way to use validator (or message filter in general)
153 // directly is a little awkward.
154 if (payload_validator)
155 filters_.Append(std::move(payload_validator));
156
157 if (handle_.pending_association()) {
158 handle_.SetAssociationEventHandler(base::Bind(
159 &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this)));
160 } else {
161 InitControllerIfNecessary();
162 }
163 }
164
~InterfaceEndpointClient()165 InterfaceEndpointClient::~InterfaceEndpointClient() {
166 DCHECK(thread_checker_.CalledOnValidThread());
167
168 if (controller_)
169 handle_.group_controller()->DetachEndpointClient(handle_);
170 }
171
associated_group()172 AssociatedGroup* InterfaceEndpointClient::associated_group() {
173 if (!associated_group_)
174 associated_group_ = base::MakeUnique<AssociatedGroup>(handle_);
175 return associated_group_.get();
176 }
177
PassHandle()178 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
179 DCHECK(thread_checker_.CalledOnValidThread());
180 DCHECK(!has_pending_responders());
181
182 if (!handle_.is_valid())
183 return ScopedInterfaceEndpointHandle();
184
185 handle_.SetAssociationEventHandler(
186 ScopedInterfaceEndpointHandle::AssociationEventCallback());
187
188 if (controller_) {
189 controller_ = nullptr;
190 handle_.group_controller()->DetachEndpointClient(handle_);
191 }
192
193 return std::move(handle_);
194 }
195
AddFilter(std::unique_ptr<MessageReceiver> filter)196 void InterfaceEndpointClient::AddFilter(
197 std::unique_ptr<MessageReceiver> filter) {
198 filters_.Append(std::move(filter));
199 }
200
RaiseError()201 void InterfaceEndpointClient::RaiseError() {
202 DCHECK(thread_checker_.CalledOnValidThread());
203
204 if (!handle_.pending_association())
205 handle_.group_controller()->RaiseError();
206 }
207
CloseWithReason(uint32_t custom_reason,const std::string & description)208 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
209 const std::string& description) {
210 DCHECK(thread_checker_.CalledOnValidThread());
211
212 auto handle = PassHandle();
213 handle.ResetWithReason(custom_reason, description);
214 }
215
Accept(Message * message)216 bool InterfaceEndpointClient::Accept(Message* message) {
217 DCHECK(thread_checker_.CalledOnValidThread());
218 DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
219 DCHECK(!handle_.pending_association());
220
221 // This has to been done even if connection error has occurred. For example,
222 // the message contains a pending associated request. The user may try to use
223 // the corresponding associated interface pointer after sending this message.
224 // That associated interface pointer has to join an associated group in order
225 // to work properly.
226 if (!message->associated_endpoint_handles()->empty())
227 message->SerializeAssociatedEndpointHandles(handle_.group_controller());
228
229 if (encountered_error_)
230 return false;
231
232 InitControllerIfNecessary();
233
234 return controller_->SendMessage(message);
235 }
236
AcceptWithResponder(Message * message,MessageReceiver * responder)237 bool InterfaceEndpointClient::AcceptWithResponder(Message* message,
238 MessageReceiver* responder) {
239 DCHECK(thread_checker_.CalledOnValidThread());
240 DCHECK(message->has_flag(Message::kFlagExpectsResponse));
241 DCHECK(!handle_.pending_association());
242
243 // Please see comments in Accept().
244 if (!message->associated_endpoint_handles()->empty())
245 message->SerializeAssociatedEndpointHandles(handle_.group_controller());
246
247 if (encountered_error_)
248 return false;
249
250 InitControllerIfNecessary();
251
252 // Reserve 0 in case we want it to convey special meaning in the future.
253 uint64_t request_id = next_request_id_++;
254 if (request_id == 0)
255 request_id = next_request_id_++;
256
257 message->set_request_id(request_id);
258
259 bool is_sync = message->has_flag(Message::kFlagIsSync);
260 if (!controller_->SendMessage(message))
261 return false;
262
263 if (!is_sync) {
264 // We assume ownership of |responder|.
265 async_responders_[request_id] = base::WrapUnique(responder);
266 return true;
267 }
268
269 SyncCallRestrictions::AssertSyncCallAllowed();
270
271 bool response_received = false;
272 std::unique_ptr<MessageReceiver> sync_responder(responder);
273 sync_responses_.insert(std::make_pair(
274 request_id, base::MakeUnique<SyncResponseInfo>(&response_received)));
275
276 base::WeakPtr<InterfaceEndpointClient> weak_self =
277 weak_ptr_factory_.GetWeakPtr();
278 controller_->SyncWatch(&response_received);
279 // Make sure that this instance hasn't been destroyed.
280 if (weak_self) {
281 DCHECK(base::ContainsKey(sync_responses_, request_id));
282 auto iter = sync_responses_.find(request_id);
283 DCHECK_EQ(&response_received, iter->second->response_received);
284 if (response_received)
285 ignore_result(sync_responder->Accept(&iter->second->response));
286 sync_responses_.erase(iter);
287 }
288
289 // Return true means that we take ownership of |responder|.
290 return true;
291 }
292
HandleIncomingMessage(Message * message)293 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
294 DCHECK(thread_checker_.CalledOnValidThread());
295 return filters_.Accept(message);
296 }
297
NotifyError(const base::Optional<DisconnectReason> & reason)298 void InterfaceEndpointClient::NotifyError(
299 const base::Optional<DisconnectReason>& reason) {
300 DCHECK(thread_checker_.CalledOnValidThread());
301
302 if (encountered_error_)
303 return;
304 encountered_error_ = true;
305
306 // Response callbacks may hold on to resource, and there's no need to keep
307 // them alive any longer. Note that it's allowed that a pending response
308 // callback may own this endpoint, so we simply move the responders onto the
309 // stack here and let them be destroyed when the stack unwinds.
310 AsyncResponderMap responders = std::move(async_responders_);
311
312 control_message_proxy_.OnConnectionError();
313
314 if (!error_handler_.is_null()) {
315 base::Closure error_handler = std::move(error_handler_);
316 error_handler.Run();
317 } else if (!error_with_reason_handler_.is_null()) {
318 ConnectionErrorWithReasonCallback error_with_reason_handler =
319 std::move(error_with_reason_handler_);
320 if (reason) {
321 error_with_reason_handler.Run(reason->custom_reason, reason->description);
322 } else {
323 error_with_reason_handler.Run(0, std::string());
324 }
325 }
326 }
327
QueryVersion(const base::Callback<void (uint32_t)> & callback)328 void InterfaceEndpointClient::QueryVersion(
329 const base::Callback<void(uint32_t)>& callback) {
330 control_message_proxy_.QueryVersion(callback);
331 }
332
RequireVersion(uint32_t version)333 void InterfaceEndpointClient::RequireVersion(uint32_t version) {
334 control_message_proxy_.RequireVersion(version);
335 }
336
FlushForTesting()337 void InterfaceEndpointClient::FlushForTesting() {
338 control_message_proxy_.FlushForTesting();
339 }
340
InitControllerIfNecessary()341 void InterfaceEndpointClient::InitControllerIfNecessary() {
342 if (controller_ || handle_.pending_association())
343 return;
344
345 controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
346 task_runner_);
347 if (expect_sync_requests_)
348 controller_->AllowWokenUpBySyncWatchOnSameThread();
349 }
350
OnAssociationEvent(ScopedInterfaceEndpointHandle::AssociationEvent event)351 void InterfaceEndpointClient::OnAssociationEvent(
352 ScopedInterfaceEndpointHandle::AssociationEvent event) {
353 if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
354 InitControllerIfNecessary();
355 } else if (event ==
356 ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
357 task_runner_->PostTask(FROM_HERE,
358 base::Bind(&InterfaceEndpointClient::NotifyError,
359 weak_ptr_factory_.GetWeakPtr(),
360 handle_.disconnect_reason()));
361 }
362 }
363
HandleValidatedMessage(Message * message)364 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
365 DCHECK_EQ(handle_.id(), message->interface_id());
366
367 if (encountered_error_) {
368 // This message is received after error has been encountered. For associated
369 // interfaces, this means the remote side sends a
370 // PeerAssociatedEndpointClosed event but continues to send more messages
371 // for the same interface. Close the pipe because this shouldn't happen.
372 DVLOG(1) << "A message is received for an interface after it has been "
373 << "disconnected. Closing the pipe.";
374 return false;
375 }
376
377 if (message->has_flag(Message::kFlagExpectsResponse)) {
378 MessageReceiverWithStatus* responder =
379 new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_);
380 bool ok = false;
381 if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
382 ok = control_message_handler_.AcceptWithResponder(message, responder);
383 } else {
384 ok = incoming_receiver_->AcceptWithResponder(message, responder);
385 }
386 if (!ok)
387 delete responder;
388 return ok;
389 } else if (message->has_flag(Message::kFlagIsResponse)) {
390 uint64_t request_id = message->request_id();
391
392 if (message->has_flag(Message::kFlagIsSync)) {
393 auto it = sync_responses_.find(request_id);
394 if (it == sync_responses_.end())
395 return false;
396 it->second->response = std::move(*message);
397 *it->second->response_received = true;
398 return true;
399 }
400
401 auto it = async_responders_.find(request_id);
402 if (it == async_responders_.end())
403 return false;
404 std::unique_ptr<MessageReceiver> responder = std::move(it->second);
405 async_responders_.erase(it);
406 return responder->Accept(message);
407 } else {
408 if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
409 return control_message_handler_.Accept(message);
410
411 return incoming_receiver_->Accept(message);
412 }
413 }
414
415 } // namespace mojo
416