• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/core/channel.h"
6 
7 #include <stdint.h>
8 #include <windows.h>
9 
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13 
14 #include "base/bind.h"
15 #include "base/containers/queue.h"
16 #include "base/location.h"
17 #include "base/macros.h"
18 #include "base/memory/ref_counted.h"
19 #include "base/message_loop/message_loop_current.h"
20 #include "base/message_loop/message_pump_for_io.h"
21 #include "base/process/process_handle.h"
22 #include "base/synchronization/lock.h"
23 #include "base/task_runner.h"
24 #include "base/win/scoped_handle.h"
25 #include "base/win/win_util.h"
26 
27 namespace mojo {
28 namespace core {
29 
30 namespace {
31 
32 class ChannelWin : public Channel,
33                    public base::MessageLoopCurrent::DestructionObserver,
34                    public base::MessagePumpForIO::IOHandler {
35  public:
ChannelWin(Delegate * delegate,ConnectionParams connection_params,scoped_refptr<base::TaskRunner> io_task_runner)36   ChannelWin(Delegate* delegate,
37              ConnectionParams connection_params,
38              scoped_refptr<base::TaskRunner> io_task_runner)
39       : Channel(delegate), self_(this), io_task_runner_(io_task_runner) {
40     if (connection_params.server_endpoint().is_valid()) {
41       handle_ = connection_params.TakeServerEndpoint()
42                     .TakePlatformHandle()
43                     .TakeHandle();
44       needs_connection_ = true;
45     } else {
46       handle_ =
47           connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
48     }
49 
50     CHECK(handle_.IsValid());
51   }
52 
Start()53   void Start() override {
54     io_task_runner_->PostTask(
55         FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
56   }
57 
ShutDownImpl()58   void ShutDownImpl() override {
59     // Always shut down asynchronously when called through the public interface.
60     io_task_runner_->PostTask(
61         FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
62   }
63 
Write(MessagePtr message)64   void Write(MessagePtr message) override {
65     if (remote_process().is_valid()) {
66       // If we know the remote process handle, we transfer all outgoing handles
67       // to the process now rewriting them in the message.
68       std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
69       for (auto& handle : handles) {
70         if (handle.handle().is_valid())
71           handle.TransferToProcess(remote_process().Clone());
72       }
73       message->SetHandles(std::move(handles));
74     }
75 
76     bool write_error = false;
77     {
78       base::AutoLock lock(write_lock_);
79       if (reject_writes_)
80         return;
81 
82       bool write_now = !delay_writes_ && outgoing_messages_.empty();
83       outgoing_messages_.emplace_back(std::move(message));
84       if (write_now && !WriteNoLock(outgoing_messages_.front()))
85         reject_writes_ = write_error = true;
86     }
87     if (write_error) {
88       // Do not synchronously invoke OnWriteError(). Write() may have been
89       // called by the delegate and we don't want to re-enter it.
90       io_task_runner_->PostTask(FROM_HERE,
91                                 base::BindOnce(&ChannelWin::OnWriteError, this,
92                                                Error::kDisconnected));
93     }
94   }
95 
LeakHandle()96   void LeakHandle() override {
97     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
98     leak_handle_ = true;
99   }
100 
GetReadPlatformHandles(const void * payload,size_t payload_size,size_t num_handles,const void * extra_header,size_t extra_header_size,std::vector<PlatformHandle> * handles,bool * deferred)101   bool GetReadPlatformHandles(const void* payload,
102                               size_t payload_size,
103                               size_t num_handles,
104                               const void* extra_header,
105                               size_t extra_header_size,
106                               std::vector<PlatformHandle>* handles,
107                               bool* deferred) override {
108     DCHECK(extra_header);
109     if (num_handles > std::numeric_limits<uint16_t>::max())
110       return false;
111     using HandleEntry = Channel::Message::HandleEntry;
112     size_t handles_size = sizeof(HandleEntry) * num_handles;
113     if (handles_size > extra_header_size)
114       return false;
115     handles->reserve(num_handles);
116     const HandleEntry* extra_header_handles =
117         reinterpret_cast<const HandleEntry*>(extra_header);
118     for (size_t i = 0; i < num_handles; i++) {
119       HANDLE handle_value =
120           base::win::Uint32ToHandle(extra_header_handles[i].handle);
121       if (remote_process().is_valid()) {
122         // If we know the remote process's handle, we assume it doesn't know
123         // ours; that means any handle values still belong to that process, and
124         // we need to transfer them to this process.
125         handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
126                            handle_value, remote_process().get())
127                            .ReleaseHandle();
128       }
129       handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
130     }
131     return true;
132   }
133 
134  private:
135   // May run on any thread.
~ChannelWin()136   ~ChannelWin() override {}
137 
StartOnIOThread()138   void StartOnIOThread() {
139     base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
140     base::MessageLoopCurrentForIO::Get()->RegisterIOHandler(handle_.Get(),
141                                                             this);
142 
143     if (needs_connection_) {
144       BOOL ok = ::ConnectNamedPipe(handle_.Get(), &connect_context_.overlapped);
145       if (ok) {
146         PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
147         OnError(Error::kConnectionFailed);
148         return;
149       }
150 
151       const DWORD err = GetLastError();
152       switch (err) {
153         case ERROR_PIPE_CONNECTED:
154           break;
155         case ERROR_IO_PENDING:
156           is_connect_pending_ = true;
157           AddRef();
158           return;
159         case ERROR_NO_DATA:
160         default:
161           OnError(Error::kConnectionFailed);
162           return;
163       }
164     }
165 
166     // Now that we have registered our IOHandler, we can start writing.
167     {
168       base::AutoLock lock(write_lock_);
169       if (delay_writes_) {
170         delay_writes_ = false;
171         WriteNextNoLock();
172       }
173     }
174 
175     // Keep this alive in case we synchronously run shutdown, via OnError(),
176     // as a result of a ReadFile() failure on the channel.
177     scoped_refptr<ChannelWin> keep_alive(this);
178     ReadMore(0);
179   }
180 
ShutDownOnIOThread()181   void ShutDownOnIOThread() {
182     base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
183 
184     // TODO(https://crbug.com/583525): This function is expected to be called
185     // once, and |handle_| should be valid at this point.
186     CHECK(handle_.IsValid());
187     CancelIo(handle_.Get());
188     if (leak_handle_)
189       ignore_result(handle_.Take());
190     else
191       handle_.Close();
192 
193     // Allow |this| to be destroyed as soon as no IO is pending.
194     self_ = nullptr;
195   }
196 
197   // base::MessageLoopCurrent::DestructionObserver:
WillDestroyCurrentMessageLoop()198   void WillDestroyCurrentMessageLoop() override {
199     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
200     if (self_)
201       ShutDownOnIOThread();
202   }
203 
204   // base::MessageLoop::IOHandler:
OnIOCompleted(base::MessagePumpForIO::IOContext * context,DWORD bytes_transfered,DWORD error)205   void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
206                      DWORD bytes_transfered,
207                      DWORD error) override {
208     if (error != ERROR_SUCCESS) {
209       if (context == &write_context_) {
210         {
211           base::AutoLock lock(write_lock_);
212           reject_writes_ = true;
213         }
214         OnWriteError(Error::kDisconnected);
215       } else {
216         OnError(Error::kDisconnected);
217       }
218     } else if (context == &connect_context_) {
219       DCHECK(is_connect_pending_);
220       is_connect_pending_ = false;
221       ReadMore(0);
222 
223       base::AutoLock lock(write_lock_);
224       if (delay_writes_) {
225         delay_writes_ = false;
226         WriteNextNoLock();
227       }
228     } else if (context == &read_context_) {
229       OnReadDone(static_cast<size_t>(bytes_transfered));
230     } else {
231       CHECK(context == &write_context_);
232       OnWriteDone(static_cast<size_t>(bytes_transfered));
233     }
234     Release();
235   }
236 
OnReadDone(size_t bytes_read)237   void OnReadDone(size_t bytes_read) {
238     DCHECK(is_read_pending_);
239     is_read_pending_ = false;
240 
241     if (bytes_read > 0) {
242       size_t next_read_size = 0;
243       if (OnReadComplete(bytes_read, &next_read_size)) {
244         ReadMore(next_read_size);
245       } else {
246         OnError(Error::kReceivedMalformedData);
247       }
248     } else if (bytes_read == 0) {
249       OnError(Error::kDisconnected);
250     }
251   }
252 
OnWriteDone(size_t bytes_written)253   void OnWriteDone(size_t bytes_written) {
254     if (bytes_written == 0)
255       return;
256 
257     bool write_error = false;
258     {
259       base::AutoLock lock(write_lock_);
260 
261       DCHECK(is_write_pending_);
262       is_write_pending_ = false;
263       DCHECK(!outgoing_messages_.empty());
264 
265       Channel::MessagePtr message = std::move(outgoing_messages_.front());
266       outgoing_messages_.pop_front();
267 
268       // Invalidate all the scoped handles so we don't attempt to close them.
269       std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
270       for (auto& handle : handles)
271         handle.CompleteTransit();
272 
273       // Overlapped WriteFile() to a pipe should always fully complete.
274       if (message->data_num_bytes() != bytes_written)
275         reject_writes_ = write_error = true;
276       else if (!WriteNextNoLock())
277         reject_writes_ = write_error = true;
278     }
279     if (write_error)
280       OnWriteError(Error::kDisconnected);
281   }
282 
ReadMore(size_t next_read_size_hint)283   void ReadMore(size_t next_read_size_hint) {
284     DCHECK(!is_read_pending_);
285 
286     size_t buffer_capacity = next_read_size_hint;
287     char* buffer = GetReadBuffer(&buffer_capacity);
288     DCHECK_GT(buffer_capacity, 0u);
289 
290     BOOL ok =
291         ::ReadFile(handle_.Get(), buffer, static_cast<DWORD>(buffer_capacity),
292                    NULL, &read_context_.overlapped);
293     if (ok || GetLastError() == ERROR_IO_PENDING) {
294       is_read_pending_ = true;
295       AddRef();
296     } else {
297       OnError(Error::kDisconnected);
298     }
299   }
300 
301   // Attempts to write a message directly to the channel. If the full message
302   // cannot be written, it's queued and a wait is initiated to write the message
303   // ASAP on the I/O thread.
WriteNoLock(const Channel::MessagePtr & message)304   bool WriteNoLock(const Channel::MessagePtr& message) {
305     BOOL ok = WriteFile(handle_.Get(), message->data(),
306                         static_cast<DWORD>(message->data_num_bytes()), NULL,
307                         &write_context_.overlapped);
308     if (ok || GetLastError() == ERROR_IO_PENDING) {
309       is_write_pending_ = true;
310       AddRef();
311       return true;
312     }
313     return false;
314   }
315 
WriteNextNoLock()316   bool WriteNextNoLock() {
317     if (outgoing_messages_.empty())
318       return true;
319     return WriteNoLock(outgoing_messages_.front());
320   }
321 
OnWriteError(Error error)322   void OnWriteError(Error error) {
323     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
324     DCHECK(reject_writes_);
325 
326     if (error == Error::kDisconnected) {
327       // If we can't write because the pipe is disconnected then continue
328       // reading to fetch any in-flight messages, relying on end-of-stream to
329       // signal the actual disconnection.
330       if (is_read_pending_ || is_connect_pending_)
331         return;
332     }
333 
334     OnError(error);
335   }
336 
337   // Keeps the Channel alive at least until explicit shutdown on the IO thread.
338   scoped_refptr<Channel> self_;
339 
340   // The pipe handle this Channel uses for communication.
341   base::win::ScopedHandle handle_;
342 
343   // Indicates whether |handle_| must wait for a connection.
344   bool needs_connection_ = false;
345 
346   const scoped_refptr<base::TaskRunner> io_task_runner_;
347 
348   base::MessagePumpForIO::IOContext connect_context_;
349   base::MessagePumpForIO::IOContext read_context_;
350   bool is_connect_pending_ = false;
351   bool is_read_pending_ = false;
352 
353   // Protects all fields potentially accessed on multiple threads via Write().
354   base::Lock write_lock_;
355   base::MessagePumpForIO::IOContext write_context_;
356   base::circular_deque<Channel::MessagePtr> outgoing_messages_;
357   bool delay_writes_ = true;
358   bool reject_writes_ = false;
359   bool is_write_pending_ = false;
360 
361   bool leak_handle_ = false;
362 
363   DISALLOW_COPY_AND_ASSIGN(ChannelWin);
364 };
365 
366 }  // namespace
367 
368 // static
Create(Delegate * delegate,ConnectionParams connection_params,scoped_refptr<base::TaskRunner> io_task_runner)369 scoped_refptr<Channel> Channel::Create(
370     Delegate* delegate,
371     ConnectionParams connection_params,
372     scoped_refptr<base::TaskRunner> io_task_runner) {
373   return new ChannelWin(delegate, std::move(connection_params), io_task_runner);
374 }
375 
376 }  // namespace core
377 }  // namespace mojo
378