1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/platform/named_platform_channel.h"
6
7 #include <errno.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <unistd.h>
11
12 #include "base/files/file_util.h"
13 #include "base/files/scoped_file.h"
14 #include "base/logging.h"
15 #include "base/posix/eintr_wrapper.h"
16 #include "base/rand_util.h"
17 #include "base/strings/string_number_conversions.h"
18
19 namespace mojo {
20
21 namespace {
22
GenerateRandomServerName(const NamedPlatformChannel::Options & options)23 NamedPlatformChannel::ServerName GenerateRandomServerName(
24 const NamedPlatformChannel::Options& options) {
25 return options.socket_dir
26 .AppendASCII(base::NumberToString(base::RandUint64()))
27 .value();
28 }
29
30 // This function fills in |unix_addr| with the appropriate data for the socket,
31 // and sets |unix_addr_len| to the length of the data therein.
32 // Returns true on success, or false on failure (typically because |server_name|
33 // violated the naming rules).
MakeUnixAddr(const NamedPlatformChannel::ServerName & server_name,struct sockaddr_un * unix_addr,size_t * unix_addr_len)34 bool MakeUnixAddr(const NamedPlatformChannel::ServerName& server_name,
35 struct sockaddr_un* unix_addr,
36 size_t* unix_addr_len) {
37 DCHECK(unix_addr);
38 DCHECK(unix_addr_len);
39 DCHECK(!server_name.empty());
40
41 constexpr size_t kMaxSocketNameLength = 104;
42
43 // We reject server_name.length() == kMaxSocketNameLength to make room for the
44 // NUL terminator at the end of the string.
45 if (server_name.length() >= kMaxSocketNameLength) {
46 LOG(ERROR) << "Socket name too long: " << server_name;
47 return false;
48 }
49
50 // Create unix_addr structure.
51 memset(unix_addr, 0, sizeof(struct sockaddr_un));
52 unix_addr->sun_family = AF_UNIX;
53 strncpy(unix_addr->sun_path, server_name.c_str(), kMaxSocketNameLength);
54 *unix_addr_len =
55 offsetof(struct sockaddr_un, sun_path) + server_name.length();
56 return true;
57 }
58
59 // This function creates a unix domain socket, and set it as non-blocking.
60 // If successful, this returns a PlatformHandle containing the socket.
61 // Otherwise, this returns an invalid PlatformHandle.
CreateUnixDomainSocket()62 PlatformHandle CreateUnixDomainSocket() {
63 // Create the unix domain socket.
64 PlatformHandle handle(base::ScopedFD(socket(AF_UNIX, SOCK_STREAM, 0)));
65 if (!handle.is_valid()) {
66 PLOG(ERROR) << "Failed to create AF_UNIX socket.";
67 return PlatformHandle();
68 }
69
70 // Now set it as non-blocking.
71 if (!base::SetNonBlocking(handle.GetFD().get())) {
72 PLOG(ERROR) << "base::SetNonBlocking() failed " << handle.GetFD().get();
73 return PlatformHandle();
74 }
75 return handle;
76 }
77
78 } // namespace
79
80 // static
CreateServerEndpoint(const Options & options,ServerName * server_name)81 PlatformChannelServerEndpoint NamedPlatformChannel::CreateServerEndpoint(
82 const Options& options,
83 ServerName* server_name) {
84 ServerName name = options.server_name;
85 if (name.empty())
86 name = GenerateRandomServerName(options);
87
88 // Make sure the path we need exists.
89 base::FilePath socket_dir = base::FilePath(name).DirName();
90 if (!base::CreateDirectory(socket_dir)) {
91 LOG(ERROR) << "Couldn't create directory: " << socket_dir.value();
92 return PlatformChannelServerEndpoint();
93 }
94
95 // Delete any old FS instances.
96 if (unlink(name.c_str()) < 0 && errno != ENOENT) {
97 PLOG(ERROR) << "unlink " << name;
98 return PlatformChannelServerEndpoint();
99 }
100
101 struct sockaddr_un unix_addr;
102 size_t unix_addr_len;
103 if (!MakeUnixAddr(name, &unix_addr, &unix_addr_len))
104 return PlatformChannelServerEndpoint();
105
106 PlatformHandle handle = CreateUnixDomainSocket();
107 if (!handle.is_valid())
108 return PlatformChannelServerEndpoint();
109
110 // Bind the socket.
111 if (bind(handle.GetFD().get(), reinterpret_cast<const sockaddr*>(&unix_addr),
112 unix_addr_len) < 0) {
113 PLOG(ERROR) << "bind " << name;
114 return PlatformChannelServerEndpoint();
115 }
116
117 // Start listening on the socket.
118 if (listen(handle.GetFD().get(), SOMAXCONN) < 0) {
119 PLOG(ERROR) << "listen " << name;
120 unlink(name.c_str());
121 return PlatformChannelServerEndpoint();
122 }
123
124 *server_name = name;
125 return PlatformChannelServerEndpoint(std::move(handle));
126 }
127
128 // static
CreateClientEndpoint(const ServerName & server_name)129 PlatformChannelEndpoint NamedPlatformChannel::CreateClientEndpoint(
130 const ServerName& server_name) {
131 DCHECK(!server_name.empty());
132
133 struct sockaddr_un unix_addr;
134 size_t unix_addr_len;
135 if (!MakeUnixAddr(server_name, &unix_addr, &unix_addr_len))
136 return PlatformChannelEndpoint();
137
138 PlatformHandle handle = CreateUnixDomainSocket();
139 if (!handle.is_valid())
140 return PlatformChannelEndpoint();
141
142 if (HANDLE_EINTR(connect(handle.GetFD().get(),
143 reinterpret_cast<sockaddr*>(&unix_addr),
144 unix_addr_len)) < 0) {
145 PLOG(ERROR) << "connect " << server_name;
146 return PlatformChannelEndpoint();
147 }
148 return PlatformChannelEndpoint(std::move(handle));
149 }
150
151 } // namespace mojo
152