1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/edk/system/channel.h"
6
7 #include <stdint.h>
8 #include <windows.h>
9
10 #include <algorithm>
11 #include <deque>
12 #include <limits>
13 #include <memory>
14
15 #include "base/bind.h"
16 #include "base/location.h"
17 #include "base/macros.h"
18 #include "base/memory/ref_counted.h"
19 #include "base/message_loop/message_loop.h"
20 #include "base/synchronization/lock.h"
21 #include "base/task_runner.h"
22 #include "mojo/edk/embedder/platform_handle_vector.h"
23
24 namespace mojo {
25 namespace edk {
26
27 namespace {
28
29 // A view over a Channel::Message object. The write queue uses these since
30 // large messages may need to be sent in chunks.
31 class MessageView {
32 public:
33 // Owns |message|. |offset| indexes the first unsent byte in the message.
MessageView(Channel::MessagePtr message,size_t offset)34 MessageView(Channel::MessagePtr message, size_t offset)
35 : message_(std::move(message)),
36 offset_(offset) {
37 DCHECK_GT(message_->data_num_bytes(), offset_);
38 }
39
MessageView(MessageView && other)40 MessageView(MessageView&& other) { *this = std::move(other); }
41
operator =(MessageView && other)42 MessageView& operator=(MessageView&& other) {
43 message_ = std::move(other.message_);
44 offset_ = other.offset_;
45 return *this;
46 }
47
~MessageView()48 ~MessageView() {}
49
data() const50 const void* data() const {
51 return static_cast<const char*>(message_->data()) + offset_;
52 }
53
data_num_bytes() const54 size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; }
55
data_offset() const56 size_t data_offset() const { return offset_; }
advance_data_offset(size_t num_bytes)57 void advance_data_offset(size_t num_bytes) {
58 DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes);
59 offset_ += num_bytes;
60 }
61
TakeChannelMessage()62 Channel::MessagePtr TakeChannelMessage() { return std::move(message_); }
63
64 private:
65 Channel::MessagePtr message_;
66 size_t offset_;
67
68 DISALLOW_COPY_AND_ASSIGN(MessageView);
69 };
70
71 class ChannelWin : public Channel,
72 public base::MessageLoop::DestructionObserver,
73 public base::MessageLoopForIO::IOHandler {
74 public:
ChannelWin(Delegate * delegate,ScopedPlatformHandle handle,scoped_refptr<base::TaskRunner> io_task_runner)75 ChannelWin(Delegate* delegate,
76 ScopedPlatformHandle handle,
77 scoped_refptr<base::TaskRunner> io_task_runner)
78 : Channel(delegate),
79 self_(this),
80 handle_(std::move(handle)),
81 io_task_runner_(io_task_runner) {
82 CHECK(handle_.is_valid());
83
84 wait_for_connect_ = handle_.get().needs_connection;
85 }
86
Start()87 void Start() override {
88 io_task_runner_->PostTask(
89 FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this));
90 }
91
ShutDownImpl()92 void ShutDownImpl() override {
93 // Always shut down asynchronously when called through the public interface.
94 io_task_runner_->PostTask(
95 FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this));
96 }
97
Write(MessagePtr message)98 void Write(MessagePtr message) override {
99 bool write_error = false;
100 {
101 base::AutoLock lock(write_lock_);
102 if (reject_writes_)
103 return;
104
105 bool write_now = !delay_writes_ && outgoing_messages_.empty();
106 outgoing_messages_.emplace_back(std::move(message), 0);
107
108 if (write_now && !WriteNoLock(outgoing_messages_.front()))
109 reject_writes_ = write_error = true;
110 }
111 if (write_error) {
112 // Do not synchronously invoke OnError(). Write() may have been called by
113 // the delegate and we don't want to re-enter it.
114 io_task_runner_->PostTask(FROM_HERE,
115 base::Bind(&ChannelWin::OnError, this));
116 }
117 }
118
LeakHandle()119 void LeakHandle() override {
120 DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
121 leak_handle_ = true;
122 }
123
GetReadPlatformHandles(size_t num_handles,const void * extra_header,size_t extra_header_size,ScopedPlatformHandleVectorPtr * handles)124 bool GetReadPlatformHandles(
125 size_t num_handles,
126 const void* extra_header,
127 size_t extra_header_size,
128 ScopedPlatformHandleVectorPtr* handles) override {
129 if (num_handles > std::numeric_limits<uint16_t>::max())
130 return false;
131 using HandleEntry = Channel::Message::HandleEntry;
132 size_t handles_size = sizeof(HandleEntry) * num_handles;
133 if (handles_size > extra_header_size)
134 return false;
135 DCHECK(extra_header);
136 handles->reset(new PlatformHandleVector(num_handles));
137 const HandleEntry* extra_header_handles =
138 reinterpret_cast<const HandleEntry*>(extra_header);
139 for (size_t i = 0; i < num_handles; i++) {
140 (*handles)->at(i).handle = reinterpret_cast<HANDLE>(
141 static_cast<uintptr_t>(extra_header_handles[i].handle));
142 }
143 return true;
144 }
145
146 private:
147 // May run on any thread.
~ChannelWin()148 ~ChannelWin() override {}
149
StartOnIOThread()150 void StartOnIOThread() {
151 base::MessageLoop::current()->AddDestructionObserver(this);
152 base::MessageLoopForIO::current()->RegisterIOHandler(
153 handle_.get().handle, this);
154
155 if (wait_for_connect_) {
156 BOOL ok = ConnectNamedPipe(handle_.get().handle,
157 &connect_context_.overlapped);
158 if (ok) {
159 PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
160 OnError();
161 return;
162 }
163
164 const DWORD err = GetLastError();
165 switch (err) {
166 case ERROR_PIPE_CONNECTED:
167 wait_for_connect_ = false;
168 break;
169 case ERROR_IO_PENDING:
170 AddRef();
171 return;
172 case ERROR_NO_DATA:
173 OnError();
174 return;
175 }
176 }
177
178 // Now that we have registered our IOHandler, we can start writing.
179 {
180 base::AutoLock lock(write_lock_);
181 if (delay_writes_) {
182 delay_writes_ = false;
183 WriteNextNoLock();
184 }
185 }
186
187 // Keep this alive in case we synchronously run shutdown.
188 scoped_refptr<ChannelWin> keep_alive(this);
189 ReadMore(0);
190 }
191
ShutDownOnIOThread()192 void ShutDownOnIOThread() {
193 base::MessageLoop::current()->RemoveDestructionObserver(this);
194
195 // BUG(crbug.com/583525): This function is expected to be called once, and
196 // |handle_| should be valid at this point.
197 CHECK(handle_.is_valid());
198 CancelIo(handle_.get().handle);
199 if (leak_handle_)
200 ignore_result(handle_.release());
201 handle_.reset();
202
203 // May destroy the |this| if it was the last reference.
204 self_ = nullptr;
205 }
206
207 // base::MessageLoop::DestructionObserver:
WillDestroyCurrentMessageLoop()208 void WillDestroyCurrentMessageLoop() override {
209 DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
210 if (self_)
211 ShutDownOnIOThread();
212 }
213
214 // base::MessageLoop::IOHandler:
OnIOCompleted(base::MessageLoopForIO::IOContext * context,DWORD bytes_transfered,DWORD error)215 void OnIOCompleted(base::MessageLoopForIO::IOContext* context,
216 DWORD bytes_transfered,
217 DWORD error) override {
218 if (error != ERROR_SUCCESS) {
219 OnError();
220 } else if (context == &connect_context_) {
221 DCHECK(wait_for_connect_);
222 wait_for_connect_ = false;
223 ReadMore(0);
224
225 base::AutoLock lock(write_lock_);
226 if (delay_writes_) {
227 delay_writes_ = false;
228 WriteNextNoLock();
229 }
230 } else if (context == &read_context_) {
231 OnReadDone(static_cast<size_t>(bytes_transfered));
232 } else {
233 CHECK(context == &write_context_);
234 OnWriteDone(static_cast<size_t>(bytes_transfered));
235 }
236 Release(); // Balancing reference taken after ReadFile / WriteFile.
237 }
238
OnReadDone(size_t bytes_read)239 void OnReadDone(size_t bytes_read) {
240 if (bytes_read > 0) {
241 size_t next_read_size = 0;
242 if (OnReadComplete(bytes_read, &next_read_size)) {
243 ReadMore(next_read_size);
244 } else {
245 OnError();
246 }
247 } else if (bytes_read == 0) {
248 OnError();
249 }
250 }
251
OnWriteDone(size_t bytes_written)252 void OnWriteDone(size_t bytes_written) {
253 if (bytes_written == 0)
254 return;
255
256 bool write_error = false;
257 {
258 base::AutoLock lock(write_lock_);
259
260 DCHECK(!outgoing_messages_.empty());
261
262 MessageView& message_view = outgoing_messages_.front();
263 message_view.advance_data_offset(bytes_written);
264 if (message_view.data_num_bytes() == 0) {
265 Channel::MessagePtr message = message_view.TakeChannelMessage();
266 outgoing_messages_.pop_front();
267
268 // Clear any handles so they don't get closed on destruction.
269 ScopedPlatformHandleVectorPtr handles = message->TakeHandles();
270 if (handles)
271 handles->clear();
272 }
273
274 if (!WriteNextNoLock())
275 reject_writes_ = write_error = true;
276 }
277 if (write_error)
278 OnError();
279 }
280
ReadMore(size_t next_read_size_hint)281 void ReadMore(size_t next_read_size_hint) {
282 size_t buffer_capacity = next_read_size_hint;
283 char* buffer = GetReadBuffer(&buffer_capacity);
284 DCHECK_GT(buffer_capacity, 0u);
285
286 BOOL ok = ReadFile(handle_.get().handle,
287 buffer,
288 static_cast<DWORD>(buffer_capacity),
289 NULL,
290 &read_context_.overlapped);
291
292 if (ok || GetLastError() == ERROR_IO_PENDING) {
293 AddRef(); // Will be balanced in OnIOCompleted
294 } else {
295 OnError();
296 }
297 }
298
299 // Attempts to write a message directly to the channel. If the full message
300 // cannot be written, it's queued and a wait is initiated to write the message
301 // ASAP on the I/O thread.
WriteNoLock(const MessageView & message_view)302 bool WriteNoLock(const MessageView& message_view) {
303 BOOL ok = WriteFile(handle_.get().handle,
304 message_view.data(),
305 static_cast<DWORD>(message_view.data_num_bytes()),
306 NULL,
307 &write_context_.overlapped);
308
309 if (ok || GetLastError() == ERROR_IO_PENDING) {
310 AddRef(); // Will be balanced in OnIOCompleted.
311 return true;
312 }
313 return false;
314 }
315
WriteNextNoLock()316 bool WriteNextNoLock() {
317 if (outgoing_messages_.empty())
318 return true;
319 return WriteNoLock(outgoing_messages_.front());
320 }
321
322 // Keeps the Channel alive at least until explicit shutdown on the IO thread.
323 scoped_refptr<Channel> self_;
324
325 ScopedPlatformHandle handle_;
326 scoped_refptr<base::TaskRunner> io_task_runner_;
327
328 base::MessageLoopForIO::IOContext connect_context_;
329 base::MessageLoopForIO::IOContext read_context_;
330 base::MessageLoopForIO::IOContext write_context_;
331
332 // Protects |reject_writes_| and |outgoing_messages_|.
333 base::Lock write_lock_;
334
335 bool delay_writes_ = true;
336
337 bool reject_writes_ = false;
338 std::deque<MessageView> outgoing_messages_;
339
340 bool wait_for_connect_;
341
342 bool leak_handle_ = false;
343
344 DISALLOW_COPY_AND_ASSIGN(ChannelWin);
345 };
346
347 } // namespace
348
349 // static
Create(Delegate * delegate,ScopedPlatformHandle platform_handle,scoped_refptr<base::TaskRunner> io_task_runner)350 scoped_refptr<Channel> Channel::Create(
351 Delegate* delegate,
352 ScopedPlatformHandle platform_handle,
353 scoped_refptr<base::TaskRunner> io_task_runner) {
354 return new ChannelWin(delegate, std::move(platform_handle), io_task_runner);
355 }
356
357 } // namespace edk
358 } // namespace mojo
359