1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <windows.h>
6
7 #include <limits>
8 #include <utility>
9
10 #include "base/debug/alias.h"
11 #include "base/memory/platform_shared_memory_region.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/strings/string_piece.h"
14 #include "mojo/core/broker.h"
15 #include "mojo/core/broker_messages.h"
16 #include "mojo/core/channel.h"
17 #include "mojo/core/platform_handle_utils.h"
18 #include "mojo/public/cpp/platform/named_platform_channel.h"
19
20 namespace mojo {
21 namespace core {
22
23 namespace {
24
25 // 256 bytes should be enough for anyone!
26 const size_t kMaxBrokerMessageSize = 256;
27
TakeHandlesFromBrokerMessage(Channel::Message * message,size_t num_handles,PlatformHandle * out_handles)28 bool TakeHandlesFromBrokerMessage(Channel::Message* message,
29 size_t num_handles,
30 PlatformHandle* out_handles) {
31 if (message->num_handles() != num_handles) {
32 DLOG(ERROR) << "Received unexpected number of handles in broker message";
33 return false;
34 }
35
36 std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
37 DCHECK_EQ(handles.size(), num_handles);
38 DCHECK(out_handles);
39
40 for (size_t i = 0; i < num_handles; ++i)
41 out_handles[i] = handles[i].TakeHandle();
42 return true;
43 }
44
WaitForBrokerMessage(HANDLE pipe_handle,BrokerMessageType expected_type)45 Channel::MessagePtr WaitForBrokerMessage(HANDLE pipe_handle,
46 BrokerMessageType expected_type) {
47 char buffer[kMaxBrokerMessageSize];
48 DWORD bytes_read = 0;
49 BOOL result = ::ReadFile(pipe_handle, buffer, kMaxBrokerMessageSize,
50 &bytes_read, nullptr);
51 if (!result) {
52 // The pipe may be broken if the browser side has been closed, e.g. during
53 // browser shutdown. In that case the ReadFile call will fail and we
54 // shouldn't continue waiting.
55 PLOG(ERROR) << "Error reading broker pipe";
56 return nullptr;
57 }
58
59 Channel::MessagePtr message =
60 Channel::Message::Deserialize(buffer, static_cast<size_t>(bytes_read));
61 if (!message || message->payload_size() < sizeof(BrokerMessageHeader)) {
62 LOG(ERROR) << "Invalid broker message";
63
64 base::debug::Alias(&buffer[0]);
65 base::debug::Alias(&bytes_read);
66 CHECK(false);
67 return nullptr;
68 }
69
70 const BrokerMessageHeader* header =
71 reinterpret_cast<const BrokerMessageHeader*>(message->payload());
72 if (header->type != expected_type) {
73 LOG(ERROR) << "Unexpected broker message type";
74
75 base::debug::Alias(&buffer[0]);
76 base::debug::Alias(&bytes_read);
77 CHECK(false);
78 return nullptr;
79 }
80
81 return message;
82 }
83
84 } // namespace
85
Broker(PlatformHandle handle)86 Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) {
87 CHECK(sync_channel_.is_valid());
88 Channel::MessagePtr message = WaitForBrokerMessage(
89 sync_channel_.GetHandle().Get(), BrokerMessageType::INIT);
90
91 // If we fail to read a message (broken pipe), just return early. The inviter
92 // handle will be null and callers must handle this gracefully.
93 if (!message)
94 return;
95
96 PlatformHandle endpoint_handle;
97 if (TakeHandlesFromBrokerMessage(message.get(), 1, &endpoint_handle)) {
98 inviter_endpoint_ = PlatformChannelEndpoint(std::move(endpoint_handle));
99 } else {
100 // If the message has no handles, we expect it to carry pipe name instead.
101 const BrokerMessageHeader* header =
102 static_cast<const BrokerMessageHeader*>(message->payload());
103 CHECK_GE(message->payload_size(),
104 sizeof(BrokerMessageHeader) + sizeof(InitData));
105 const InitData* data = reinterpret_cast<const InitData*>(header + 1);
106 CHECK_EQ(message->payload_size(),
107 sizeof(BrokerMessageHeader) + sizeof(InitData) +
108 data->pipe_name_length * sizeof(base::char16));
109 const base::char16* name_data =
110 reinterpret_cast<const base::char16*>(data + 1);
111 CHECK(data->pipe_name_length);
112 inviter_endpoint_ = NamedPlatformChannel::ConnectToServer(
113 base::StringPiece16(name_data, data->pipe_name_length).as_string());
114 }
115 }
116
~Broker()117 Broker::~Broker() {}
118
GetInviterEndpoint()119 PlatformChannelEndpoint Broker::GetInviterEndpoint() {
120 return std::move(inviter_endpoint_);
121 }
122
GetWritableSharedMemoryRegion(size_t num_bytes)123 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
124 size_t num_bytes) {
125 base::AutoLock lock(lock_);
126 BufferRequestData* buffer_request;
127 Channel::MessagePtr out_message = CreateBrokerMessage(
128 BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
129 buffer_request->size = base::checked_cast<uint32_t>(num_bytes);
130 DWORD bytes_written = 0;
131 BOOL result =
132 ::WriteFile(sync_channel_.GetHandle().Get(), out_message->data(),
133 static_cast<DWORD>(out_message->data_num_bytes()),
134 &bytes_written, nullptr);
135 if (!result ||
136 static_cast<size_t>(bytes_written) != out_message->data_num_bytes()) {
137 PLOG(ERROR) << "Error sending sync broker message";
138 return base::WritableSharedMemoryRegion();
139 }
140
141 PlatformHandle handle;
142 Channel::MessagePtr response = WaitForBrokerMessage(
143 sync_channel_.GetHandle().Get(), BrokerMessageType::BUFFER_RESPONSE);
144 if (response && TakeHandlesFromBrokerMessage(response.get(), 1, &handle)) {
145 BufferResponseData* data;
146 if (!GetBrokerMessageData(response.get(), &data))
147 return base::WritableSharedMemoryRegion();
148 return base::WritableSharedMemoryRegion::Deserialize(
149 base::subtle::PlatformSharedMemoryRegion::Take(
150 CreateSharedMemoryRegionHandleFromPlatformHandles(std::move(handle),
151 PlatformHandle()),
152 base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
153 num_bytes,
154 base::UnguessableToken::Deserialize(data->guid_high,
155 data->guid_low)));
156 }
157
158 return base::WritableSharedMemoryRegion();
159 }
160
161 } // namespace core
162 } // namespace mojo
163