1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/lib/router.h"
6
7 #include <stdint.h>
8
9 #include <utility>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/stl_util.h"
16 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
17
18 namespace mojo {
19 namespace internal {
20
21 // ----------------------------------------------------------------------------
22
23 namespace {
24
DCheckIfInvalid(const base::WeakPtr<Router> & router,const std::string & message)25 void DCheckIfInvalid(const base::WeakPtr<Router>& router,
26 const std::string& message) {
27 bool is_valid = router && !router->encountered_error() && router->is_valid();
28 DCHECK(!is_valid) << message;
29 }
30
31 class ResponderThunk : public MessageReceiverWithStatus {
32 public:
ResponderThunk(const base::WeakPtr<Router> & router,scoped_refptr<base::SingleThreadTaskRunner> runner)33 explicit ResponderThunk(const base::WeakPtr<Router>& router,
34 scoped_refptr<base::SingleThreadTaskRunner> runner)
35 : router_(router),
36 accept_was_invoked_(false),
37 task_runner_(std::move(runner)) {}
~ResponderThunk()38 ~ResponderThunk() override {
39 if (!accept_was_invoked_) {
40 // The Mojo application handled a message that was expecting a response
41 // but did not send a response.
42 // We raise an error to signal the calling application that an error
43 // condition occurred. Without this the calling application would have no
44 // way of knowing it should stop waiting for a response.
45 if (task_runner_->RunsTasksOnCurrentThread()) {
46 // Please note that even if this code is run from a different task
47 // runner on the same thread as |task_runner_|, it is okay to directly
48 // call Router::RaiseError(), because it will raise error from the
49 // correct task runner asynchronously.
50 if (router_)
51 router_->RaiseError();
52 } else {
53 task_runner_->PostTask(FROM_HERE,
54 base::Bind(&Router::RaiseError, router_));
55 }
56 }
57 }
58
59 // MessageReceiver implementation:
Accept(Message * message)60 bool Accept(Message* message) override {
61 DCHECK(task_runner_->RunsTasksOnCurrentThread());
62 accept_was_invoked_ = true;
63 DCHECK(message->has_flag(Message::kFlagIsResponse));
64
65 bool result = false;
66
67 if (router_)
68 result = router_->Accept(message);
69
70 return result;
71 }
72
73 // MessageReceiverWithStatus implementation:
IsValid()74 bool IsValid() override {
75 DCHECK(task_runner_->RunsTasksOnCurrentThread());
76 return router_ && !router_->encountered_error() && router_->is_valid();
77 }
78
DCheckInvalid(const std::string & message)79 void DCheckInvalid(const std::string& message) override {
80 if (task_runner_->RunsTasksOnCurrentThread()) {
81 DCheckIfInvalid(router_, message);
82 } else {
83 task_runner_->PostTask(FROM_HERE,
84 base::Bind(&DCheckIfInvalid, router_, message));
85 }
86 }
87
88 private:
89 base::WeakPtr<Router> router_;
90 bool accept_was_invoked_;
91 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
92 };
93
94 } // namespace
95
96 // ----------------------------------------------------------------------------
97
SyncResponseInfo(bool * in_response_received)98 Router::SyncResponseInfo::SyncResponseInfo(bool* in_response_received)
99 : response_received(in_response_received) {}
100
~SyncResponseInfo()101 Router::SyncResponseInfo::~SyncResponseInfo() {}
102
103 // ----------------------------------------------------------------------------
104
HandleIncomingMessageThunk(Router * router)105 Router::HandleIncomingMessageThunk::HandleIncomingMessageThunk(Router* router)
106 : router_(router) {
107 }
108
~HandleIncomingMessageThunk()109 Router::HandleIncomingMessageThunk::~HandleIncomingMessageThunk() {
110 }
111
Accept(Message * message)112 bool Router::HandleIncomingMessageThunk::Accept(Message* message) {
113 return router_->HandleIncomingMessage(message);
114 }
115
116 // ----------------------------------------------------------------------------
117
Router(ScopedMessagePipeHandle message_pipe,FilterChain filters,bool expects_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner)118 Router::Router(ScopedMessagePipeHandle message_pipe,
119 FilterChain filters,
120 bool expects_sync_requests,
121 scoped_refptr<base::SingleThreadTaskRunner> runner)
122 : thunk_(this),
123 filters_(std::move(filters)),
124 connector_(std::move(message_pipe),
125 Connector::SINGLE_THREADED_SEND,
126 std::move(runner)),
127 incoming_receiver_(nullptr),
128 next_request_id_(0),
129 testing_mode_(false),
130 pending_task_for_messages_(false),
131 encountered_error_(false),
132 weak_factory_(this) {
133 filters_.SetSink(&thunk_);
134 if (expects_sync_requests)
135 connector_.AllowWokenUpBySyncWatchOnSameThread();
136 connector_.set_incoming_receiver(filters_.GetHead());
137 connector_.set_connection_error_handler(
138 base::Bind(&Router::OnConnectionError, base::Unretained(this)));
139 }
140
~Router()141 Router::~Router() {}
142
Accept(Message * message)143 bool Router::Accept(Message* message) {
144 DCHECK(thread_checker_.CalledOnValidThread());
145 DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
146 return connector_.Accept(message);
147 }
148
AcceptWithResponder(Message * message,MessageReceiver * responder)149 bool Router::AcceptWithResponder(Message* message, MessageReceiver* responder) {
150 DCHECK(thread_checker_.CalledOnValidThread());
151 DCHECK(message->has_flag(Message::kFlagExpectsResponse));
152
153 // Reserve 0 in case we want it to convey special meaning in the future.
154 uint64_t request_id = next_request_id_++;
155 if (request_id == 0)
156 request_id = next_request_id_++;
157
158 bool is_sync = message->has_flag(Message::kFlagIsSync);
159 message->set_request_id(request_id);
160 if (!connector_.Accept(message))
161 return false;
162
163 if (!is_sync) {
164 // We assume ownership of |responder|.
165 async_responders_[request_id] = base::WrapUnique(responder);
166 return true;
167 }
168
169 SyncCallRestrictions::AssertSyncCallAllowed();
170
171 bool response_received = false;
172 std::unique_ptr<MessageReceiver> sync_responder(responder);
173 sync_responses_.insert(std::make_pair(
174 request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));
175
176 base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
177 connector_.SyncWatch(&response_received);
178 // Make sure that this instance hasn't been destroyed.
179 if (weak_self) {
180 DCHECK(ContainsKey(sync_responses_, request_id));
181 auto iter = sync_responses_.find(request_id);
182 DCHECK_EQ(&response_received, iter->second->response_received);
183 if (response_received) {
184 std::unique_ptr<Message> response = std::move(iter->second->response);
185 ignore_result(sync_responder->Accept(response.get()));
186 }
187 sync_responses_.erase(iter);
188 }
189
190 // Return true means that we take ownership of |responder|.
191 return true;
192 }
193
EnableTestingMode()194 void Router::EnableTestingMode() {
195 DCHECK(thread_checker_.CalledOnValidThread());
196 testing_mode_ = true;
197 connector_.set_enforce_errors_from_incoming_receiver(false);
198 }
199
HandleIncomingMessage(Message * message)200 bool Router::HandleIncomingMessage(Message* message) {
201 DCHECK(thread_checker_.CalledOnValidThread());
202
203 const bool during_sync_call =
204 connector_.during_sync_handle_watcher_callback();
205 if (!message->has_flag(Message::kFlagIsSync) &&
206 (during_sync_call || !pending_messages_.empty())) {
207 std::unique_ptr<Message> pending_message(new Message);
208 message->MoveTo(pending_message.get());
209 pending_messages_.push(std::move(pending_message));
210
211 if (!pending_task_for_messages_) {
212 pending_task_for_messages_ = true;
213 connector_.task_runner()->PostTask(
214 FROM_HERE, base::Bind(&Router::HandleQueuedMessages,
215 weak_factory_.GetWeakPtr()));
216 }
217
218 return true;
219 }
220
221 return HandleMessageInternal(message);
222 }
223
HandleQueuedMessages()224 void Router::HandleQueuedMessages() {
225 DCHECK(thread_checker_.CalledOnValidThread());
226 DCHECK(pending_task_for_messages_);
227
228 base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
229 while (!pending_messages_.empty()) {
230 std::unique_ptr<Message> message(std::move(pending_messages_.front()));
231 pending_messages_.pop();
232
233 bool result = HandleMessageInternal(message.get());
234 if (!weak_self)
235 return;
236
237 if (!result && !testing_mode_) {
238 connector_.RaiseError();
239 break;
240 }
241 }
242
243 pending_task_for_messages_ = false;
244
245 // We may have already seen a connection error from the connector, but
246 // haven't notified the user because we want to process all the queued
247 // messages first. We should do it now.
248 if (connector_.encountered_error() && !encountered_error_)
249 OnConnectionError();
250 }
251
HandleMessageInternal(Message * message)252 bool Router::HandleMessageInternal(Message* message) {
253 if (message->has_flag(Message::kFlagExpectsResponse)) {
254 if (!incoming_receiver_)
255 return false;
256
257 MessageReceiverWithStatus* responder = new ResponderThunk(
258 weak_factory_.GetWeakPtr(), connector_.task_runner());
259 bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
260 if (!ok)
261 delete responder;
262 return ok;
263
264 } else if (message->has_flag(Message::kFlagIsResponse)) {
265 uint64_t request_id = message->request_id();
266
267 if (message->has_flag(Message::kFlagIsSync)) {
268 auto it = sync_responses_.find(request_id);
269 if (it == sync_responses_.end()) {
270 DCHECK(testing_mode_);
271 return false;
272 }
273 it->second->response.reset(new Message());
274 message->MoveTo(it->second->response.get());
275 *it->second->response_received = true;
276 return true;
277 }
278
279 auto it = async_responders_.find(request_id);
280 if (it == async_responders_.end()) {
281 DCHECK(testing_mode_);
282 return false;
283 }
284 std::unique_ptr<MessageReceiver> responder = std::move(it->second);
285 async_responders_.erase(it);
286 return responder->Accept(message);
287 } else {
288 if (!incoming_receiver_)
289 return false;
290
291 return incoming_receiver_->Accept(message);
292 }
293 }
294
OnConnectionError()295 void Router::OnConnectionError() {
296 if (encountered_error_)
297 return;
298
299 if (!pending_messages_.empty()) {
300 // After all the pending messages are processed, we will check whether an
301 // error has been encountered and run the user's connection error handler
302 // if necessary.
303 DCHECK(pending_task_for_messages_);
304 return;
305 }
306
307 if (connector_.during_sync_handle_watcher_callback()) {
308 // We don't want the error handler to reenter an ongoing sync call.
309 connector_.task_runner()->PostTask(
310 FROM_HERE,
311 base::Bind(&Router::OnConnectionError, weak_factory_.GetWeakPtr()));
312 return;
313 }
314
315 encountered_error_ = true;
316 if (!error_handler_.is_null())
317 error_handler_.Run();
318 }
319
320 // ----------------------------------------------------------------------------
321
322 } // namespace internal
323 } // namespace mojo
324