1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/core/broker.h"
6
7 #include <fcntl.h>
8 #include <unistd.h>
9
10 #include <utility>
11 #include <vector>
12
13 #include "base/logging.h"
14 #include "base/memory/platform_shared_memory_region.h"
15 #include "build/build_config.h"
16 #include "mojo/core/broker_messages.h"
17 #include "mojo/core/channel.h"
18 #include "mojo/core/platform_handle_utils.h"
19 #include "mojo/public/cpp/platform/socket_utils_posix.h"
20
21 namespace mojo {
22 namespace core {
23
24 namespace {
25
WaitForBrokerMessage(int socket_fd,BrokerMessageType expected_type,size_t expected_num_handles,size_t expected_data_size,std::vector<PlatformHandle> * incoming_handles)26 Channel::MessagePtr WaitForBrokerMessage(
27 int socket_fd,
28 BrokerMessageType expected_type,
29 size_t expected_num_handles,
30 size_t expected_data_size,
31 std::vector<PlatformHandle>* incoming_handles) {
32 Channel::MessagePtr message(new Channel::Message(
33 sizeof(BrokerMessageHeader) + expected_data_size, expected_num_handles));
34 std::vector<base::ScopedFD> incoming_fds;
35 ssize_t read_result =
36 SocketRecvmsg(socket_fd, const_cast<void*>(message->data()),
37 message->data_num_bytes(), &incoming_fds, true /* block */);
38 bool error = false;
39 if (read_result < 0) {
40 PLOG(ERROR) << "Recvmsg error";
41 error = true;
42 } else if (static_cast<size_t>(read_result) != message->data_num_bytes()) {
43 LOG(ERROR) << "Invalid node channel message";
44 error = true;
45 } else if (incoming_fds.size() != expected_num_handles) {
46 LOG(ERROR) << "Received unexpected number of handles";
47 error = true;
48 }
49
50 if (error)
51 return nullptr;
52
53 const BrokerMessageHeader* header =
54 reinterpret_cast<const BrokerMessageHeader*>(message->payload());
55 if (header->type != expected_type) {
56 LOG(ERROR) << "Unexpected message";
57 return nullptr;
58 }
59
60 incoming_handles->reserve(incoming_fds.size());
61 for (size_t i = 0; i < incoming_fds.size(); ++i)
62 incoming_handles->emplace_back(std::move(incoming_fds[i]));
63
64 return message;
65 }
66
67 } // namespace
68
Broker(PlatformHandle handle)69 Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) {
70 CHECK(sync_channel_.is_valid());
71
72 int fd = sync_channel_.GetFD().get();
73 // Mark the channel as blocking.
74 int flags = fcntl(fd, F_GETFL);
75 PCHECK(flags != -1);
76 flags = fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
77 PCHECK(flags != -1);
78
79 // Wait for the first message, which should contain a handle.
80 std::vector<PlatformHandle> incoming_platform_handles;
81 if (WaitForBrokerMessage(fd, BrokerMessageType::INIT, 1, 0,
82 &incoming_platform_handles)) {
83 inviter_endpoint_ =
84 PlatformChannelEndpoint(std::move(incoming_platform_handles[0]));
85 }
86 }
87
88 Broker::~Broker() = default;
89
GetInviterEndpoint()90 PlatformChannelEndpoint Broker::GetInviterEndpoint() {
91 return std::move(inviter_endpoint_);
92 }
93
GetWritableSharedMemoryRegion(size_t num_bytes)94 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
95 size_t num_bytes) {
96 base::AutoLock lock(lock_);
97
98 BufferRequestData* buffer_request;
99 Channel::MessagePtr out_message = CreateBrokerMessage(
100 BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
101 buffer_request->size = num_bytes;
102 ssize_t write_result =
103 SocketWrite(sync_channel_.GetFD().get(), out_message->data(),
104 out_message->data_num_bytes());
105 if (write_result < 0) {
106 PLOG(ERROR) << "Error sending sync broker message";
107 return base::WritableSharedMemoryRegion();
108 } else if (static_cast<size_t>(write_result) !=
109 out_message->data_num_bytes()) {
110 LOG(ERROR) << "Error sending complete broker message";
111 return base::WritableSharedMemoryRegion();
112 }
113
114 #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \
115 (defined(OS_MACOSX) && !defined(OS_IOS))
116 // Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use
117 // a single handle to represent a writable region.
118 constexpr size_t kNumExpectedHandles = 1;
119 #else
120 constexpr size_t kNumExpectedHandles = 2;
121 #endif
122
123 std::vector<PlatformHandle> handles;
124 Channel::MessagePtr message = WaitForBrokerMessage(
125 sync_channel_.GetFD().get(), BrokerMessageType::BUFFER_RESPONSE,
126 kNumExpectedHandles, sizeof(BufferResponseData), &handles);
127 if (message) {
128 const BufferResponseData* data;
129 if (!GetBrokerMessageData(message.get(), &data))
130 return base::WritableSharedMemoryRegion();
131
132 if (handles.size() == 1)
133 handles.emplace_back();
134 return base::WritableSharedMemoryRegion::Deserialize(
135 base::subtle::PlatformSharedMemoryRegion::Take(
136 CreateSharedMemoryRegionHandleFromPlatformHandles(
137 std::move(handles[0]), std::move(handles[1])),
138 base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
139 num_bytes,
140 base::UnguessableToken::Deserialize(data->guid_high,
141 data->guid_low)));
142 }
143
144 return base::WritableSharedMemoryRegion();
145 }
146
147 } // namespace core
148 } // namespace mojo
149