• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/edk/system/channel.h"
6 
7 #include <stdint.h>
8 #include <windows.h>
9 
10 #include <algorithm>
11 #include <deque>
12 #include <limits>
13 #include <memory>
14 
15 #include "base/bind.h"
16 #include "base/location.h"
17 #include "base/macros.h"
18 #include "base/memory/ref_counted.h"
19 #include "base/message_loop/message_loop.h"
20 #include "base/synchronization/lock.h"
21 #include "base/task_runner.h"
22 #include "mojo/edk/embedder/platform_handle_vector.h"
23 
24 namespace mojo {
25 namespace edk {
26 
27 namespace {
28 
29 // A view over a Channel::Message object. The write queue uses these since
30 // large messages may need to be sent in chunks.
31 class MessageView {
32  public:
33   // Owns |message|. |offset| indexes the first unsent byte in the message.
MessageView(Channel::MessagePtr message,size_t offset)34   MessageView(Channel::MessagePtr message, size_t offset)
35       : message_(std::move(message)),
36         offset_(offset) {
37     DCHECK_GT(message_->data_num_bytes(), offset_);
38   }
39 
MessageView(MessageView && other)40   MessageView(MessageView&& other) { *this = std::move(other); }
41 
operator =(MessageView && other)42   MessageView& operator=(MessageView&& other) {
43     message_ = std::move(other.message_);
44     offset_ = other.offset_;
45     return *this;
46   }
47 
~MessageView()48   ~MessageView() {}
49 
data() const50   const void* data() const {
51     return static_cast<const char*>(message_->data()) + offset_;
52   }
53 
data_num_bytes() const54   size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; }
55 
data_offset() const56   size_t data_offset() const { return offset_; }
advance_data_offset(size_t num_bytes)57   void advance_data_offset(size_t num_bytes) {
58     DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes);
59     offset_ += num_bytes;
60   }
61 
TakeChannelMessage()62   Channel::MessagePtr TakeChannelMessage() { return std::move(message_); }
63 
64  private:
65   Channel::MessagePtr message_;
66   size_t offset_;
67 
68   DISALLOW_COPY_AND_ASSIGN(MessageView);
69 };
70 
71 class ChannelWin : public Channel,
72                    public base::MessageLoop::DestructionObserver,
73                    public base::MessageLoopForIO::IOHandler {
74  public:
ChannelWin(Delegate * delegate,ScopedPlatformHandle handle,scoped_refptr<base::TaskRunner> io_task_runner)75   ChannelWin(Delegate* delegate,
76              ScopedPlatformHandle handle,
77              scoped_refptr<base::TaskRunner> io_task_runner)
78       : Channel(delegate),
79         self_(this),
80         handle_(std::move(handle)),
81         io_task_runner_(io_task_runner) {
82     CHECK(handle_.is_valid());
83 
84     wait_for_connect_ = handle_.get().needs_connection;
85   }
86 
Start()87   void Start() override {
88     io_task_runner_->PostTask(
89         FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this));
90   }
91 
ShutDownImpl()92   void ShutDownImpl() override {
93     // Always shut down asynchronously when called through the public interface.
94     io_task_runner_->PostTask(
95         FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this));
96   }
97 
Write(MessagePtr message)98   void Write(MessagePtr message) override {
99     bool write_error = false;
100     {
101       base::AutoLock lock(write_lock_);
102       if (reject_writes_)
103         return;
104 
105       bool write_now = !delay_writes_ && outgoing_messages_.empty();
106       outgoing_messages_.emplace_back(std::move(message), 0);
107 
108       if (write_now && !WriteNoLock(outgoing_messages_.front()))
109         reject_writes_ = write_error = true;
110     }
111     if (write_error) {
112       // Do not synchronously invoke OnError(). Write() may have been called by
113       // the delegate and we don't want to re-enter it.
114       io_task_runner_->PostTask(FROM_HERE,
115                                 base::Bind(&ChannelWin::OnError, this));
116     }
117   }
118 
LeakHandle()119   void LeakHandle() override {
120     DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
121     leak_handle_ = true;
122   }
123 
GetReadPlatformHandles(size_t num_handles,const void * extra_header,size_t extra_header_size,ScopedPlatformHandleVectorPtr * handles)124   bool GetReadPlatformHandles(
125       size_t num_handles,
126       const void* extra_header,
127       size_t extra_header_size,
128       ScopedPlatformHandleVectorPtr* handles) override {
129     if (num_handles > std::numeric_limits<uint16_t>::max())
130       return false;
131     using HandleEntry = Channel::Message::HandleEntry;
132     size_t handles_size = sizeof(HandleEntry) * num_handles;
133     if (handles_size > extra_header_size)
134       return false;
135     DCHECK(extra_header);
136     handles->reset(new PlatformHandleVector(num_handles));
137     const HandleEntry* extra_header_handles =
138         reinterpret_cast<const HandleEntry*>(extra_header);
139     for (size_t i = 0; i < num_handles; i++) {
140       (*handles)->at(i).handle = reinterpret_cast<HANDLE>(
141           static_cast<uintptr_t>(extra_header_handles[i].handle));
142     }
143     return true;
144   }
145 
146  private:
147   // May run on any thread.
~ChannelWin()148   ~ChannelWin() override {}
149 
StartOnIOThread()150   void StartOnIOThread() {
151     base::MessageLoop::current()->AddDestructionObserver(this);
152     base::MessageLoopForIO::current()->RegisterIOHandler(
153         handle_.get().handle, this);
154 
155     if (wait_for_connect_) {
156       BOOL ok = ConnectNamedPipe(handle_.get().handle,
157                                  &connect_context_.overlapped);
158       if (ok) {
159         PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
160         OnError();
161         return;
162       }
163 
164       const DWORD err = GetLastError();
165       switch (err) {
166         case ERROR_PIPE_CONNECTED:
167           wait_for_connect_ = false;
168           break;
169         case ERROR_IO_PENDING:
170           AddRef();
171           return;
172         case ERROR_NO_DATA:
173           OnError();
174           return;
175       }
176     }
177 
178     // Now that we have registered our IOHandler, we can start writing.
179     {
180       base::AutoLock lock(write_lock_);
181       if (delay_writes_) {
182         delay_writes_ = false;
183         WriteNextNoLock();
184       }
185     }
186 
187     // Keep this alive in case we synchronously run shutdown.
188     scoped_refptr<ChannelWin> keep_alive(this);
189     ReadMore(0);
190   }
191 
ShutDownOnIOThread()192   void ShutDownOnIOThread() {
193     base::MessageLoop::current()->RemoveDestructionObserver(this);
194 
195     // BUG(crbug.com/583525): This function is expected to be called once, and
196     // |handle_| should be valid at this point.
197     CHECK(handle_.is_valid());
198     CancelIo(handle_.get().handle);
199     if (leak_handle_)
200       ignore_result(handle_.release());
201     handle_.reset();
202 
203     // May destroy the |this| if it was the last reference.
204     self_ = nullptr;
205   }
206 
207   // base::MessageLoop::DestructionObserver:
WillDestroyCurrentMessageLoop()208   void WillDestroyCurrentMessageLoop() override {
209     DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
210     if (self_)
211       ShutDownOnIOThread();
212   }
213 
214   // base::MessageLoop::IOHandler:
OnIOCompleted(base::MessageLoopForIO::IOContext * context,DWORD bytes_transfered,DWORD error)215   void OnIOCompleted(base::MessageLoopForIO::IOContext* context,
216                      DWORD bytes_transfered,
217                      DWORD error) override {
218     if (error != ERROR_SUCCESS) {
219       OnError();
220     } else if (context == &connect_context_) {
221       DCHECK(wait_for_connect_);
222       wait_for_connect_ = false;
223       ReadMore(0);
224 
225       base::AutoLock lock(write_lock_);
226       if (delay_writes_) {
227         delay_writes_ = false;
228         WriteNextNoLock();
229       }
230     } else if (context == &read_context_) {
231       OnReadDone(static_cast<size_t>(bytes_transfered));
232     } else {
233       CHECK(context == &write_context_);
234       OnWriteDone(static_cast<size_t>(bytes_transfered));
235     }
236     Release();  // Balancing reference taken after ReadFile / WriteFile.
237   }
238 
OnReadDone(size_t bytes_read)239   void OnReadDone(size_t bytes_read) {
240     if (bytes_read > 0) {
241       size_t next_read_size = 0;
242       if (OnReadComplete(bytes_read, &next_read_size)) {
243         ReadMore(next_read_size);
244       } else {
245         OnError();
246       }
247     } else if (bytes_read == 0) {
248       OnError();
249     }
250   }
251 
OnWriteDone(size_t bytes_written)252   void OnWriteDone(size_t bytes_written) {
253     if (bytes_written == 0)
254       return;
255 
256     bool write_error = false;
257     {
258       base::AutoLock lock(write_lock_);
259 
260       DCHECK(!outgoing_messages_.empty());
261 
262       MessageView& message_view = outgoing_messages_.front();
263       message_view.advance_data_offset(bytes_written);
264       if (message_view.data_num_bytes() == 0) {
265         Channel::MessagePtr message = message_view.TakeChannelMessage();
266         outgoing_messages_.pop_front();
267 
268         // Clear any handles so they don't get closed on destruction.
269         ScopedPlatformHandleVectorPtr handles = message->TakeHandles();
270         if (handles)
271           handles->clear();
272       }
273 
274       if (!WriteNextNoLock())
275         reject_writes_ = write_error = true;
276     }
277     if (write_error)
278       OnError();
279   }
280 
ReadMore(size_t next_read_size_hint)281   void ReadMore(size_t next_read_size_hint) {
282     size_t buffer_capacity = next_read_size_hint;
283     char* buffer = GetReadBuffer(&buffer_capacity);
284     DCHECK_GT(buffer_capacity, 0u);
285 
286     BOOL ok = ReadFile(handle_.get().handle,
287                        buffer,
288                        static_cast<DWORD>(buffer_capacity),
289                        NULL,
290                        &read_context_.overlapped);
291 
292     if (ok || GetLastError() == ERROR_IO_PENDING) {
293       AddRef();  // Will be balanced in OnIOCompleted
294     } else {
295       OnError();
296     }
297   }
298 
299   // Attempts to write a message directly to the channel. If the full message
300   // cannot be written, it's queued and a wait is initiated to write the message
301   // ASAP on the I/O thread.
WriteNoLock(const MessageView & message_view)302   bool WriteNoLock(const MessageView& message_view) {
303     BOOL ok = WriteFile(handle_.get().handle,
304                         message_view.data(),
305                         static_cast<DWORD>(message_view.data_num_bytes()),
306                         NULL,
307                         &write_context_.overlapped);
308 
309     if (ok || GetLastError() == ERROR_IO_PENDING) {
310       AddRef();  // Will be balanced in OnIOCompleted.
311       return true;
312     }
313     return false;
314   }
315 
WriteNextNoLock()316   bool WriteNextNoLock() {
317     if (outgoing_messages_.empty())
318       return true;
319     return WriteNoLock(outgoing_messages_.front());
320   }
321 
322   // Keeps the Channel alive at least until explicit shutdown on the IO thread.
323   scoped_refptr<Channel> self_;
324 
325   ScopedPlatformHandle handle_;
326   scoped_refptr<base::TaskRunner> io_task_runner_;
327 
328   base::MessageLoopForIO::IOContext connect_context_;
329   base::MessageLoopForIO::IOContext read_context_;
330   base::MessageLoopForIO::IOContext write_context_;
331 
332   // Protects |reject_writes_| and |outgoing_messages_|.
333   base::Lock write_lock_;
334 
335   bool delay_writes_ = true;
336 
337   bool reject_writes_ = false;
338   std::deque<MessageView> outgoing_messages_;
339 
340   bool wait_for_connect_;
341 
342   bool leak_handle_ = false;
343 
344   DISALLOW_COPY_AND_ASSIGN(ChannelWin);
345 };
346 
347 }  // namespace
348 
349 // static
Create(Delegate * delegate,ScopedPlatformHandle platform_handle,scoped_refptr<base::TaskRunner> io_task_runner)350 scoped_refptr<Channel> Channel::Create(
351     Delegate* delegate,
352     ScopedPlatformHandle platform_handle,
353     scoped_refptr<base::TaskRunner> io_task_runner) {
354   return new ChannelWin(delegate, std::move(platform_handle), io_task_runner);
355 }
356 
357 }  // namespace edk
358 }  // namespace mojo
359