1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/ipc/host_impl.h"
18
19 #include <algorithm>
20 #include <cinttypes>
21 #include <utility>
22
23 #include "perfetto/base/build_config.h"
24 #include "perfetto/base/task_runner.h"
25 #include "perfetto/base/time.h"
26 #include "perfetto/ext/base/crash_keys.h"
27 #include "perfetto/ext/base/utils.h"
28 #include "perfetto/ext/ipc/service.h"
29 #include "perfetto/ext/ipc/service_descriptor.h"
30
31 #include "protos/perfetto/ipc/wire_protocol.gen.h"
32
33 // TODO(primiano): put limits on #connections/uid and req. queue (b/69093705).
34
35 namespace perfetto {
36 namespace ipc {
37
38 namespace {
39
40 constexpr base::SockFamily kHostSockFamily =
41 kUseTCPSocket ? base::SockFamily::kInet : base::SockFamily::kUnix;
42
43 base::CrashKey g_crash_key_uid("ipc_uid");
44
GetPosixPeerUid(base::UnixSocket * sock)45 uid_t GetPosixPeerUid(base::UnixSocket* sock) {
46 #if PERFETTO_BUILDFLAG(PERFETTO_OS_WIN)
47 base::ignore_result(sock);
48 // Unsupported. Must be != kInvalidUid or the PacketValidator will fail.
49 return 0;
50 #else
51 return sock->peer_uid_posix();
52 #endif
53 }
54
GetLinuxPeerPid(base::UnixSocket * sock)55 pid_t GetLinuxPeerPid(base::UnixSocket* sock) {
56 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
57 PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
58 return sock->peer_pid_linux();
59 #else
60 base::ignore_result(sock);
61 return base::kInvalidPid; // Unsupported.
62 #endif
63 }
64
65 } // namespace
66
67 // static
CreateInstance(const char * socket_name,base::TaskRunner * task_runner)68 std::unique_ptr<Host> Host::CreateInstance(const char* socket_name,
69 base::TaskRunner* task_runner) {
70 std::unique_ptr<HostImpl> host(new HostImpl(socket_name, task_runner));
71 if (!host->sock() || !host->sock()->is_listening())
72 return nullptr;
73 return std::unique_ptr<Host>(std::move(host));
74 }
75
76 // static
CreateInstance(base::ScopedSocketHandle socket_fd,base::TaskRunner * task_runner)77 std::unique_ptr<Host> Host::CreateInstance(base::ScopedSocketHandle socket_fd,
78 base::TaskRunner* task_runner) {
79 std::unique_ptr<HostImpl> host(
80 new HostImpl(std::move(socket_fd), task_runner));
81 if (!host->sock() || !host->sock()->is_listening())
82 return nullptr;
83 return std::unique_ptr<Host>(std::move(host));
84 }
85
HostImpl(base::ScopedSocketHandle socket_fd,base::TaskRunner * task_runner)86 HostImpl::HostImpl(base::ScopedSocketHandle socket_fd,
87 base::TaskRunner* task_runner)
88 : task_runner_(task_runner), weak_ptr_factory_(this) {
89 PERFETTO_DCHECK_THREAD(thread_checker_);
90 sock_ = base::UnixSocket::Listen(std::move(socket_fd), this, task_runner_,
91 kHostSockFamily, base::SockType::kStream);
92 }
93
HostImpl(const char * socket_name,base::TaskRunner * task_runner)94 HostImpl::HostImpl(const char* socket_name, base::TaskRunner* task_runner)
95 : task_runner_(task_runner), weak_ptr_factory_(this) {
96 PERFETTO_DCHECK_THREAD(thread_checker_);
97 sock_ = base::UnixSocket::Listen(socket_name, this, task_runner_,
98 kHostSockFamily, base::SockType::kStream);
99 if (!sock_) {
100 PERFETTO_PLOG("Failed to create %s", socket_name);
101 }
102 }
103
104 HostImpl::~HostImpl() = default;
105
ExposeService(std::unique_ptr<Service> service)106 bool HostImpl::ExposeService(std::unique_ptr<Service> service) {
107 PERFETTO_DCHECK_THREAD(thread_checker_);
108 const std::string& service_name = service->GetDescriptor().service_name;
109 if (GetServiceByName(service_name)) {
110 PERFETTO_DLOG("Duplicate ExposeService(): %s", service_name.c_str());
111 return false;
112 }
113 ServiceID sid = ++last_service_id_;
114 ExposedService exposed_service(sid, service_name, std::move(service));
115 services_.emplace(sid, std::move(exposed_service));
116 return true;
117 }
118
SetSocketSendTimeoutMs(uint32_t timeout_ms)119 void HostImpl::SetSocketSendTimeoutMs(uint32_t timeout_ms) {
120 PERFETTO_DCHECK_THREAD(thread_checker_);
121 // Should be less than the watchdog period (30s).
122 socket_tx_timeout_ms_ = timeout_ms;
123 }
124
OnNewIncomingConnection(base::UnixSocket *,std::unique_ptr<base::UnixSocket> new_conn)125 void HostImpl::OnNewIncomingConnection(
126 base::UnixSocket*,
127 std::unique_ptr<base::UnixSocket> new_conn) {
128 PERFETTO_DCHECK_THREAD(thread_checker_);
129 std::unique_ptr<ClientConnection> client(new ClientConnection());
130 ClientID client_id = ++last_client_id_;
131 clients_by_socket_[new_conn.get()] = client.get();
132 client->id = client_id;
133 client->sock = std::move(new_conn);
134 client->sock->SetTxTimeout(socket_tx_timeout_ms_);
135 clients_[client_id] = std::move(client);
136 }
137
OnDataAvailable(base::UnixSocket * sock)138 void HostImpl::OnDataAvailable(base::UnixSocket* sock) {
139 PERFETTO_DCHECK_THREAD(thread_checker_);
140 auto it = clients_by_socket_.find(sock);
141 if (it == clients_by_socket_.end())
142 return;
143 ClientConnection* client = it->second;
144 BufferedFrameDeserializer& frame_deserializer = client->frame_deserializer;
145
146 auto peer_uid = GetPosixPeerUid(client->sock.get());
147 auto scoped_key = g_crash_key_uid.SetScoped(static_cast<int64_t>(peer_uid));
148
149 size_t rsize;
150 do {
151 auto buf = frame_deserializer.BeginReceive();
152 base::ScopedFile fd;
153 rsize = client->sock->Receive(buf.data, buf.size, &fd);
154 if (fd) {
155 PERFETTO_DCHECK(!client->received_fd);
156 client->received_fd = std::move(fd);
157 }
158 if (!frame_deserializer.EndReceive(rsize))
159 return OnDisconnect(client->sock.get());
160 } while (rsize > 0);
161
162 for (;;) {
163 std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame();
164 if (!frame)
165 break;
166 OnReceivedFrame(client, *frame);
167 }
168 }
169
OnReceivedFrame(ClientConnection * client,const Frame & req_frame)170 void HostImpl::OnReceivedFrame(ClientConnection* client,
171 const Frame& req_frame) {
172 if (req_frame.has_msg_bind_service())
173 return OnBindService(client, req_frame);
174 if (req_frame.has_msg_invoke_method())
175 return OnInvokeMethod(client, req_frame);
176
177 PERFETTO_DLOG("Received invalid RPC frame from client %" PRIu64, client->id);
178 Frame reply_frame;
179 reply_frame.set_request_id(req_frame.request_id());
180 reply_frame.mutable_msg_request_error()->set_error("unknown request");
181 SendFrame(client, reply_frame);
182 }
183
OnBindService(ClientConnection * client,const Frame & req_frame)184 void HostImpl::OnBindService(ClientConnection* client, const Frame& req_frame) {
185 // Binding a service doesn't do anything major. It just returns back the
186 // service id and its method map.
187 const Frame::BindService& req = req_frame.msg_bind_service();
188 Frame reply_frame;
189 reply_frame.set_request_id(req_frame.request_id());
190 auto* reply = reply_frame.mutable_msg_bind_service_reply();
191 const ExposedService* service = GetServiceByName(req.service_name());
192 if (service) {
193 reply->set_success(true);
194 reply->set_service_id(service->id);
195 uint32_t method_id = 1; // method ids start at index 1.
196 for (const auto& desc_method : service->instance->GetDescriptor().methods) {
197 Frame::BindServiceReply::MethodInfo* method_info = reply->add_methods();
198 method_info->set_name(desc_method.name);
199 method_info->set_id(method_id++);
200 }
201 }
202 SendFrame(client, reply_frame);
203 }
204
OnInvokeMethod(ClientConnection * client,const Frame & req_frame)205 void HostImpl::OnInvokeMethod(ClientConnection* client,
206 const Frame& req_frame) {
207 const Frame::InvokeMethod& req = req_frame.msg_invoke_method();
208 Frame reply_frame;
209 RequestID request_id = req_frame.request_id();
210 reply_frame.set_request_id(request_id);
211 reply_frame.mutable_msg_invoke_method_reply()->set_success(false);
212 auto svc_it = services_.find(req.service_id());
213 if (svc_it == services_.end())
214 return SendFrame(client, reply_frame); // |success| == false by default.
215
216 Service* service = svc_it->second.instance.get();
217 const ServiceDescriptor& svc = service->GetDescriptor();
218 const auto& methods = svc.methods;
219 const uint32_t method_id = req.method_id();
220 if (method_id == 0 || method_id > methods.size())
221 return SendFrame(client, reply_frame);
222
223 const ServiceDescriptor::Method& method = methods[method_id - 1];
224 std::unique_ptr<ProtoMessage> decoded_req_args(
225 method.request_proto_decoder(req.args_proto()));
226 if (!decoded_req_args)
227 return SendFrame(client, reply_frame);
228
229 Deferred<ProtoMessage> deferred_reply;
230 base::WeakPtr<HostImpl> host_weak_ptr = weak_ptr_factory_.GetWeakPtr();
231 ClientID client_id = client->id;
232
233 if (!req.drop_reply()) {
234 deferred_reply.Bind([host_weak_ptr, client_id,
235 request_id](AsyncResult<ProtoMessage> reply) {
236 if (!host_weak_ptr)
237 return; // The reply came too late, the HostImpl has gone.
238 host_weak_ptr->ReplyToMethodInvocation(client_id, request_id,
239 std::move(reply));
240 });
241 }
242
243 auto peer_uid = GetPosixPeerUid(client->sock.get());
244 auto scoped_key = g_crash_key_uid.SetScoped(static_cast<int64_t>(peer_uid));
245 service->client_info_ =
246 ClientInfo(client->id, peer_uid, GetLinuxPeerPid(client->sock.get()));
247 service->received_fd_ = &client->received_fd;
248 method.invoker(service, *decoded_req_args, std::move(deferred_reply));
249 service->received_fd_ = nullptr;
250 service->client_info_ = ClientInfo();
251 }
252
ReplyToMethodInvocation(ClientID client_id,RequestID request_id,AsyncResult<ProtoMessage> reply)253 void HostImpl::ReplyToMethodInvocation(ClientID client_id,
254 RequestID request_id,
255 AsyncResult<ProtoMessage> reply) {
256 auto client_iter = clients_.find(client_id);
257 if (client_iter == clients_.end())
258 return; // client has disconnected by the time we got the async reply.
259
260 ClientConnection* client = client_iter->second.get();
261 Frame reply_frame;
262 reply_frame.set_request_id(request_id);
263
264 // TODO(fmayer): add a test to guarantee that the reply is consumed within the
265 // same call stack and not kept around. ConsumerIPCService::OnTraceData()
266 // relies on this behavior.
267 auto* reply_frame_data = reply_frame.mutable_msg_invoke_method_reply();
268 reply_frame_data->set_has_more(reply.has_more());
269 if (reply.success()) {
270 std::string reply_proto = reply->SerializeAsString();
271 reply_frame_data->set_reply_proto(reply_proto);
272 reply_frame_data->set_success(true);
273 }
274 SendFrame(client, reply_frame, reply.fd());
275 }
276
277 // static
SendFrame(ClientConnection * client,const Frame & frame,int fd)278 void HostImpl::SendFrame(ClientConnection* client, const Frame& frame, int fd) {
279 auto peer_uid = GetPosixPeerUid(client->sock.get());
280 auto scoped_key = g_crash_key_uid.SetScoped(static_cast<int64_t>(peer_uid));
281
282 std::string buf = BufferedFrameDeserializer::Serialize(frame);
283
284 // When a new Client connects in OnNewClientConnection we set a timeout on
285 // Send (see call to SetTxTimeout).
286 //
287 // The old behaviour was to do a blocking I/O call, which caused crashes from
288 // misbehaving producers (see b/169051440).
289 bool res = client->sock->Send(buf.data(), buf.size(), fd);
290 // If we timeout |res| will be false, but the UnixSocket will have called
291 // UnixSocket::ShutDown() and thus |is_connected()| is false.
292 PERFETTO_CHECK(res || !client->sock->is_connected());
293 }
294
OnDisconnect(base::UnixSocket * sock)295 void HostImpl::OnDisconnect(base::UnixSocket* sock) {
296 PERFETTO_DCHECK_THREAD(thread_checker_);
297 auto it = clients_by_socket_.find(sock);
298 if (it == clients_by_socket_.end())
299 return;
300 ClientID client_id = it->second->id;
301
302 ClientInfo client_info(client_id, GetPosixPeerUid(sock),
303 GetLinuxPeerPid(sock));
304 clients_by_socket_.erase(it);
305 PERFETTO_DCHECK(clients_.count(client_id));
306 clients_.erase(client_id);
307
308 for (const auto& service_it : services_) {
309 Service& service = *service_it.second.instance;
310 service.client_info_ = client_info;
311 service.OnClientDisconnected();
312 service.client_info_ = ClientInfo();
313 }
314 }
315
GetServiceByName(const std::string & name)316 const HostImpl::ExposedService* HostImpl::GetServiceByName(
317 const std::string& name) {
318 // This could be optimized by using another map<name,ServiceID>. However this
319 // is used only by Bind/ExposeService that are quite rare (once per client
320 // connection and once per service instance), not worth it.
321 for (const auto& it : services_) {
322 if (it.second.name == name)
323 return &it.second;
324 }
325 return nullptr;
326 }
327
ExposedService(ServiceID id_,const std::string & name_,std::unique_ptr<Service> instance_)328 HostImpl::ExposedService::ExposedService(ServiceID id_,
329 const std::string& name_,
330 std::unique_ptr<Service> instance_)
331 : id(id_), name(name_), instance(std::move(instance_)) {}
332
333 HostImpl::ExposedService::ExposedService(ExposedService&&) noexcept = default;
334 HostImpl::ExposedService& HostImpl::ExposedService::operator=(
335 HostImpl::ExposedService&&) = default;
336 HostImpl::ExposedService::~ExposedService() = default;
337
338 HostImpl::ClientConnection::~ClientConnection() = default;
339
340 } // namespace ipc
341 } // namespace perfetto
342