1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6
7 #include <stdint.h>
8
9 #include <utility>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/macros.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/stl_util.h"
17 #include "mojo/public/cpp/bindings/associated_group.h"
18 #include "mojo/public/cpp/bindings/associated_group_controller.h"
19 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
20 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
21
22 namespace mojo {
23
24 // ----------------------------------------------------------------------------
25
26 namespace {
27
DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient> & client,const std::string & message)28 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client,
29 const std::string& message) {
30 bool is_valid = client && !client->encountered_error();
31 DCHECK(!is_valid) << message;
32 }
33
34 // When receiving an incoming message which expects a repsonse,
35 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
36 // incoming message receiver. When the receiver finishes processing the message,
37 // it can provide a response using this object.
38 class ResponderThunk : public MessageReceiverWithStatus {
39 public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SingleThreadTaskRunner> runner)40 explicit ResponderThunk(
41 const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
42 scoped_refptr<base::SingleThreadTaskRunner> runner)
43 : endpoint_client_(endpoint_client),
44 accept_was_invoked_(false),
45 task_runner_(std::move(runner)) {}
~ResponderThunk()46 ~ResponderThunk() override {
47 if (!accept_was_invoked_) {
48 // The Mojo application handled a message that was expecting a response
49 // but did not send a response.
50 // We raise an error to signal the calling application that an error
51 // condition occurred. Without this the calling application would have no
52 // way of knowing it should stop waiting for a response.
53 if (task_runner_->RunsTasksOnCurrentThread()) {
54 // Please note that even if this code is run from a different task
55 // runner on the same thread as |task_runner_|, it is okay to directly
56 // call InterfaceEndpointClient::RaiseError(), because it will raise
57 // error from the correct task runner asynchronously.
58 if (endpoint_client_) {
59 endpoint_client_->RaiseError();
60 }
61 } else {
62 task_runner_->PostTask(
63 FROM_HERE,
64 base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
65 }
66 }
67 }
68
69 // MessageReceiver implementation:
Accept(Message * message)70 bool Accept(Message* message) override {
71 DCHECK(task_runner_->RunsTasksOnCurrentThread());
72 accept_was_invoked_ = true;
73 DCHECK(message->has_flag(Message::kFlagIsResponse));
74
75 bool result = false;
76
77 if (endpoint_client_)
78 result = endpoint_client_->Accept(message);
79
80 return result;
81 }
82
83 // MessageReceiverWithStatus implementation:
IsValid()84 bool IsValid() override {
85 DCHECK(task_runner_->RunsTasksOnCurrentThread());
86 return endpoint_client_ && !endpoint_client_->encountered_error();
87 }
88
DCheckInvalid(const std::string & message)89 void DCheckInvalid(const std::string& message) override {
90 if (task_runner_->RunsTasksOnCurrentThread()) {
91 DCheckIfInvalid(endpoint_client_, message);
92 } else {
93 task_runner_->PostTask(
94 FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message));
95 }
96 }
97
98 private:
99 base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
100 bool accept_was_invoked_;
101 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
102
103 DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
104 };
105
106 } // namespace
107
108 // ----------------------------------------------------------------------------
109
SyncResponseInfo(bool * in_response_received)110 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
111 bool* in_response_received)
112 : response_received(in_response_received) {}
113
~SyncResponseInfo()114 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
115
116 // ----------------------------------------------------------------------------
117
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)118 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
119 InterfaceEndpointClient* owner)
120 : owner_(owner) {}
121
122 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()123 ~HandleIncomingMessageThunk() {}
124
Accept(Message * message)125 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
126 Message* message) {
127 return owner_->HandleValidatedMessage(message);
128 }
129
130 // ----------------------------------------------------------------------------
131
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageFilter> payload_validator,bool expect_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner)132 InterfaceEndpointClient::InterfaceEndpointClient(
133 ScopedInterfaceEndpointHandle handle,
134 MessageReceiverWithResponderStatus* receiver,
135 std::unique_ptr<MessageFilter> payload_validator,
136 bool expect_sync_requests,
137 scoped_refptr<base::SingleThreadTaskRunner> runner)
138 : handle_(std::move(handle)),
139 incoming_receiver_(receiver),
140 payload_validator_(std::move(payload_validator)),
141 thunk_(this),
142 next_request_id_(1),
143 encountered_error_(false),
144 task_runner_(std::move(runner)),
145 weak_ptr_factory_(this) {
146 DCHECK(handle_.is_valid());
147 DCHECK(handle_.is_local());
148
149 // TODO(yzshen): the way to use validator (or message filter in general)
150 // directly is a little awkward.
151 payload_validator_->set_sink(&thunk_);
152
153 controller_ = handle_.group_controller()->AttachEndpointClient(
154 handle_, this, task_runner_);
155 if (expect_sync_requests)
156 controller_->AllowWokenUpBySyncWatchOnSameThread();
157 }
158
~InterfaceEndpointClient()159 InterfaceEndpointClient::~InterfaceEndpointClient() {
160 DCHECK(thread_checker_.CalledOnValidThread());
161
162 handle_.group_controller()->DetachEndpointClient(handle_);
163 }
164
associated_group()165 AssociatedGroup* InterfaceEndpointClient::associated_group() {
166 if (!associated_group_)
167 associated_group_ = handle_.group_controller()->CreateAssociatedGroup();
168 return associated_group_.get();
169 }
170
interface_id() const171 uint32_t InterfaceEndpointClient::interface_id() const {
172 DCHECK(thread_checker_.CalledOnValidThread());
173 return handle_.id();
174 }
175
PassHandle()176 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
177 DCHECK(thread_checker_.CalledOnValidThread());
178 DCHECK(!has_pending_responders());
179
180 if (!handle_.is_valid())
181 return ScopedInterfaceEndpointHandle();
182
183 controller_ = nullptr;
184 handle_.group_controller()->DetachEndpointClient(handle_);
185
186 return std::move(handle_);
187 }
188
RaiseError()189 void InterfaceEndpointClient::RaiseError() {
190 DCHECK(thread_checker_.CalledOnValidThread());
191
192 handle_.group_controller()->RaiseError();
193 }
194
Accept(Message * message)195 bool InterfaceEndpointClient::Accept(Message* message) {
196 DCHECK(thread_checker_.CalledOnValidThread());
197 DCHECK(controller_);
198 DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
199
200 if (encountered_error_)
201 return false;
202
203 return controller_->SendMessage(message);
204 }
205
AcceptWithResponder(Message * message,MessageReceiver * responder)206 bool InterfaceEndpointClient::AcceptWithResponder(Message* message,
207 MessageReceiver* responder) {
208 DCHECK(thread_checker_.CalledOnValidThread());
209 DCHECK(controller_);
210 DCHECK(message->has_flag(Message::kFlagExpectsResponse));
211
212 if (encountered_error_)
213 return false;
214
215 // Reserve 0 in case we want it to convey special meaning in the future.
216 uint64_t request_id = next_request_id_++;
217 if (request_id == 0)
218 request_id = next_request_id_++;
219
220 message->set_request_id(request_id);
221
222 bool is_sync = message->has_flag(Message::kFlagIsSync);
223 if (!controller_->SendMessage(message))
224 return false;
225
226 if (!is_sync) {
227 // We assume ownership of |responder|.
228 async_responders_[request_id] = base::WrapUnique(responder);
229 return true;
230 }
231
232 SyncCallRestrictions::AssertSyncCallAllowed();
233
234 bool response_received = false;
235 std::unique_ptr<MessageReceiver> sync_responder(responder);
236 sync_responses_.insert(std::make_pair(
237 request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));
238
239 base::WeakPtr<InterfaceEndpointClient> weak_self =
240 weak_ptr_factory_.GetWeakPtr();
241 controller_->SyncWatch(&response_received);
242 // Make sure that this instance hasn't been destroyed.
243 if (weak_self) {
244 DCHECK(ContainsKey(sync_responses_, request_id));
245 auto iter = sync_responses_.find(request_id);
246 DCHECK_EQ(&response_received, iter->second->response_received);
247 if (response_received) {
248 std::unique_ptr<Message> response = std::move(iter->second->response);
249 ignore_result(sync_responder->Accept(response.get()));
250 }
251 sync_responses_.erase(iter);
252 }
253
254 // Return true means that we take ownership of |responder|.
255 return true;
256 }
257
HandleIncomingMessage(Message * message)258 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
259 DCHECK(thread_checker_.CalledOnValidThread());
260
261 return payload_validator_->Accept(message);
262 }
263
NotifyError()264 void InterfaceEndpointClient::NotifyError() {
265 DCHECK(thread_checker_.CalledOnValidThread());
266
267 if (encountered_error_)
268 return;
269 encountered_error_ = true;
270 if (!error_handler_.is_null())
271 error_handler_.Run();
272 }
273
HandleValidatedMessage(Message * message)274 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
275 DCHECK_EQ(handle_.id(), message->interface_id());
276
277 if (message->has_flag(Message::kFlagExpectsResponse)) {
278 if (!incoming_receiver_)
279 return false;
280
281 MessageReceiverWithStatus* responder =
282 new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_);
283 bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
284 if (!ok)
285 delete responder;
286 return ok;
287 } else if (message->has_flag(Message::kFlagIsResponse)) {
288 uint64_t request_id = message->request_id();
289
290 if (message->has_flag(Message::kFlagIsSync)) {
291 auto it = sync_responses_.find(request_id);
292 if (it == sync_responses_.end())
293 return false;
294 it->second->response.reset(new Message());
295 message->MoveTo(it->second->response.get());
296 *it->second->response_received = true;
297 return true;
298 }
299
300 auto it = async_responders_.find(request_id);
301 if (it == async_responders_.end())
302 return false;
303 std::unique_ptr<MessageReceiver> responder = std::move(it->second);
304 async_responders_.erase(it);
305 return responder->Accept(message);
306 } else {
307 if (!incoming_receiver_)
308 return false;
309
310 return incoming_receiver_->Accept(message);
311 }
312 }
313
314 } // namespace mojo
315